Skip to content

Commit

Permalink
Merge pull request #4892 from tjysdsg/conv2d1
Browse files Browse the repository at this point in the history
Add Conv2dSubsampling1 module and test it in AphasiaBank ASR recipe
  • Loading branch information
sw005320 committed Jan 30, 2023
2 parents 756519a + c1dc65f commit e37ee27
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 6 deletions.
20 changes: 20 additions & 0 deletions egs2/aphasiabank/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,30 @@
- Git hash: `39c1ec0509904f16ac36d25efc971e2a94ff781f`
- Commit date: `Wed Dec 21 12:50:18 2022 -0500`

## asr_train_asr_ebranchformer_small_wavlm_large1

- [train_asr_ebranchformer_small_wavlm_large1.yaml](conf/tuning/train_asr_ebranchformer_small_wavlm_large1.yaml)
- Control group data is included
- Downsampling rate = 2 = 2 (WavLM) * 1 (`Conv2dSubsampling1`)
- [Hugging Face](https://huggingface.co/espnet/jiyang_tang_aphsiabank_english_asr_ebranchformer_small_wavlm_large1)

### WER

| dataset | Snt | Wrd | Corr | Sub | Del | Ins | Err | S.Err |
|-------------------------------------|-------|--------|------|------|-----|-----|------|-------|
| decode_asr_model_valid.acc.ave/test | 16380 | 120684 | 77.5 | 16.4 | 6.1 | 4.2 | 26.7 | 70.8 |

### CER

| dataset | Snt | Wrd | Corr | Sub | Del | Ins | Err | S.Err |
|-------------------------------------|-------|--------|------|-----|-----|-----|------|-------|
| decode_asr_model_valid.acc.ave/test | 16380 | 530731 | 87.6 | 5.4 | 6.9 | 4.7 | 17.0 | 70.8 |

## asr_train_asr_ebranchformer_small_wavlm_large

- [train_asr_ebranchformer_small_wavlm_large.yaml](conf/tuning/train_asr_ebranchformer_small_wavlm_large.yaml)
- Control group data is included
- Downsampling rate = 4 = 2 (WavLM) * 2 (`Conv2dSubsampling2`)
- [Hugging Face](https://huggingface.co/espnet/jiyang_tang_aphsiabank_english_asr_ebranchformer_small_wavlm_large)

### WER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ encoder_conf:
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
layer_drop_rate: 0.1
input_layer: conv2d2 # subsampling rate = 2
input_layer: conv2d2 # subsampling rate = 2 (WavLM) * 2 (conv2d2)
macaron_ffn: true
pos_enc_layer_type: rel_pos
attention_layer_type: rel_selfattn
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# https://github.com/espnet/espnet/blob/master/egs2/librispeech/asr1/conf/tuning/train_asr_conformer7_wavlm_large.yaml
unused_parameters: true
freeze_param: [
"frontend.upstream"
]
frontend: s3prl
frontend_conf:
frontend_conf:
upstream: wavlm_large # Note: If the upstream is changed, please change the input_size in the preencoder.
download_dir: ./hub
multilayer_feature: True

preencoder: linear
preencoder_conf:
input_size: 1024 # Note: If the upstream is changed, please change this value accordingly.
output_size: 80

model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false
extract_feats_in_collect_stats: false # Note: "False" means during collect stats (stage 10), generating dummy stats files rather than extract_feats by forward frontend.


# Based on https://github.com/espnet/espnet/blob/master/egs2/librispeech/asr1/conf/tuning/train_asr_e_branchformer.yaml
# The encoder is smaller as we keep the size roughly the same as the conformer and transformer experiments
encoder: e_branchformer
encoder_conf:
output_size: 256
attention_heads: 4
linear_units: 1024
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
layer_drop_rate: 0.1
input_layer: conv2d1 # subsampling rate = 2 (WavLM) * 1 (conv2d1)
macaron_ffn: true
pos_enc_layer_type: rel_pos
attention_layer_type: rel_selfattn
rel_pos_type: latest
cgmlp_linear_units: 3072
cgmlp_conv_kernel: 31
use_linear_after_conv: false
gate_activation: identity
positionwise_layer_type: linear
use_ffn: true
merge_conv_kernel: 31

decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
layer_drop_rate: 0.2

seed: 2022
log_interval: 200
num_att_plot: 0
num_workers: 4
sort_in_batch: descending
sort_batch: descending
batch_type: numel
batch_bins: 3000000
accum_grad: 16
grad_clip: 5
max_epoch: 30
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10

use_amp: true
cudnn_deterministic: false
cudnn_benchmark: false

optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 2500

specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
2 changes: 1 addition & 1 deletion egs2/aphasiabank/asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ valid_set="val"
test_sets="test"
include_control=true

asr_config=conf/train_asr.yaml
asr_config=conf/tuning/train_asr_ebranchformer_small_wavlm_large.yaml

feats_normalize=global_mvn
if [[ ${asr_config} == *"hubert"* ]] || [[ ${asr_config} == *"wavlm"* ]]; then
Expand Down
61 changes: 61 additions & 0 deletions espnet/nets/pytorch_backend/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self, message, actual_size, limit):

def check_short_utt(ins, size):
"""Check if the utterance is too short for subsampling."""
if isinstance(ins, Conv2dSubsampling1) and size < 5:
return True, 5
if isinstance(ins, Conv2dSubsampling2) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling) and size < 7:
Expand Down Expand Up @@ -100,6 +102,65 @@ def __getitem__(self, key):
return self.out[key]


class Conv2dSubsampling1(torch.nn.Module):
"""Similar to Conv2dSubsampling module, but without any subsampling performed.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""

def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling1 object."""
super(Conv2dSubsampling1, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (idim - 4), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)

def forward(self, x, x_mask):
"""Pass x through 2 Conv2d layers without subsampling.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim).
where time' = time - 4.
torch.Tensor: Subsampled mask (#batch, 1, time').
where time' = time - 4.
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-4]

def __getitem__(self, key):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]


class Conv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Expand Down
9 changes: 9 additions & 0 deletions espnet2/asr/encoder/branchformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -366,6 +367,13 @@ def __init__(
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
Expand Down Expand Up @@ -521,6 +529,7 @@ def forward(

if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
9 changes: 9 additions & 0 deletions espnet2/asr/encoder/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -155,6 +156,13 @@ def __init__(
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
Expand Down Expand Up @@ -313,6 +321,7 @@ def forward(

if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
9 changes: 9 additions & 0 deletions espnet2/asr/encoder/e_branchformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -252,6 +253,13 @@ def __init__(
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
Expand Down Expand Up @@ -395,6 +403,7 @@ def forward(

if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
9 changes: 9 additions & 0 deletions espnet2/asr/encoder/longformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -158,6 +159,13 @@ def __init__(
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
Expand Down Expand Up @@ -304,6 +312,7 @@ def forward(
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
4 changes: 4 additions & 0 deletions espnet2/asr/encoder/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -92,6 +93,8 @@ def __init__(
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
Expand Down Expand Up @@ -183,6 +186,7 @@ def forward(
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down

0 comments on commit e37ee27

Please sign in to comment.