-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4845 from m-koichi/s4-decoder
[WIP] Add S4 decoder in ESPnet2
- Loading branch information
Showing
25 changed files
with
4,270 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
egs2/librispeech/asr1/conf/tuning/train_asr_s4_decoder.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Trained with Tesla A100 (40GB) x 4 GPUs. It takes about 2.5 days. | ||
encoder: conformer | ||
encoder_conf: | ||
output_size: 512 | ||
attention_heads: 8 | ||
linear_units: 2048 | ||
num_blocks: 12 | ||
dropout_rate: 0.1 | ||
positional_dropout_rate: 0.1 | ||
attention_dropout_rate: 0.1 | ||
input_layer: conv2d | ||
normalize_before: true | ||
macaron_style: true | ||
rel_pos_type: latest | ||
pos_enc_layer_type: rel_pos | ||
selfattention_layer_type: rel_selfattn | ||
activation_type: swish | ||
use_cnn_module: true | ||
cnn_module_kernel: 31 | ||
|
||
decoder: s4 | ||
decoder_conf: | ||
dropinp: 0.0 | ||
dropout: 0.1 | ||
drop_path: 0.1 | ||
prenorm: true | ||
n_layers: 6 | ||
# specify each block config here. for detail, see espnet2/asr/state_spaces | ||
# layer: stack of black box modules, _name_ must be specified | ||
# residual: choose an option for residual connection | ||
layer: | ||
- _name_: s4 | ||
postact: glu | ||
dropout: 0.1 | ||
n_ssm: 1 | ||
lr: 0.0025 | ||
dt_min: 0.001 | ||
dt_max: 0.1 | ||
measure: legs | ||
- _name_: mha | ||
n_head: 8 | ||
dropout: 0.1 | ||
- _name_: ff | ||
expand: 4 | ||
activation: gelu | ||
dropout: 0.1 | ||
residual: residual | ||
norm: layer | ||
|
||
model_conf: | ||
ctc_weight: 0.3 | ||
lsm_weight: 0.1 | ||
length_normalized_loss: false | ||
|
||
frontend_conf: | ||
n_fft: 512 | ||
hop_length: 160 | ||
|
||
use_amp: true | ||
num_workers: 4 | ||
batch_type: numel | ||
batch_bins: 35000000 | ||
accum_grad: 4 | ||
max_epoch: 60 | ||
patience: none | ||
init: none | ||
best_model_criterion: | ||
- - valid | ||
- acc | ||
- max | ||
keep_nbest_models: 10 | ||
|
||
optim: adamw | ||
exclude_weight_decay: true | ||
optim_conf: | ||
lr: 0.0025 | ||
weight_decay: 0.01 | ||
scheduler: warmuplr | ||
scheduler_conf: | ||
warmup_steps: 40000 | ||
|
||
specaug: specaug | ||
specaug_conf: | ||
apply_time_warp: true | ||
time_warp_window: 5 | ||
time_warp_mode: bicubic | ||
apply_freq_mask: true | ||
freq_mask_width_range: | ||
- 0 | ||
- 27 | ||
num_freq_mask: 2 | ||
apply_time_mask: true | ||
time_mask_width_ratio_range: | ||
- 0. | ||
- 0.05 | ||
num_time_mask: 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
"""Decoder definition.""" | ||
from typing import Any, List, Tuple | ||
|
||
import torch | ||
from typeguard import check_argument_types | ||
|
||
from espnet2.asr.decoder.abs_decoder import AbsDecoder | ||
from espnet2.asr.state_spaces.model import SequenceModel | ||
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask | ||
from espnet.nets.scorer_interface import BatchScorerInterface | ||
|
||
|
||
class S4Decoder(AbsDecoder, BatchScorerInterface): | ||
"""S4 decoder module. | ||
Args: | ||
vocab_size: output dim | ||
encoder_output_size: dimension of hidden vector | ||
input_layer: input layer type | ||
dropinp: input dropout | ||
dropout: dropout parameter applied on every residual and every layer | ||
prenorm: pre-norm vs. post-norm | ||
n_layers: number of layers | ||
transposed: transpose inputs so each layer receives (batch, dim, length) | ||
tie_dropout: tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d | ||
n_repeat: each layer is repeated n times per stage before applying pooling | ||
layer: layer config, must be specified | ||
residual: residual config | ||
norm: normalization config (e.g. layer vs batch) | ||
pool: config for pooling layer per stage | ||
track_norms: log norms of each layer output | ||
drop_path: drop rate for stochastic depth | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocab_size: int, | ||
encoder_output_size: int, | ||
input_layer: str = "embed", | ||
dropinp: float = 0.0, | ||
dropout: float = 0.25, | ||
prenorm: bool = True, | ||
n_layers: int = 16, | ||
transposed: bool = False, | ||
tie_dropout: bool = False, | ||
n_repeat=1, | ||
layer=None, | ||
residual=None, | ||
norm=None, | ||
pool=None, | ||
track_norms=True, | ||
drop_path: float = 0.0, | ||
): | ||
assert check_argument_types() | ||
super().__init__() | ||
|
||
self.d_model = encoder_output_size | ||
self.sos = vocab_size - 1 | ||
self.eos = vocab_size - 1 | ||
self.odim = vocab_size | ||
self.dropout = dropout | ||
|
||
if input_layer == "embed": | ||
self.embed = torch.nn.Embedding(vocab_size, self.d_model) | ||
else: | ||
raise NotImplementedError | ||
self.dropout_emb = torch.nn.Dropout(p=dropout) | ||
|
||
self.decoder = SequenceModel( | ||
self.d_model, | ||
n_layers=n_layers, | ||
transposed=transposed, | ||
dropout=dropout, | ||
tie_dropout=tie_dropout, | ||
prenorm=prenorm, | ||
n_repeat=n_repeat, | ||
layer=layer, | ||
residual=residual, | ||
norm=norm, | ||
pool=pool, | ||
track_norms=track_norms, | ||
dropinp=dropinp, | ||
drop_path=drop_path, | ||
) | ||
|
||
self.output = torch.nn.Linear(self.d_model, vocab_size) | ||
|
||
def init_state(self, x: torch.Tensor): | ||
"""Initialize state.""" | ||
return self.decoder.default_state(1, device=x.device) | ||
|
||
def forward( | ||
self, | ||
hs_pad: torch.Tensor, | ||
hlens: torch.Tensor, | ||
ys_in_pad: torch.Tensor, | ||
ys_in_lens: torch.Tensor, | ||
state=None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Forward decoder. | ||
Args: | ||
hs_pad: encoded memory, float32 (batch, maxlen_in, feat) | ||
hlens: (batch) | ||
ys_in_pad: | ||
input token ids, int64 (batch, maxlen_out) | ||
if input_layer == "embed" | ||
input tensor (batch, maxlen_out, #mels) in the other cases | ||
ys_in_lens: (batch) | ||
Returns: | ||
(tuple): tuple containing: | ||
x: decoded token score before softmax (batch, maxlen_out, token) | ||
if use_output_layer is True, | ||
olens: (batch, ) | ||
""" | ||
memory = hs_pad | ||
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to( | ||
memory.device | ||
) | ||
|
||
emb = self.embed(ys_in_pad) | ||
z, state = self.decoder( | ||
emb, | ||
state=state, | ||
memory=memory, | ||
lengths=ys_in_lens, | ||
mask=memory_mask, | ||
) | ||
|
||
decoded = self.output(z) | ||
return decoded, ys_in_lens | ||
|
||
def score(self, ys, state, x): | ||
raise NotImplementedError | ||
|
||
def batch_score( | ||
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor | ||
) -> Tuple[torch.Tensor, List[Any]]: | ||
"""Score new token batch. | ||
Args: | ||
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | ||
states (List[Any]): Scorer states for prefix tokens. | ||
xs (torch.Tensor): | ||
The encoder feature that generates ys (n_batch, xlen, n_feat). | ||
Returns: | ||
tuple[torch.Tensor, List[Any]]: Tuple of | ||
batchfied scores for next token with shape of `(n_batch, n_vocab)` | ||
and next state list for ys. | ||
""" | ||
# merge states | ||
n_batch = len(ys) | ||
ys = self.embed(ys[:, -1:]) | ||
|
||
# workaround for remaining beam width of 1 | ||
if type(states[0]) is list: | ||
states = states[0] | ||
|
||
assert ys.size(1) == 1, ys.shape | ||
ys = ys.squeeze(1) | ||
|
||
ys, states = self.decoder.step(ys, state=states, memory=xs) | ||
logp = self.output(ys).log_softmax(dim=-1) | ||
|
||
states_list = [ | ||
[state[b].unsqueeze(0) if state is not None else None for state in states] | ||
for b in range(n_batch) | ||
] | ||
|
||
return logp, states_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Initialize sub package.""" |
Oops, something went wrong.