Skip to content

Commit

Permalink
zipformer wenetspeech (#1130)
Browse files Browse the repository at this point in the history
* copy files

* update train.py

* small fixes

* Add decode.py

* Fix dataloader in decode.py

* add blank penalty

* Add blank-penalty to other decoding method

* Minor fixes

* add zipformer2 recipe

* Minor fixes

* Remove pruned7

* export and test models

* Replace bpe with tokens in export.py and pretrain.py

* Minor fixes

* Minor fixes

* Minor fixes

* Fix export

* Update results

* Fix zipformer-ctc

* Fix ci

* Fix ci

* Fix CI

* Fix CI

---------

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
  • Loading branch information
pkufool and csukuangfj committed Jun 26, 2023
1 parent 4d5b836 commit 219bba1
Show file tree
Hide file tree
Showing 49 changed files with 4,401 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav

pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
Expand All @@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
Expand All @@ -46,7 +47,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"

./zipformer/jit_pretrained_streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
$repo/test_wavs/1089-134686-0001.wav

Expand All @@ -60,7 +61,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
Expand Down
7 changes: 4 additions & 3 deletions .github/scripts/run-librispeech-zipformer-2023-05-18.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav

pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
Expand All @@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
Expand All @@ -43,7 +44,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"

./zipformer/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
Expand All @@ -56,7 +57,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
Expand Down
8 changes: 4 additions & 4 deletions .github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav

pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt"
Expand All @@ -40,7 +41,7 @@ log "Export to torchscript model"
--use-transducer 1 \
--use-ctc 1 \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
Expand All @@ -51,7 +52,7 @@ log "Decode with models exported by torch.jit.script()"

for method in ctc-decoding 1best; do
./zipformer/jit_pretrained_ctc.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--model-filename $repo/exp/jit_script.pt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
Expand All @@ -71,8 +72,7 @@ for method in ctc-decoding 1best; do
--use-ctc 1 \
--method $method \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--words-file $repo/data/lang_bpe_500/words.txt \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--G $repo/data/lm/G_4_gram.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
Expand Down
4 changes: 2 additions & 2 deletions .github/scripts/test-ncnn-export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ git lfs pull --include "data/lang_char_bpe/Linv.pt"
git lfs pull --include "exp/pretrained.pt"

cd exp
ln -s pretrained.pt epoch-99.pt
ln -s pretrained.pt epoch-9999.pt
popd

./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--epoch 9999 \
--avg 1 \
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
subtract_ilme=True,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
Expand Down
47 changes: 40 additions & 7 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import k2
import sentencepiece as spm
import torch
from torch import nn

from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
Expand All @@ -35,7 +36,6 @@
get_texts,
get_texts_with_timestamp,
)
from torch import nn


def fast_beam_search_one_best(
Expand All @@ -47,8 +47,8 @@ def fast_beam_search_one_best(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
ilme_scale: float = 0.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -90,8 +90,8 @@ def fast_beam_search_one_best(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale,
blank_penalty=blank_penalty,
)

best_path = one_best_decoding(lattice)
Expand All @@ -114,6 +114,8 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
blank_penalty: float = 0.0,
ilme_scale: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -168,6 +170,8 @@ def fast_beam_search_nbest_LG(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
blank_penalty=blank_penalty,
ilme_scale=ilme_scale,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -240,6 +244,7 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -293,6 +298,7 @@ def fast_beam_search_nbest(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
blank_penalty=blank_penalty,
temperature=temperature,
)

Expand Down Expand Up @@ -331,6 +337,7 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
blank_penalty=blank_penalty,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -432,8 +440,8 @@ def fast_beam_search(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
ilme_scale: float = 0.0,
blank_penalty: float = 0.0,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -503,8 +511,13 @@ def fast_beam_search(
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

log_probs = (logits / temperature).log_softmax(dim=-1)
if subtract_ilme:

if ilme_scale != 0:
ilme_logits = model.joiner(
torch.zeros_like(
current_encoder_out, device=current_encoder_out.device
Expand All @@ -513,8 +526,11 @@ def fast_beam_search(
project_input=False,
)
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
if blank_penalty != 0:
ilme_logits[:, 0] -= blank_penalty
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
log_probs -= ilme_scale * ilme_log_probs

decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
Expand All @@ -526,6 +542,7 @@ def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
max_sym_per_frame: int,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance.
Expand Down Expand Up @@ -595,6 +612,9 @@ def greedy_search(
)
# logits is (1, 1, 1, vocab_size)

if blank_penalty != 0:
logits[:, :, :, 0] -= blank_penalty

y = logits.argmax().item()
if y not in (blank_id, unk_id):
hyp.append(y)
Expand Down Expand Up @@ -626,6 +646,7 @@ def greedy_search_batch(
model: nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
blank_penalty: float = 0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Expand Down Expand Up @@ -703,6 +724,10 @@ def greedy_search_batch(

logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
Expand Down Expand Up @@ -923,6 +948,7 @@ def modified_beam_search(
context_graph: Optional[ContextGraph] = None,
beam: int = 4,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Expand Down Expand Up @@ -1028,6 +1054,9 @@ def modified_beam_search(

logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)

log_probs.add_(ys_log_probs)
Expand Down Expand Up @@ -1662,6 +1691,7 @@ def beam_search(
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""
Expand Down Expand Up @@ -1758,6 +1788,9 @@ def beam_search(
project_input=False,
)

if blank_penalty != 0:
logits[:, :, :, 0] -= blank_penalty

# TODO(fangjun): Scale the blank posterior
log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
Expand Down

0 comments on commit 219bba1

Please sign in to comment.