Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zipformer wenetspeech #1130

Merged
merged 27 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ def decode_one_batch(
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
subtract_ilme=True,
ilme_scale=params.ilme_scale,
)
for hyp in hyp_tokens:
Expand Down
47 changes: 40 additions & 7 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import k2
import sentencepiece as spm
import torch
from torch import nn

from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
Expand All @@ -35,7 +36,6 @@
get_texts,
get_texts_with_timestamp,
)
from torch import nn


def fast_beam_search_one_best(
Expand All @@ -47,8 +47,8 @@ def fast_beam_search_one_best(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
ilme_scale: float = 0.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -90,8 +90,8 @@ def fast_beam_search_one_best(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
subtract_ilme=subtract_ilme,
ilme_scale=ilme_scale,
blank_penalty=blank_penalty,
)

best_path = one_best_decoding(lattice)
Expand All @@ -114,6 +114,8 @@ def fast_beam_search_nbest_LG(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
blank_penalty: float = 0.0,
ilme_scale: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -168,6 +170,8 @@ def fast_beam_search_nbest_LG(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
blank_penalty=blank_penalty,
ilme_scale=ilme_scale,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -240,6 +244,7 @@ def fast_beam_search_nbest(
nbest_scale: float = 0.5,
use_double_scores: bool = True,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -293,6 +298,7 @@ def fast_beam_search_nbest(
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
blank_penalty=blank_penalty,
temperature=temperature,
)

Expand Down Expand Up @@ -331,6 +337,7 @@ def fast_beam_search_nbest_oracle(
use_double_scores: bool = True,
nbest_scale: float = 0.5,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
Expand Down Expand Up @@ -389,6 +396,7 @@ def fast_beam_search_nbest_oracle(
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
blank_penalty=blank_penalty,
)

nbest = Nbest.from_lattice(
Expand Down Expand Up @@ -432,8 +440,8 @@ def fast_beam_search(
max_states: int,
max_contexts: int,
temperature: float = 1.0,
subtract_ilme: bool = False,
ilme_scale: float = 0.1,
ilme_scale: float = 0.0,
blank_penalty: float = 0.0,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.

Expand Down Expand Up @@ -503,8 +511,13 @@ def fast_beam_search(
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

log_probs = (logits / temperature).log_softmax(dim=-1)
if subtract_ilme:

if ilme_scale != 0:
ilme_logits = model.joiner(
torch.zeros_like(
current_encoder_out, device=current_encoder_out.device
Expand All @@ -513,8 +526,11 @@ def fast_beam_search(
project_input=False,
)
ilme_logits = ilme_logits.squeeze(1).squeeze(1)
if blank_penalty != 0:
ilme_logits[:, 0] -= blank_penalty
ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1)
log_probs -= ilme_scale * ilme_log_probs

decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
Expand All @@ -526,6 +542,7 @@ def greedy_search(
model: nn.Module,
encoder_out: torch.Tensor,
max_sym_per_frame: int,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance.
Expand Down Expand Up @@ -595,6 +612,9 @@ def greedy_search(
)
# logits is (1, 1, 1, vocab_size)

if blank_penalty != 0:
logits[:, :, :, 0] -= blank_penalty

y = logits.argmax().item()
if y not in (blank_id, unk_id):
hyp.append(y)
Expand Down Expand Up @@ -626,6 +646,7 @@ def greedy_search_batch(
model: nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
blank_penalty: float = 0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Expand Down Expand Up @@ -703,6 +724,10 @@ def greedy_search_batch(

logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
Expand Down Expand Up @@ -923,6 +948,7 @@ def modified_beam_search(
context_graph: Optional[ContextGraph] = None,
beam: int = 4,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Expand Down Expand Up @@ -1028,6 +1054,9 @@ def modified_beam_search(

logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)

if blank_penalty != 0:
logits[:, 0] -= blank_penalty

log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)

log_probs.add_(ys_log_probs)
Expand Down Expand Up @@ -1662,6 +1691,7 @@ def beam_search(
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
blank_penalty: float = 0.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""
Expand Down Expand Up @@ -1758,6 +1788,9 @@ def beam_search(
project_input=False,
)

if blank_penalty != 0:
logits[:, :, :, 0] -= blank_penalty

# TODO(fangjun): Scale the blank posterior
log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
Expand Down
36 changes: 18 additions & 18 deletions egs/librispeech/ASR/zipformer/export-onnx-streaming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)

"""
Expand All @@ -19,7 +19,7 @@
repo=$(basename $repo_url)

pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"

cd exp
Expand All @@ -29,7 +29,7 @@
2. Export the model to ONNX

./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
Expand Down Expand Up @@ -57,9 +57,9 @@

It will generate the following 3 files inside $repo/exp:

- encoder-epoch-99-avg-1.onnx
- decoder-epoch-99-avg-1.onnx
- joiner-epoch-99-avg-1.onnx
- encoder-epoch-99-avg-1-chunk-16-left-64.onnx
- decoder-epoch-99-avg-1-chunk-16-left-64.onnx
- joiner-epoch-99-avg-1-chunk-16-left-64.onnx

See ./onnx_pretrained-streaming.py for how to use the exported ONNX models.
"""
Expand All @@ -69,14 +69,15 @@
from pathlib import Path
from typing import Dict, List, Tuple

import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_model
from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2

from icefall.checkpoint import (
Expand All @@ -85,7 +86,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool, make_pad_mask
from icefall.utils import make_pad_mask, str2bool


def get_parser():
Expand Down Expand Up @@ -142,10 +143,10 @@ def get_parser():
)

parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -585,12 +586,9 @@ def main():

logging.info(f"device: {device}")

sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)

# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down Expand Up @@ -709,6 +707,8 @@ def main():
suffix = f"epoch-{params.epoch}"

suffix += f"-avg-{params.avg}"
suffix += f"-chunk-{params.chunk_size}"
suffix += f"-left-{params.left_context_frames}"

opset_version = 13

Expand Down
29 changes: 13 additions & 16 deletions egs/librispeech/ASR/zipformer/export-onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3
#
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang)
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Wei Kang)
# Copyright 2023 Danqing Fu (danqing.fu@gmail.com)

"""
Expand All @@ -19,7 +19,7 @@
repo=$(basename $repo_url)

pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"

cd exp
Expand All @@ -29,12 +29,11 @@
2. Export the model to ONNX

./zipformer/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
--exp-dir $repo/exp \
\
--num-encoder-layers "2,2,3,4,3,2" \
--downsampling-factor "1,2,4,8,4,2" \
--feedforward-dim "512,768,1024,1536,1024,768" \
Expand Down Expand Up @@ -67,14 +66,15 @@
from pathlib import Path
from typing import Dict, Tuple

import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_model
from train import add_model_arguments, get_model, get_params
from zipformer import Zipformer2

from icefall.checkpoint import (
Expand All @@ -83,7 +83,7 @@
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool, make_pad_mask
from icefall.utils import make_pad_mask, str2bool


def get_parser():
Expand Down Expand Up @@ -140,10 +140,10 @@ def get_parser():
)

parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)

parser.add_argument(
Expand Down Expand Up @@ -434,12 +434,9 @@ def main():

logging.info(f"device: {device}")

sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)

# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1

logging.info(params)

Expand Down