Skip to content

Commit

Permalink
Merge pull request #4845 from m-koichi/s4-decoder
Browse files Browse the repository at this point in the history
[WIP] Add S4 decoder in ESPnet2
  • Loading branch information
mergify[bot] committed Jan 11, 2023
2 parents 0b1cc6c + 15a2f0e commit ef202b0
Show file tree
Hide file tree
Showing 25 changed files with 4,270 additions and 1 deletion.
56 changes: 56 additions & 0 deletions egs2/librispeech/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,62 @@



# Conformer, S4 Decoder
- Params: 113.20M
- ASR config [conf/tuning/train_asr_s4_decoder.yaml](conf/tuning/train_asr_s4_decoder.yaml)
- LM config: [conf/tuning/train_lm_transformer2.yaml](conf/tuning/train_lm_transformer2.yaml)
- Pretrained model: [https://huggingface.co/espnet/kmiyazaki_librispeech_asr_s4_decoder](https://huggingface.co/espnet/kmiyazaki_librispeech_asr_s4_decoder)
# RESULTS
## Environments
- date: `Thu Dec 29 11:58:25 UTC 2022`
- python version: `3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]`
- espnet version: `espnet 202211`
- pytorch version: `pytorch 1.12.0`
- Git hash: `617189d2d7e060bbcf670ab54b88776333b5137e`
- Commit date: `Mon Dec 26 18:01:58 2022 +0900`

## asr_train_asr_s4_decoder_raw_en_bpe5000_sp
### WER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|beam60_ctc0.3/dev_clean|2703|54402|98.2|1.6|0.2|0.2|2.0|25.9|
|beam60_ctc0.3/dev_other|2864|50948|95.5|4.2|0.4|0.5|5.0|42.2|
|beam60_ctc0.3/test_clean|2620|52576|98.0|1.8|0.2|0.3|2.3|27.2|
|beam60_ctc0.3/test_other|2939|52343|95.6|4.0|0.4|0.6|5.0|44.4|
|beam60_ctc0.3_lm0.6/dev_clean|2703|54402|98.5|1.3|0.2|0.2|1.7|23.0|
|beam60_ctc0.3_lm0.6/dev_other|2864|50948|96.4|3.3|0.3|0.4|4.0|36.6|
|beam60_ctc0.3_lm0.6/test_clean|2620|52576|98.3|1.5|0.2|0.2|1.9|23.7|
|beam60_ctc0.3_lm0.6/test_other|2939|52343|96.3|3.3|0.4|0.4|4.1|39.5|

### CER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|beam60_ctc0.3/dev_clean|2703|288456|99.5|0.3|0.2|0.2|0.6|25.9|
|beam60_ctc0.3/dev_other|2864|265951|98.4|1.0|0.6|0.5|2.1|42.2|
|beam60_ctc0.3/test_clean|2620|281530|99.5|0.3|0.2|0.2|0.7|27.2|
|beam60_ctc0.3/test_other|2939|272758|98.6|0.8|0.6|0.6|2.0|44.4|
|beam60_ctc0.3_lm0.6/dev_clean|2703|288456|99.6|0.2|0.2|0.2|0.6|23.0|
|beam60_ctc0.3_lm0.6/dev_other|2864|265951|98.6|0.8|0.5|0.5|1.8|36.6|
|beam60_ctc0.3_lm0.6/test_clean|2620|281530|99.6|0.2|0.2|0.2|0.6|23.7|
|beam60_ctc0.3_lm0.6/test_other|2939|272758|98.8|0.7|0.6|0.5|1.7|39.5|

### TER

|dataset|Snt|Wrd|Corr|Sub|Del|Ins|Err|S.Err|
|---|---|---|---|---|---|---|---|---|
|beam60_ctc0.3/dev_clean|2703|68010|97.8|1.6|0.6|0.4|2.5|25.9|
|beam60_ctc0.3/dev_other|2864|63110|94.5|4.3|1.3|0.9|6.4|42.2|
|beam60_ctc0.3/test_clean|2620|65818|97.5|1.7|0.7|0.4|2.8|27.2|
|beam60_ctc0.3/test_other|2939|65101|94.6|3.9|1.5|0.8|6.2|44.4|
|beam60_ctc0.3_lm0.6/dev_clean|2703|68010|98.1|1.4|0.5|0.4|2.2|23.0|
|beam60_ctc0.3_lm0.6/dev_other|2864|63110|95.4|3.5|1.1|0.9|5.5|36.6|
|beam60_ctc0.3_lm0.6/test_clean|2620|65818|98.0|1.4|0.6|0.4|2.4|23.7|
|beam60_ctc0.3_lm0.6/test_other|2939|65101|95.5|3.2|1.3|0.8|5.4|39.5|



# Conformer, `hop_length=160`
- Params: 116.15 M
- ASR config: [conf/tuning/train_asr_conformer10_hop_length160.yaml](conf/tuning/train_asr_conformer10_hop_length160.yaml)
Expand Down
96 changes: 96 additions & 0 deletions egs2/librispeech/asr1/conf/tuning/train_asr_s4_decoder.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Trained with Tesla A100 (40GB) x 4 GPUs. It takes about 2.5 days.
encoder: conformer
encoder_conf:
output_size: 512
attention_heads: 8
linear_units: 2048
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
input_layer: conv2d
normalize_before: true
macaron_style: true
rel_pos_type: latest
pos_enc_layer_type: rel_pos
selfattention_layer_type: rel_selfattn
activation_type: swish
use_cnn_module: true
cnn_module_kernel: 31

decoder: s4
decoder_conf:
dropinp: 0.0
dropout: 0.1
drop_path: 0.1
prenorm: true
n_layers: 6
# specify each block config here. for detail, see espnet2/asr/state_spaces
# layer: stack of black box modules, _name_ must be specified
# residual: choose an option for residual connection
layer:
- _name_: s4
postact: glu
dropout: 0.1
n_ssm: 1
lr: 0.0025
dt_min: 0.001
dt_max: 0.1
measure: legs
- _name_: mha
n_head: 8
dropout: 0.1
- _name_: ff
expand: 4
activation: gelu
dropout: 0.1
residual: residual
norm: layer

model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false

frontend_conf:
n_fft: 512
hop_length: 160

use_amp: true
num_workers: 4
batch_type: numel
batch_bins: 35000000
accum_grad: 4
max_epoch: 60
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10

optim: adamw
exclude_weight_decay: true
optim_conf:
lr: 0.0025
weight_decay: 0.01
scheduler: warmuplr
scheduler_conf:
warmup_steps: 40000

specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 10
173 changes: 173 additions & 0 deletions espnet2/asr/decoder/s4_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
"""Decoder definition."""
from typing import Any, List, Tuple

import torch
from typeguard import check_argument_types

from espnet2.asr.decoder.abs_decoder import AbsDecoder
from espnet2.asr.state_spaces.model import SequenceModel
from espnet.nets.pytorch_backend.nets_utils import make_pad_mask
from espnet.nets.scorer_interface import BatchScorerInterface


class S4Decoder(AbsDecoder, BatchScorerInterface):
"""S4 decoder module.
Args:
vocab_size: output dim
encoder_output_size: dimension of hidden vector
input_layer: input layer type
dropinp: input dropout
dropout: dropout parameter applied on every residual and every layer
prenorm: pre-norm vs. post-norm
n_layers: number of layers
transposed: transpose inputs so each layer receives (batch, dim, length)
tie_dropout: tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d
n_repeat: each layer is repeated n times per stage before applying pooling
layer: layer config, must be specified
residual: residual config
norm: normalization config (e.g. layer vs batch)
pool: config for pooling layer per stage
track_norms: log norms of each layer output
drop_path: drop rate for stochastic depth
"""

def __init__(
self,
vocab_size: int,
encoder_output_size: int,
input_layer: str = "embed",
dropinp: float = 0.0,
dropout: float = 0.25,
prenorm: bool = True,
n_layers: int = 16,
transposed: bool = False,
tie_dropout: bool = False,
n_repeat=1,
layer=None,
residual=None,
norm=None,
pool=None,
track_norms=True,
drop_path: float = 0.0,
):
assert check_argument_types()
super().__init__()

self.d_model = encoder_output_size
self.sos = vocab_size - 1
self.eos = vocab_size - 1
self.odim = vocab_size
self.dropout = dropout

if input_layer == "embed":
self.embed = torch.nn.Embedding(vocab_size, self.d_model)
else:
raise NotImplementedError
self.dropout_emb = torch.nn.Dropout(p=dropout)

self.decoder = SequenceModel(
self.d_model,
n_layers=n_layers,
transposed=transposed,
dropout=dropout,
tie_dropout=tie_dropout,
prenorm=prenorm,
n_repeat=n_repeat,
layer=layer,
residual=residual,
norm=norm,
pool=pool,
track_norms=track_norms,
dropinp=dropinp,
drop_path=drop_path,
)

self.output = torch.nn.Linear(self.d_model, vocab_size)

def init_state(self, x: torch.Tensor):
"""Initialize state."""
return self.decoder.default_state(1, device=x.device)

def forward(
self,
hs_pad: torch.Tensor,
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
state=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
Args:
hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
hlens: (batch)
ys_in_pad:
input token ids, int64 (batch, maxlen_out)
if input_layer == "embed"
input tensor (batch, maxlen_out, #mels) in the other cases
ys_in_lens: (batch)
Returns:
(tuple): tuple containing:
x: decoded token score before softmax (batch, maxlen_out, token)
if use_output_layer is True,
olens: (batch, )
"""
memory = hs_pad
memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
memory.device
)

emb = self.embed(ys_in_pad)
z, state = self.decoder(
emb,
state=state,
memory=memory,
lengths=ys_in_lens,
mask=memory_mask,
)

decoded = self.output(z)
return decoded, ys_in_lens

def score(self, ys, state, x):
raise NotImplementedError

def batch_score(
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
) -> Tuple[torch.Tensor, List[Any]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch = len(ys)
ys = self.embed(ys[:, -1:])

# workaround for remaining beam width of 1
if type(states[0]) is list:
states = states[0]

assert ys.size(1) == 1, ys.shape
ys = ys.squeeze(1)

ys, states = self.decoder.step(ys, state=states, memory=xs)
logp = self.output(ys).log_softmax(dim=-1)

states_list = [
[state[b].unsqueeze(0) if state is not None else None for state in states]
for b in range(n_batch)
]

return logp, states_list
1 change: 1 addition & 0 deletions espnet2/asr/state_spaces/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Initialize sub package."""

0 comments on commit ef202b0

Please sign in to comment.