Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ftshijt/espnet
Browse files Browse the repository at this point in the history
  • Loading branch information
ftshijt committed Jan 1, 2023
2 parents 766d9a9 + edd6f74 commit 3241de9
Show file tree
Hide file tree
Showing 10 changed files with 549 additions and 35 deletions.
2 changes: 1 addition & 1 deletion egs2/TEMPLATE/asr1/asr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1410,7 +1410,7 @@ if ! "${skip_eval}"; then
_opts="--token_type ${_tok_type} "
if [ "${_tok_type}" = "char" ] || [ "${_tok_type}" = "word" ]; then
_type="${_tok_type:0:1}er"
_opts+="--non_linguistic_symbols \"${nlsyms_txt}\" "
_opts+="--non_linguistic_symbols ${nlsyms_txt} "
_opts+="--remove_non_linguistic_symbols true "

elif [ "${_tok_type}" = "bpe" ]; then
Expand Down
1 change: 0 additions & 1 deletion egs2/TEMPLATE/asr1/db.sh
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ ST_CMDS=downloads
MS_INDIC_IS18=
MARATHI=downloads
MLS=downloads
VOXFORGE=downloads
VOXPOPULI=downloads
HARPERVALLEY=downloads
TALROMUR=downloads
Expand Down
5 changes: 4 additions & 1 deletion egs2/gigaspeech/asr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
abs_data_dir=$(readlink -f ${data_dir})
log "making Kaldi format data directory in ${abs_data_dir}"
pushd GigaSpeech
./toolkits/kaldi/gigaspeech_data_prep.sh ${GIGASPEECH} ${abs_data_dir} true
./toolkits/kaldi/gigaspeech_data_prep.sh --train-subset XL ${GIGASPEECH} ${abs_data_dir}
popd
mv ${data_dir}/gigaspeech_train_xl ${data_dir}/train
mv ${data_dir}/gigaspeech_dev ${data_dir}/dev
mv ${data_dir}/gigaspeech_test ${data_dir}/test
fi

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
Expand Down
13 changes: 12 additions & 1 deletion egs2/librispeech_100/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ Model: https://huggingface.co/pyf98/librispeech_100h_conformer
|beam20_ctc0.3/dev_other|2864|50948|84.6|13.9|1.5|2.1|17.4|79.9|
|beam20_ctc0.3/test_clean|2620|52576|94.3|5.3|0.4|0.8|6.5|57.0|
|beam20_ctc0.3/test_other|2939|52343|84.7|13.7|1.6|2.0|17.3|81.6|
|timesync_beam20_ctc0.3/dev_clean|2703|54402|94.4|5.1|0.5|0.7|6.3|56.6|
|timesync_beam20_ctc0.3/dev_other|2864|50948|83.9|13.4|2.7|1.8|17.8|80.3|
|timesync_beam20_ctc0.3/test_clean|2620|52576|94.3|5.2|0.5|0.7|6.5|57.3|
|timesync_beam20_ctc0.3/test_other|2939|52343|84.1|13.4|2.4|1.8|17.7|82.2|

### CER

Expand All @@ -85,6 +89,10 @@ Model: https://huggingface.co/pyf98/librispeech_100h_conformer
|beam20_ctc0.3/dev_other|2864|265951|93.3|4.2|2.5|2.0|8.7|79.9|
|beam20_ctc0.3/test_clean|2620|281530|98.1|1.1|0.8|0.6|2.5|57.0|
|beam20_ctc0.3/test_other|2939|272758|93.5|4.0|2.6|1.9|8.4|81.6|
|timesync_beam20_ctc0.3/dev_clean|2703|288456|98.1|1.0|0.9|0.6|2.5|56.6|
|timesync_beam20_ctc0.3/dev_other|2864|265951|92.0|3.9|4.2|1.8|9.8|80.3|
|timesync_beam20_ctc0.3/test_clean|2620|281530|98.0|1.0|1.0|0.6|2.6|57.3|
|timesync_beam20_ctc0.3/test_other|2939|272758|92.5|3.7|3.8|1.7|9.3|82.2|

### TER

Expand All @@ -98,7 +106,10 @@ Model: https://huggingface.co/pyf98/librispeech_100h_conformer
|beam20_ctc0.3/dev_other|2864|64524|81.0|13.5|5.5|2.3|21.3|79.9|
|beam20_ctc0.3/test_clean|2620|66983|92.0|5.0|3.0|0.6|8.6|57.0|
|beam20_ctc0.3/test_other|2939|66650|81.2|13.0|5.8|2.0|20.9|81.6|

|timesync_beam20_ctc0.3/dev_clean|2703|69558|91.8|4.8|3.4|0.5|8.7|56.6|
|timesync_beam20_ctc0.3/dev_other|2864|64524|80.0|12.5|7.5|1.8|21.8|80.3|
|timesync_beam20_ctc0.3/test_clean|2620|66983|91.9|4.8|3.4|0.6|8.7|57.3|
|timesync_beam20_ctc0.3/test_other|2939|66650|80.4|12.2|7.4|1.8|21.4|82.2|


## Environments
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
beam_size: 20
ctc_weight: 0.3
lm_weight: 0.0
maxlenratio: 0.0
minlenratio: 0.0
penalty: 0.0
time_sync: true
10 changes: 10 additions & 0 deletions espnet/nets/batch_beam_search_online_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ def forward(
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
if best.yseq[1:-1].shape[0] == x.shape[0]:
logging.warning(
"best hypo length: {} == max output length: {}".format(
best.yseq[1:-1].shape[0], maxlen
)
)
logging.warning(
"decoding may be stopped by the max output length limitation, "
+ "please consider to increase the maxlenratio."
)
return nbest_hyps

def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]:
Expand Down
10 changes: 10 additions & 0 deletions espnet/nets/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,16 @@ def forward(
+ "".join([self.token_list[x] for x in best.yseq[1:-1]])
+ "\n"
)
if best.yseq[1:-1].shape[0] == x.shape[0]:
logging.warning(
"best hypo length: {} == max output length: {}".format(
best.yseq[1:-1].shape[0], maxlen
)
)
logging.warning(
"decoding may be stopped by the max output length limitation, "
+ "please consider to increase the maxlenratio."
)
return nbest_hyps

def post_process(
Expand Down
237 changes: 237 additions & 0 deletions espnet/nets/beam_search_timesync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
Time Synchronous One-Pass Beam Search.
Implements joint CTC/attention decoding where
hypotheses are expanded along the time (input) axis,
as described in https://arxiv.org/abs/2210.05200.
Supports CPU and GPU inference.
References: https://arxiv.org/abs/1408.2873 for CTC beam search
Author: Brian Yan
"""

import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple

import numpy as np
import torch

from espnet.nets.beam_search import Hypothesis
from espnet.nets.scorer_interface import ScorerInterface


@dataclass
class CacheItem:
"""For caching attentional decoder and LM states."""

state: Any
scores: Any
log_sum: float


class BeamSearchTimeSync(torch.nn.Module):
"""Time synchronous beam search algorithm."""

def __init__(
self,
sos: int,
beam_size: int,
scorers: Dict[str, ScorerInterface],
weights: Dict[str, float],
token_list=dict,
pre_beam_ratio: float = 1.5,
blank: int = 0,
force_lid: bool = False,
temp: float = 1.0,
):
"""Initialize beam search.
Args:
beam_size: num hyps
sos: sos index
ctc: CTC module
pre_beam_ratio: pre_beam_ratio * beam_size = pre_beam
pre_beam is used to select candidates from vocab to extend hypotheses
decoder: decoder ScorerInterface
ctc_weight: ctc_weight
blank: blank index
"""
super().__init__()
self.ctc = scorers["ctc"]
self.decoder = scorers["decoder"]
self.lm = scorers["lm"] if "lm" in scorers else None
self.beam_size = beam_size
self.pre_beam_size = int(pre_beam_ratio * beam_size)
self.ctc_weight = weights["ctc"]
self.lm_weight = weights["lm"]
self.decoder_weight = weights["decoder"]
self.penalty = weights["length_bonus"]
self.sos = sos
self.sos_th = torch.tensor([self.sos])
self.blank = blank
self.attn_cache = dict() # cache for p_attn(Y|X)
self.lm_cache = dict() # cache for p_lm(Y)
self.enc_output = None # log p_ctc(Z|X)
self.force_lid = force_lid
self.temp = temp
self.token_list = token_list

def reset(self, enc_output: torch.Tensor):
"""Reset object for a new utterance."""
self.attn_cache = dict()
self.lm_cache = dict()
self.enc_output = enc_output
self.sos_th = self.sos_th.to(enc_output.device)

if self.decoder is not None:
init_decoder_state = self.decoder.init_state(enc_output)
decoder_scores, decoder_state = self.decoder.score(
self.sos_th, init_decoder_state, enc_output
)
self.attn_cache[(self.sos,)] = CacheItem(
state=decoder_state,
scores=decoder_scores,
log_sum=0.0,
)
if self.lm is not None:
init_lm_state = self.lm.init_state(enc_output)
lm_scores, lm_state = self.lm.score(self.sos_th, init_lm_state, enc_output)
self.lm_cache[(self.sos,)] = CacheItem(
state=lm_state,
scores=lm_scores,
log_sum=0.0,
)

def cached_score(self, h: Tuple[int], cache: dict, scorer: ScorerInterface) -> Any:
"""Retrieve decoder/LM scores which may be cached."""
root = h[:-1] # prefix
if root in cache:
root_scores = cache[root].scores
root_state = cache[root].state
root_log_sum = cache[root].log_sum
else: # run decoder fwd one step and update cache
root_root = root[:-1]
root_root_state = cache[root_root].state
root_scores, root_state = scorer.score(
torch.tensor(root, device=self.enc_output.device).long(),
root_root_state,
self.enc_output,
)
root_log_sum = cache[root_root].log_sum + float(
cache[root_root].scores[root[-1]]
)
cache[root] = CacheItem(
state=root_state, scores=root_scores, log_sum=root_log_sum
)
cand_score = float(root_scores[h[-1]])
score = root_log_sum + cand_score

return score

def joint_score(self, hyps: Any, ctc_score_dp: Any) -> Any:
"""Calculate joint score for hyps."""
scores = dict()
for h in hyps:
score = self.ctc_weight * np.logaddexp(*ctc_score_dp[h]) # ctc score
if len(h) > 1 and self.decoder_weight > 0 and self.decoder is not None:
score += (
self.cached_score(h, self.attn_cache, self.decoder)
* self.decoder_weight
) # attn score
if len(h) > 1 and self.lm is not None and self.lm_weight > 0:
score += (
self.cached_score(h, self.lm_cache, self.lm) * self.lm_weight
) # lm score
score += self.penalty * (len(h) - 1) # penalty score
scores[h] = score
return scores

def time_step(self, p_ctc: Any, ctc_score_dp: Any, hyps: Any) -> Any:
"""Execute a single time step."""
pre_beam_threshold = np.sort(p_ctc)[-self.pre_beam_size]
cands = set(np.where(p_ctc >= pre_beam_threshold)[0])
if len(cands) == 0:
cands = {np.argmax(p_ctc)}
new_hyps = set()
ctc_score_dp_next = defaultdict(
lambda: (float("-inf"), float("-inf"))
) # (p_nb, p_b)
tmp = []
for hyp_l in hyps:
p_prev_l = np.logaddexp(*ctc_score_dp[hyp_l])
for c in cands:
if c == self.blank:
logging.debug("blank cand, hypothesis is " + str(hyp_l))
p_nb, p_b = ctc_score_dp_next[hyp_l]
p_b = np.logaddexp(p_b, p_ctc[c] + p_prev_l)
ctc_score_dp_next[hyp_l] = (p_nb, p_b)
new_hyps.add(hyp_l)
else:
l_plus = hyp_l + (int(c),)
logging.debug("non-blank cand, hypothesis is " + str(l_plus))
p_nb, p_b = ctc_score_dp_next[l_plus]
if c == hyp_l[-1]:
logging.debug("repeat cand, hypothesis is " + str(hyp_l))
p_nb_prev, p_b_prev = ctc_score_dp[hyp_l]
p_nb = np.logaddexp(p_nb, p_ctc[c] + p_b_prev)
p_nb_l, p_b_l = ctc_score_dp_next[hyp_l]
p_nb_l = np.logaddexp(p_nb_l, p_ctc[c] + p_nb_prev)
ctc_score_dp_next[hyp_l] = (p_nb_l, p_b_l)
else:
p_nb = np.logaddexp(p_nb, p_ctc[c] + p_prev_l)
if l_plus not in hyps and l_plus in ctc_score_dp:
p_b = np.logaddexp(
p_b, p_ctc[self.blank] + np.logaddexp(*ctc_score_dp[l_plus])
)
p_nb = np.logaddexp(p_nb, p_ctc[c] + ctc_score_dp[l_plus][0])
tmp.append(l_plus)
ctc_score_dp_next[l_plus] = (p_nb, p_b)
new_hyps.add(l_plus)

scores = self.joint_score(new_hyps, ctc_score_dp_next)

hyps = sorted(new_hyps, key=lambda l: scores[l], reverse=True)[: self.beam_size]
ctc_score_dp = ctc_score_dp_next.copy()
return ctc_score_dp, hyps, scores

def forward(
self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
) -> List[Hypothesis]:
"""Perform beam search.
Args:
enc_output (torch.Tensor)
Return:
list[Hypothesis]
"""
logging.info("decoder input lengths: " + str(x.shape[0]))
lpz = self.ctc.log_softmax(x.unsqueeze(0))
lpz = lpz.squeeze(0)
lpz = lpz.cpu().detach().numpy()
self.reset(x)

hyps = [(self.sos,)]
ctc_score_dp = defaultdict(
lambda: (float("-inf"), float("-inf"))
) # (p_nb, p_b) - dp object tracking p_ctc
ctc_score_dp[(self.sos,)] = (float("-inf"), 0.0)
for t in range(lpz.shape[0]):
logging.debug("position " + str(t))
ctc_score_dp, hyps, scores = self.time_step(lpz[t, :], ctc_score_dp, hyps)

ret = [
Hypothesis(yseq=torch.tensor(list(h) + [self.sos]), score=scores[h])
for h in hyps
]
best_hyp = "".join([self.token_list[x] for x in ret[0].yseq.tolist()])
best_hyp_len = len(ret[0].yseq)
best_score = ret[0].score
logging.info(f"output length: {best_hyp_len}")
logging.info(f"total log probability: {best_score:.2f}")
logging.info(f"best hypo: {best_hyp}")

return ret

0 comments on commit 3241de9

Please sign in to comment.