Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for K2 pruned transducer loss #5268

Merged
merged 11 commits into from
Jul 13, 2023
22 changes: 20 additions & 2 deletions doc/espnet2_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ Latency: 52581.004 [ms/sentence]

## Transducer ASR

> ***Important***: If you encounter any issue related to Transducer loss, please open an issue in [our fork of warp-transducer](https://github.com/b-flo/warp-transducer).
> ***Important***: If you encounter any issue related to `warp-transducer`, please open an issue in [our forked repo](https://github.com/b-flo/warp-transducer).

ESPnet2 supports models trained with the (RNN-)Tranducer loss, aka Transducer models. Currently, two versions of these models exist within ESPnet2: one under `asr` and the other under `asr_transducer`. The first one is designed as a supplement of CTC-Attention ASR models while the second is designed independently and purely for the Transducer task. For that, we rely on `ESPnetASRTransducerModel` instead of `ESPnetASRModel` and a new task called `ASRTransducerTask` is used in place of `ASRTask`.

Expand All @@ -431,13 +431,31 @@ To enable Transducer model training or decoding in your experiments, the followi
asr.sh --asr_task asr_transducer [...]
```

For Transducer loss computation during training, we rely on a fork of `warp-transducer`. The installation procedure is described [here](https://espnet.github.io/espnet/installation.html#step-3-optional-custom-tool-installation).
For Transducer loss computation during training, we rely by default on a fork of `warp-transducer`. The installation procedure is described [here](https://espnet.github.io/espnet/installation.html#step-3-optional-custom-tool-installation).

**Note:** We made available FastEmit regularization [[Yu et al., 2021]](https://arxiv.org/pdf/2010.11148) during loss computation. To enable it, `fastemit_lambda` need to be set in `model_conf`:

model_conf:
fastemit_lambda: Regularization parameter for FastEmit. (float, default = 0.0)

Optionnaly, we also support training with the Pruned RNN-T loss [[Kuang et al. 2022]](https://arxiv.org/pdf/2206.13236.pdf) made available in the [k2](https://github.com/k2-fsa/k2) toolkit. To use it, the parameter `use_k2_pruned_loss` should be set to `True` in `model_conf`. From here, the loss computation can be controlled by setting the following parameters through `k2_pruned_loss_args` in `model_conf`:

model_conf:
use_k2_pruned_loss: True
k2_pruned_loss_args:
prune_range: How many tokens by frame are used compute the pruned loss. (int, default = 5)
simple_loss_scaling: The weight to scale the simple loss after warm-up. (float, default = 0.5)
lm_scale: The scale factor to smooth the LM part. (float, default = 0.0)
am_scale: The scale factor to smooth the AM part. (float, default = 0.0)
loss_type: Define the type of path to take for loss computation, either 'regular', 'smoothed' or 'constrained'. (str, default = "regular")

**Note:** Because the number of tokens emitted by timestep can be restricted during training with this version, we also make available the parameter `validation_nstep`. It let the users apply similar constraints during the validation process, when reporting CER or/and WER:

model_conf:
validation_nstep: Maximum number of symbol expansions at each time step when reporting CER or/and WER using mAES.

For more information, see section Inference and "modified Adaptive Expansion Search" algorithm.

### Architecture

The architecture is composed of three modules: encoder, decoder and joint network. Each module has one (or three) config(s) with various parameters in order to configure the internal parts. The following sections describe the mandatory and optional parameters for each module.
Expand Down
1 change: 1 addition & 0 deletions espnet2/asr_transducer/beam_search_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
[hyp for hyp in kept_hyps if hyp.score > hyps_max],
key=lambda x: x.score,
)

if len(kept_most_prob) >= self.beam_size:
kept_hyps = kept_most_prob
break
Expand Down
4 changes: 3 additions & 1 deletion espnet2/asr_transducer/decoder/mega_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,9 @@ def batch_score(
states:

"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
labels = torch.tensor(
[[h.yseq[-1]] for h in hyps], dtype=torch.long, device=self.device
)
states = self.create_batch_states([h.dec_state for h in hyps])

out, states = self.inference(labels, states=states)
Expand Down
4 changes: 3 additions & 1 deletion espnet2/asr_transducer/decoder/rnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def batch_score(
states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)

"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
labels = torch.tensor(
[[h.yseq[-1]] for h in hyps], dtype=torch.long, device=self.device
)
embed = self.embed(labels)

states = self.create_batch_states([h.dec_state for h in hyps])
Expand Down
4 changes: 3 additions & 1 deletion espnet2/asr_transducer/decoder/rwkv_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def batch_score(
states: Decoder hidden states. [5 x (B, 1, D_att/D_dec, N)]

"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
labels = torch.tensor(
[[h.yseq[-1]] for h in hyps], dtype=torch.long, device=self.device
)
states = self.create_batch_states([h.dec_state for h in hyps])

out, states = self.inference(labels, states)
Expand Down
4 changes: 3 additions & 1 deletion espnet2/asr_transducer/decoder/stateless_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def batch_score(self, hyps: List[Hypothesis]) -> Tuple[torch.Tensor, None]:
states: Decoder hidden states. None

"""
labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
labels = torch.tensor(
[[h.yseq[-1]] for h in hyps], dtype=torch.long, device=self.device
)
embed = self.embed(labels)

return embed.squeeze(1), None
Expand Down
19 changes: 17 additions & 2 deletions espnet2/asr_transducer/error_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ErrorCalculator:
token_list: List of token units.
sym_space: Space symbol.
sym_blank: Blank symbol.
nstep: Maximum number of symbol expansions at each time step w/ mAES.
report_cer: Whether to compute CER.
report_wer: Whether to compute WER.

Expand All @@ -30,17 +31,31 @@ def __init__(
token_list: List[int],
sym_space: str,
sym_blank: str,
nstep: int = 2,
report_cer: bool = False,
report_wer: bool = False,
) -> None:
"""Construct an ErrorCalculatorTransducer object."""
super().__init__()

# (b-flo): Since the commit #8c9c851 we rely on the mAES algorithm for
# validation instead of the default algorithm.
#
# With the addition of k2 pruned transducer loss, the number of emitted symbols
# at each timestep can be restricted during training. Performing an unrestricted
# (/ unconstrained) decoding without regard to the training conditions can lead
# to huge performance degradation. It won't be an issue with mAES and the user
# can now control the number of emitted symbols during validation.
#
# Also, under certain conditions, using the default algorithm can lead to a long
# decoding procedure due to the loop break condition. Other algorithms,
# such as mAES, won't be impacted by that.
self.beam_search = BeamSearchTransducer(
decoder=decoder,
joint_network=joint_network,
beam_size=1,
search_type="default",
beam_size=2,
search_type="maes",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to change it from default to maes?
What is the intention of this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to change it from default to maes?

It is, it won't make a difference for the reporting.

What is the intention of this change?

To control the number of emitted symbols at each timestep in both training and validation.
With k2, we can end-up with a big performance degradation if we enforce an unrestricted decoding without regards to the training condition/parameters (See. https://arxiv.org/pdf/2211.00484.pdf). With maes, it'll be "stable" under all conditions.

Also, in some cases you can end-up with a really long looping procedure with default (e.g.: during early stages with some architecture such as MEGA or if there is a training/decoding condition mismatch as I said). It won't happen with other decoding algorithms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.
I also encountered the issue of the number of emitted symbols, and this change sounds good to me.
Can you leave this discussion somewhere in the source code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done but I'm not sure it's in a right place nor in a right form..!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sw005320 Is that okay for you like that or do you want me to add a note in the tutorial doc instead? I'll merge after that

Copy link
Contributor

@sw005320 sw005320 Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure. the current one is OK
(I'm trying to fix the CI issue...)

nstep=nstep,
score_norm=False,
)

Expand Down