Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zipformer wenetspeech #1130

Merged
merged 27 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav

pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script_chunk_16_left_128.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
Expand All @@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
Expand All @@ -46,7 +47,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"

./zipformer/jit_pretrained_streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script_chunk_16_left_128.pt \
$repo/test_wavs/1089-134686-0001.wav

Expand All @@ -60,7 +61,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
Expand Down
7 changes: 4 additions & 3 deletions .github/scripts/run-librispeech-zipformer-2023-05-18.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav

pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/jit_script.pt"
git lfs pull --include "exp/pretrained.pt"
ln -s pretrained.pt epoch-99.pt
Expand All @@ -33,7 +34,7 @@ log "Export to torchscript model"
./zipformer/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
Expand All @@ -43,7 +44,7 @@ ls -lh $repo/exp/*.pt
log "Decode with models exported by torch.jit.script()"

./zipformer/jit_pretrained.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--nn-model-filename $repo/exp/jit_script.pt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
Expand All @@ -56,7 +57,7 @@ for method in greedy_search modified_beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
Expand Down
8 changes: 4 additions & 4 deletions .github/scripts/run-librispeech-zipformer-ctc-2023-06-14.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ ls -lh $repo/test_wavs/*.wav

pushd $repo/exp
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "data/lang_bpe_500/HLG.pt"
git lfs pull --include "data/lang_bpe_500/L.pt"
git lfs pull --include "data/lang_bpe_500/LG.pt"
Expand All @@ -40,7 +41,7 @@ log "Export to torchscript model"
--use-transducer 1 \
--use-ctc 1 \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
Expand All @@ -51,7 +52,7 @@ log "Decode with models exported by torch.jit.script()"

for method in ctc-decoding 1best; do
./zipformer/jit_pretrained_ctc.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--model-filename $repo/exp/jit_script.pt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
Expand All @@ -71,8 +72,7 @@ for method in ctc-decoding 1best; do
--use-ctc 1 \
--method $method \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--words-file $repo/data/lang_bpe_500/words.txt \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--HLG $repo/data/lang_bpe_500/HLG.pt \
--G $repo/data/lm/G_4_gram.pt \
--words-file $repo/data/lang_bpe_500/words.txt \
Expand Down
4 changes: 2 additions & 2 deletions .github/scripts/test-ncnn-export.sh
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,14 @@ git lfs pull --include "data/lang_char_bpe/Linv.pt"
git lfs pull --include "exp/pretrained.pt"

cd exp
ln -s pretrained.pt epoch-99.pt
ln -s pretrained.pt epoch-9999.pt
popd

./pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py \
--lang-dir $repo/data/lang_char_bpe \
--exp-dir $repo/exp \
--use-averaged-model 0 \
--epoch 99 \
--epoch 9999 \
--avg 1 \
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
subtract_ilme=True,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
Expand Down
47 changes: 40 additions & 7 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import k2
import sentencepiece as spm
import torch
from torch import nn

from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
Expand All @@ -35,7 +36,6 @@
get_texts,
get_texts_with_timestamp,
)
from torch import nn


def fast_beam_search_one_best(
Expand All @@ -47,8 +47,8 @@ def fast_beam_search_one_best(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
ilme_scale: float = 0.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -90,8 +90,8 @@ def fast_beam_search_one_best(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale,
blank_penalty=blank_penalty,
)

best_path = one_best_decoding(lattice)
Expand All @@ -114,6 +114,8 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
blank_penalty: float = 0.0,
ilme_scale: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -168,6 +170,8 @@ def fast_beam_search_nbest_LG(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
blank_penalty=blank_penalty,
ilme_scale=ilme_scale,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -240,6 +244,7 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -293,6 +298,7 @@ def fast_beam_search_nbest(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
blank_penalty=blank_penalty,
temperature=temperature,
)

Expand Down Expand Up @@ -331,6 +337,7 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
blank_penalty=blank_penalty,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -432,8 +440,8 @@ def fast_beam_search(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
ilme_scale: float = 0.0,
blank_penalty: float = 0.0,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.

Expand Down Expand Up @@ -503,8 +511,13 @@ def fast_beam_search(
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

log_probs = (logits / temperature).log_softmax(dim=-1)
if subtract_ilme:

if ilme_scale != 0:
ilme_logits = model.joiner(
torch.zeros_like(
current_encoder_out, device=current_encoder_out.device
Expand All @@ -513,8 +526,11 @@ def fast_beam_search(
project_input=False,
)
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
if blank_penalty != 0:
ilme_logits[:, 0] -= blank_penalty
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
log_probs -= ilme_scale * ilme_log_probs

decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
Expand All @@ -526,6 +542,7 @@ def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
max_sym_per_frame: int,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance.
Expand Down Expand Up @@ -595,6 +612,9 @@ def greedy_search(
)
# logits is (1, 1, 1, vocab_size)

if blank_penalty != 0:
logits[:, :, :, 0] -= blank_penalty

y = logits.argmax().item()
if y not in (blank_id, unk_id):
hyp.append(y)
Expand Down Expand Up @@ -626,6 +646,7 @@ def greedy_search_batch(
model: nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
blank_penalty: float = 0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Expand Down Expand Up @@ -703,6 +724,10 @@ def greedy_search_batch(

logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
Expand Down Expand Up @@ -923,6 +948,7 @@ def modified_beam_search(
context_graph: Optional[ContextGraph] = None,
beam: int = 4,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Expand Down Expand Up @@ -1028,6 +1054,9 @@ def modified_beam_search(

logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)

log_probs.add_(ys_log_probs)
Expand Down Expand Up @@ -1662,6 +1691,7 @@ def beam_search(
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""
Expand Down Expand Up @@ -1758,6 +1788,9 @@ def beam_search(
project_input=False,
)

if blank_penalty != 0:
logits[:, :, :, 0] -= blank_penalty

# TODO(fangjun): Scale the blank posterior
log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
Expand Down