Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming Zipformer with multi-dataset #984

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,16 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info(f"About to get train_{self.args.subset} cuts")
path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
cuts_train = CutSet.from_jsonl_lazy(path)
return cuts_train

@lru_cache()
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
)
if self.args.small_dev:
return cuts_valid.subset(first=1000)
else:
Expand All @@ -403,4 +405,6 @@ def dev_cuts(self) -> CutSet:
@lru_cache()
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
)
3 changes: 2 additions & 1 deletion egs/librispeech/ASR/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
log-*
.DS_Store
.DS_Store
run*.sh
1 change: 1 addition & 0 deletions egs/librispeech/ASR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The following table lists the differences among them.
| `pruned_transducer_stateless7_ctc` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but with extra CTC head|
| `pruned_transducer_stateless7_ctc_bs` | Zipformer | Embedding + Conv1d | pruned_transducer_stateless7_ctc + blank skip |
| `pruned_transducer_stateless7_streaming` | Streaming Zipformer | Embedding + Conv1d | streaming version of pruned_transducer_stateless7 |
| `pruned_transducer_stateless7_streaming_multi` | Streaming Zipformer | Embedding + Conv1d | same as pruned_transducer_stateless7_streaming, trained on LibriSpeech + GigaSpeech |
| `pruned_transducer_stateless8` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech|
| `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR|
| `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model |
Expand Down
140 changes: 138 additions & 2 deletions egs/librispeech/ASR/RESULTS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,141 @@
## Results

### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer + Multi-Dataset)

#### [pruned_transducer_stateless7_streaming_multi](./pruned_transducer_stateless7_streaming_multi)

See <https://github.com/k2-fsa/icefall/pull/984> for more details.

You can find a pretrained model, training logs, decoding logs, and decoding
results at: <https://huggingface.co/marcoyang/icefall-libri-giga-pruned-transducer-stateless7-streaming-2023-04-04>

Number of model parameters: 70369391, i.e., 70.37 M

##### training on full librispeech + full gigaspeech (with giga_prob=0.9)

The WERs are:


| decoding method | chunk size | test-clean | test-other | comment | decoding mode |
|----------------------|------------|------------|------------|---------------------|----------------------|
| greedy search | 320ms | 2.43 | 6.0 | --epoch 20 --avg 4 | simulated streaming |
| greedy search | 320ms | 2.47 | 6.13 | --epoch 20 --avg 4 | chunk-wise |
| fast beam search | 320ms | 2.43 | 5.99 | --epoch 20 --avg 4 | simulated streaming |
| fast beam search | 320ms | 2.8 | 6.46 | --epoch 20 --avg 4 | chunk-wise |
| modified beam search | 320ms | 2.4 | 5.96 | --epoch 20 --avg 4 | simulated streaming |
| modified beam search | 320ms | 2.42 | 6.03 | --epoch 20 --avg 4 | chunk-size |
| greedy search | 640ms | 2.26 | 5.58 | --epoch 20 --avg 4 | simulated streaming |
| greedy search | 640ms | 2.33 | 5.76 | --epoch 20 --avg 4 | chunk-wise |
| fast beam search | 640ms | 2.27 | 5.54 | --epoch 20 --avg 4 | simulated streaming |
| fast beam search | 640ms | 2.37 | 5.75 | --epoch 20 --avg 4 | chunk-wise |
| modified beam search | 640ms | 2.22 | 5.5 | --epoch 20 --avg 4 | simulated streaming |
| modified beam search | 640ms | 2.25 | 5.69 | --epoch 20 --avg 4 | chunk-size |

The model also has good WERs on GigaSpeech. The following WERs are achieved on GigaSpeech test and dev sets:

| decoding method | chunk size | dev | test | comment | decoding mode |
|----------------------|------------|-----|------|------------|---------------------|
| greedy search | 320ms | 12.08 | 11.98 | --epoch 20 --avg 4 | simulated streaming |
| greedy search | 640ms | 11.66 | 11.71 | --epoch 20 --avg 4 | simulated streaming |
| modified beam search | 320ms | 11.95 | 11.83 | --epoch 20 --avg 4 | simulated streaming |
| modified beam search | 320ms | 11.65 | 11.56 | --epoch 20 --avg 4 | simulated streaming |


Note: `simulated streaming` indicates feeding full utterance during decoding using `decode.py`,
while `chunk-size` indicates feeding certain number of frames at each time using `streaming_decode.py`.

The training command is:

```bash
./pruned_transducer_stateless7_streaming_multi/train.py \
--world-size 4 \
--num-epochs 20 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7_streaming_multi/exp \
--full-libri 1 \
--giga-prob 0.9 \
--max-duration 750 \
--master-port 12345
```

The tensorboard log can be found at
<https://tensorboard.dev/experiment/G4yDMLXGQXexf41i4MA2Tg/#scalars>

The simulated streaming decoding command (e.g., chunk-size=320ms) is:
```bash
for m in greedy_search fast_beam_search modified_beam_search; do
./pruned_transducer_stateless7_streaming_multi/decode.py \
--epoch 20 \
--avg 4 \
--exp-dir ./pruned_transducer_stateless7_streaming_multi/exp \
--max-duration 600 \
--decode-chunk-len 32 \
--right-padding 64 \
--decoding-method $m
done
```

The streaming chunk-size decoding command (e.g., chunk-size=320ms) is:
```bash
for m in greedy_search modified_beam_search fast_beam_search; do
./pruned_transducer_stateless7_streaming_multi/streaming_decode.py \
--epoch 20 \
--avg 4 \
--exp-dir ./pruned_transducer_stateless7_streaming_multi/exp \
--decoding-method $m \
--decode-chunk-len 32 \
--num-decode-streams 2000
done
```


#### Smaller model

We also provide a very small version (only 6.1M parameters) of this setup. The training command for the small model is:

```bash
./pruned_transducer_stateless7_streaming_multi/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7_streaming_multi/exp \
--full-libri 1 \
--giga-prob 0.9 \
--num-encoder-layers "2,2,2,2,2" \
--feedforward-dims "256,256,512,512,256" \
--nhead "4,4,4,4,4" \
--encoder-dims "128,128,128,128,128" \
--attention-dims "96,96,96,96,96" \
--encoder-unmasked-dims "96,96,96,96,96" \
--max-duration 1200 \
--master-port 12345
```

You can find this pretrained small model and its training logs, decoding logs, and decoding
results at:
<https://huggingface.co/marcoyang/icefall-libri-giga-pruned-transducer-stateless7-streaming-6M-2023-04-03>


| decoding method | chunk size | test-clean | test-other | comment | decoding mode |
|----------------------|------------|------------|------------|---------------------|----------------------|
| greedy search | 320ms | 5.95 | 15.03 | --epoch 30 --avg 1 | simulated streaming |
| greedy search | 640ms | 5.61 | 13.86 | --epoch 30 --avg 1 | simulated streaming |
| modified beam search | 320ms | 5.72 | 14.34 | --epoch 30 --avg 1 | simulated streaming |
| modified beam search | 640ms | 5.43 | 13.16 | --epoch 30 --avg 1 | simulated streaming |
| fast beam search | 320ms | 5.88 | 14.45 | --epoch 30 --avg 1 | simulated streaming |
| fast beam search | 640ms | 5.48 | 13.31 | --epoch 30 --avg 1 | simulated streaming |

This small model achieves the following WERs on GigaSpeech test and dev sets:

| decoding method | chunk size | dev | test | comment | decoding mode |
|----------------------|------------|------------|------------|---------------------|----------------------|
| greedy search | 320ms | 17.57 | 17.2 | --epoch 30 --avg 1 | simulated streaming |
| modified beam search | 320ms | 16.98 | 11.98 | --epoch 30 --avg 1 | simulated streaming |

You can find the tensorboard logs at <https://tensorboard.dev/experiment/tAc5iXxTQrCQxky5O5OLyw/#scalars>.

### Streaming Zipformer-Transducer (Pruned Stateless Transducer + Streaming Zipformer)

#### [pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
Expand Down Expand Up @@ -53,7 +189,7 @@ The tensorboard log can be found at

The simulated streaming decoding command (e.g., chunk-size=320ms) is:
```bash
for $m in greedy_search fast_beam_search modified_beam_search; do
for m in greedy_search fast_beam_search modified_beam_search; do
./pruned_transducer_stateless7_streaming/decode.py \
--epoch 30 \
--avg 9 \
Expand Down Expand Up @@ -599,7 +735,7 @@ done
```

Note that a small change is made to the `pruned_transducer_stateless7/decoder.py` in
this [PR](/ceph-data4/yangxiaoyu/softwares/icefall_development/icefall_random_padding/egs/librispeech/ASR/pruned_transducer_stateless7/exp_960h_no_paddingidx_ngpu4/tensorboard) to address the
this [PR](https://github.com/k2-fsa/icefall/pull/942) to address the
problem of emitting the first symbol at the very beginning. If you need a
model without this issue, please download the model from here: <https://huggingface.co/marcoyang/icefall-asr-librispeech-pruned-transducer-stateless7-2023-03-10>

Expand Down
5 changes: 3 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,9 @@ def forward(ctx, x: Tensor, y: Tensor):

@staticmethod
def backward(ctx, ans_grad: Tensor):
return ans_grad, torch.ones(
ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device
return (
ans_grad,
torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
\
--decode-chunk-len 32 \
--num-encoder-layers "2,4,3,2,4" \
--feedforward-dims "1024,1024,2048,2048,1024" \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,12 @@
import argparse
import logging

import torch
from onnx_pretrained import OnnxModel
from zipformer import stack_states

from icefall import is_module_available

import torch


def get_parser():
parser = argparse.ArgumentParser(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,29 @@ def __init__(self, args):
self.init_joiner(args)

# Please change the parameters according to your model

# 20M
# self.num_encoder_layers = to_int_tuple("2,2,2,2,2")
# self.encoder_dims = to_int_tuple("256,256,256,256,256") # also known as d_model
# self.attention_dims = to_int_tuple("192,192,192,192,192")
# self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2")
# self.cnn_module_kernels = to_int_tuple("31,31,31,31,31")

# 9.6M
# self.num_encoder_layers = to_int_tuple("2,3,2,2,3")
# self.encoder_dims = to_int_tuple("160,160,160,160,160") # also known as d_model
# self.attention_dims = to_int_tuple("96,96,96,96,96")
# self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2")
# self.cnn_module_kernels = to_int_tuple("31,31,31,31,31")

# 5.5M or 6M

# self.num_encoder_layers = to_int_tuple("2,2,2,2,2")
# self.encoder_dims = to_int_tuple("128,128,128,128,128") # also known as d_model
# self.attention_dims = to_int_tuple("96,96,96,96,96")
# self.zipformer_downsampling_factors = to_int_tuple("1,2,4,8,2")
# self.cnn_module_kernels = to_int_tuple("31,31,31,31,31")

self.num_encoder_layers = to_int_tuple("2,4,3,2,4")
self.encoder_dims = to_int_tuple("384,384,384,384,384") # also known as d_model
self.attention_dims = to_int_tuple("192,192,192,192,192")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,43 @@ def test_model():
model = torch.jit.script(model)


def test_model_small():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = "2,2,2,2,2"
params.feedforward_dims = "256,256,512,512,256"
params.nhead = "4,4,4,4,4"
params.encoder_dims = "128,128,128,128,128"
params.attention_dims = "96,96,96,96,96"
params.encoder_unmasked_dims = "96,96,96,96,96"
params.zipformer_downsampling_factors = "1,2,4,8,2"
params.cnn_module_kernels = "31,31,31,31,31"
params.decoder_dim = 320
params.joiner_dim = 320
params.num_left_chunks = 4
params.short_chunk_size = 50
params.decode_chunk_len = 32
model = get_transducer_model(params)

num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
import pdb

pdb.set_trace()

# Test jit script
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
print("Using torch.jit.script")
model = torch.jit.script(model)


def test_model_jit_trace():
params = get_params()
params.vocab_size = 500
Expand Down Expand Up @@ -142,7 +179,7 @@ def _test_joiner():


def main():
test_model()
test_model_small()
test_model_jit_trace()


Expand Down
24 changes: 12 additions & 12 deletions egs/librispeech/ASR/pruned_transducer_stateless7_streaming/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,10 +1049,10 @@ def run(rank, world_size, args):

librispeech = LibriSpeechAsrDataModule(args)

train_cuts = librispeech.train_clean_100_cuts()
if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
train_cuts += librispeech.train_clean_360_cuts()
train_cuts += librispeech.train_other_500_cuts()

def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
Expand Down Expand Up @@ -1091,7 +1091,7 @@ def remove_short_and_long_utt(c: Cut):

return True

train_cuts = train_cuts.filter(remove_short_and_long_utt)
# train_cuts = train_cuts.filter(remove_short_and_long_utt)

if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
Expand All @@ -1108,14 +1108,14 @@ def remove_short_and_long_utt(c: Cut):
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)

if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# params=params,
# )

scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
Expand Down