Skip to content

Commit

Permalink
Merge pull request #4582 from b-flo/branchformer_transducer
Browse files Browse the repository at this point in the history
Offline/Online Branchformer Transducer
  • Loading branch information
sw005320 committed Sep 11, 2022
2 parents d42af0c + 21d6bb2 commit 36e824b
Show file tree
Hide file tree
Showing 24 changed files with 597 additions and 160 deletions.
55 changes: 36 additions & 19 deletions doc/espnet2_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -446,28 +446,25 @@ The architecture is composed of three modules: encoder, decoder and joint networ

#### Encoder

For the encoder, we propose a unique encoder type encapsulating the following blocks: Conformer and Conv 1D (other X-former such as Branchformer or Enformer will be supported soon).
For the encoder, we propose a unique encoder type encapsulating the following blocks: Branchformer, Conformer and Conv 1D (other X-former such as Squeezeformer or Enformer will be supported later).
It is similar to the custom encoder in ESPnet1, meaning we don't need to set the parameter `encoder: [type]` here. Instead, the encoder architecture is defined by three configurations passed to `encoder_conf`:

1. `input_conf` (**Dict**): The configuration for the input block.
2. `main_conf` (**Dict**): The main configuration for the parameters shared across all blocks.
3. `body_conf` (**List[Dict]**): The list of configurations for each block of the encoder architecture but the input block.

The first and second configurations are optional. If needed, fhe following parameters can be modified in each configuration:
The first and second configurations are optional. If needed, the following parameters can be modified in each configuration:

main_conf:
pos_wise_act_type: Position-wise activation type. (str, default = "swish")
conv_mod_act_type: Convolutional module activation type. (str, default = "swish")
pos_wise_act_type: Conformer position-wise feed-forward activation type. (str, default = "swish")
conv_mod_act_type: Conformer convolution module activation type. (str, default = "swish")
pos_enc_dropout_rate: Dropout rate for the positional encoding layer, if used. (float, default = 0.0)
pos_enc_max_len: Positional encoding maximum length. (int, default = 5000)
simplified_att_score: Whether to use simplified attention score computation. (bool, default = False)
after_norm_type: Final normalization type. (str, default = "layer_norm")
after_norm_eps: Epsilon value for the final normalization. (float, default = None)
after_norm_partial: Value for the final normalization with RMSNorm. (float, default = None)
dynamic_chunk_training: Whether to train streaming model with dynamic chunks. (bool, default = False)
short_chunk_threshold: Chunk length threshold (in percent) for dynamic chunk selection. (int, default = 0.75)
short_chunk_size: Minimum number of frames during dynamic chunk training. (int, default = 25)
left_chunk_size: Number of frames in left context. (int, default = 0)
norm_type: X-former normalization module type. (str, default = "layer_norm")
conv_mod_norm_type: Branchformer convolution module normalization type. (str, default = "layer_norm")
after_norm_eps: Epsilon value for the final normalization module. (float, default = 1e-05 or 0.25 for BasicNorm)
after_norm_partial: Partial value for the final normalization module, if norm_type = 'rms_norm'. (float, default = -1.0)
# For more information on the parameters below, please refer to espnet2/asr_transducer/activation.py
ftswish_threshold: Threshold value for FTSwish activation formulation.
ftswish_mean_shift: Mean shifting value for FTSwish activation formulation.
Expand All @@ -490,7 +487,7 @@ The only mandatory configuration is `body_conf`, defining the encoder body archi
# Conv 1D
- block_type: conv1d
output_size: Output size. (int)
kernel_size: Size of the context window. (int or Tuple)
kernel_size: Size of the convolving kernel. (int or Tuple)
stride (optional): Stride of the sliding blocks. (int or tuple, default = 1)
dilation (optional): Parameter to control the stride of elements within the neighborhood. (int or tuple, default = 1)
groups (optional): Number of blocked connections from input channels to output channels. (int, default = 1)
Expand All @@ -499,18 +496,32 @@ The only mandatory configuration is `body_conf`, defining the encoder body archi
batch_norm: Whether to use batch normalization after convolution. (bool, default = False)
dropout_rate (optional): Dropout rate for the Conv1d outputs. (float, default = 0.0)

# Branchformer
- block_type: branchformer
hidden_size: Hidden (and output) dimension. (int)
linear_size: Dimension of the Linear layers. (int)
conv_mod_kernel_size: Size of the convolving kernel in the convolutional module. (int)
heads (optional): Number of heads in multi-head attention. (int, default = 4)
norm_eps (optional): Epsilon value for the normalization module. (float, default = 1e-05 or 0.25 for BasicNorm)
norm_partial (optional): Partial value for the normalization module, if norm_type = 'rms_norm'. (float, default = -1.0)
conv_mod_norm_eps (optional): Epsilon value for convolutional module normalization. (float, default = 1e-05 or 0.25 for BasicNorm)
conv_mod_norm_partial (optional): Partial value for the convolutional module normalization, if conv_norm_type = 'rms_norm'. (float, default = -1.0)
dropout_rate (optional): Dropout rate for some intermediate layers. (float, default = 0.0)
att_dropout_rate (optional): Dropout rate for the attention module. (float, default = 0.0)

# Conformer
- block_type: conformer
hidden_size: Hidden (and output) dimension. (int)
linear_size: Dimension of feed-forward module. (int)
conv_mod_kernel_size: Number of kernel in convolutional module. (int)
conv_mod_kernel_size: Size of the convolving kernel in the convolutional module. (int)
heads (optional): Number of heads in multi-head attention. (int, default = 4)
norm_eps (optional): Epsilon value for Conformer normalization. (float, default = None)
norm_partial (optional): Value for the Conformer normalization with RMSNorm. (float, default = None)
conv_mod_norm_eps (optional): Epsilon value for convolutional module normalization. (float, default = None)
norm_eps (optional): Epsilon value for normalization module. (float, default = 1e-05 or 0.25 for BasicNorm)
norm_partial (optional): Partial value for the normalization module, if norm_type = 'rms_norm'. (float, default = -1.0)
conv_mod_norm_eps (optional): Epsilon value for Batchnorm1d in the convolutional module. (float, default = 1e-05)
conv_mod_norm_momentum (optional): Momentum value for Batchnorm1d in the convolutional module. (float, default = 0.1)
dropout_rate (optional): Dropout rate for some intermediate layers. (float, default = 0.0)
att_dropout_rate (optional): Dropout rate for the attention module. (float, default = 0.0)
pos_wise_dropout_rate (optional): Dropout rate for the position-wise module. (float, default = 0.0)
pos_wise_dropout_rate (optional): Dropout rate for the position-wise feed-forward module. (float, default = 0.0)

In addition, each block has a parameter `num_blocks` to build **N** times the defined block (int, default = 1). This is useful if you want to use a group of blocks sharing the same parameters without writing each configuration.

Expand Down Expand Up @@ -622,9 +633,14 @@ For a complete explanation on the different procedure and parameters, we refer t

#### Training

To train a streaming model, the parameter `dynamic_chunk_training` should be set to `True` in the encoder `main_conf`.
To train a streaming model, the parameter `dynamic_chunk_training` should be set to `True` in the encoder `main_conf`. From here, the user has access to two parameters in order to control the dynamic chunk selection (`short_chunk_threshold` and `short_chunk_size`) and another one to control the left context in the causal convolution and the attention module (`left_chunk_size`).

All these parameters can be configured through `main_conf`, introduced in the Encoder section:

From here, the user has access to two parameters in order to control the dynamic chunk selection (`short_chunk_threshold` and `short_chunk_size`) and another one to control the left context in the causal convolution and the attention module (`left_chunk_size`). All these parameters can be configured through the `main_conf`. The Encoder section provides a short description of the parameters.
dynamic_chunk_training: Whether to train streaming model with dynamic chunks. (bool, default = False)
short_chunk_threshold: Chunk length threshold (in percent) for dynamic chunk selection. (int, default = 0.75)
short_chunk_size: Minimum number of frames during dynamic chunk training. (int, default = 25)
left_chunk_size: Number of frames in left context. (int, default = 0)

#### Decoding

Expand All @@ -637,6 +653,7 @@ To perform chunk-by-chunk inference, the parameter `streaming` should be set to
For each parameter, the number of frames is defined AFTER subsampling, meaning the input chunk will be bigger than the one provided. The input size is determined by the frontend and the input block's subsampling, given `chunk_size + right_context` defining the decoding window.

***Note:*** Because the training part does not consider the right context, relying on `right_context` during decoding may result in a mismatch and performance degration.

***Note 2:*** All search algorithms but ALSD are available with chunk-by-chunk inference.

### FAQ
Expand Down
1 change: 1 addition & 0 deletions espnet2/asr_transducer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def get_activation(
{"beta": swish_beta, "use_builtin": torch_version >= V("1.8")},
),
"tanh": (torch.nn.Tanh, {}),
"identity": (torch.nn.Identity, {}),
}

act_func, act_args = activations[activation_type]
Expand Down
48 changes: 25 additions & 23 deletions espnet2/asr_transducer/beam_search_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Hypothesis:
yseq: Label sequence as integer ID sequence.
dec_state: RNNDecoder or StatelessDecoder state.
((N, 1, D_dec), (N, 1, D_dec) or None) or None
lm_state: RNNLM state. ((N, D_lm), (N, D_lm))
lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
"""

Expand All @@ -45,7 +45,26 @@ class ExtendedHypothesis(Hypothesis):


class BeamSearchTransducer:
"""Beam search implementation for Transducer."""
"""Beam search implementation for Transducer.
Args:
decoder: Decoder module.
joint_network: Joint network module.
beam_size: Size of the beam.
lm: LM class.
lm_weight: LM weight for soft fusion.
search_type: Search algorithm to use during inference.
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
u_max: Maximum expected target sequence length. (ALSD)
nstep: Number of maximum expansion steps at each time step. (mAES)
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
expansion_beta:
Number of additional candidates for expanded hypotheses selection. (mAES)
score_norm: Normalize final scores by length.
nbest: Number of final hypothesis.
streaming: Whether to perform chunk-by-chunk beam search.
"""

def __init__(
self,
Expand All @@ -63,27 +82,10 @@ def __init__(
score_norm: bool = False,
nbest: int = 1,
streaming: bool = False,
):
"""Initialize Transducer search module.
) -> None:
"""Construct a BeamSearchTransducer object."""
super().__init__()

Args:
decoder: Decoder module.
joint_network: Joint network module.
beam_size: Beam size.
lm: LM class.
lm_weight: LM weight for soft fusion.
search_type: Search algorithm to use during inference.
max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
u_max: Maximum expected target sequence length. (ALSD)
nstep: Number of maximum expansion steps at each time step. (mAES)
expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
expansion_beta:
Number of additional candidates for expanded hypotheses selection. (mAES)
score_norm: Normalize final scores by length.
nbest: Number of final hypothesis.
streaming: Whether to perform chunk-by-chunk beam search.
"""
self.decoder = decoder
self.joint_network = joint_network

Expand Down Expand Up @@ -174,7 +176,7 @@ def __call__(

return hyps

def reset_inference_cache(self):
def reset_inference_cache(self) -> None:
"""Reset cache for decoder scoring and streaming."""
self.decoder.score_cache = {}
self.search_cache = None
Expand Down
5 changes: 3 additions & 2 deletions espnet2/asr_transducer/decoder/rnn_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ def __init__(
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()

assert check_argument_types()

if rnn_type not in ("lstm", "gru"):
raise ValueError(f"Not supported: rnn_type={rnn_type}")

super().__init__()

self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)

Expand Down
5 changes: 3 additions & 2 deletions espnet2/asr_transducer/decoder/stateless_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ def __init__(
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
) -> None:
assert check_argument_types()

"""Construct a StatelessDecoder object."""
super().__init__()

assert check_argument_types()

self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
self.embed_dropout_rate = torch.nn.Dropout(p=embed_dropout_rate)

Expand Down

0 comments on commit 36e824b

Please sign in to comment.