Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Conv2dSubsampling1 module and test it in AphasiaBank ASR recipe #4892

Merged
merged 5 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 18 additions & 0 deletions egs2/aphasiabank/asr1/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@
- Git hash: `39c1ec0509904f16ac36d25efc971e2a94ff781f`
- Commit date: `Wed Dec 21 12:50:18 2022 -0500`

## asr_train_asr_ebranchformer_small_wavlm_large1

- [train_asr_ebranchformer_small_wavlm_large.yaml](conf/tuning/train_asr_ebranchformer_small_wavlm_large.yaml)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a short note here to explain the difference? E.g., this config uses xxx as the input layer which does not perform downsampling.

Copy link
Collaborator

@pyf98 pyf98 Jan 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are this config name and path correct? It looks same as the previous one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Fixed in 4734d75

- Control group data is included
- [Hugging Face](https://huggingface.co/espnet/jiyang_tang_aphsiabank_english_asr_ebranchformer_small_wavlm_large1)

### WER

| dataset | Snt | Wrd | Corr | Sub | Del | Ins | Err | S.Err |
|-------------------------------------|-------|--------|------|------|-----|-----|------|-------|
| decode_asr_model_valid.acc.ave/test | 16380 | 120684 | 77.5 | 16.4 | 6.1 | 4.2 | 26.7 | 70.8 |

### CER

| dataset | Snt | Wrd | Corr | Sub | Del | Ins | Err | S.Err |
|-------------------------------------|-------|--------|------|-----|-----|-----|------|-------|
| decode_asr_model_valid.acc.ave/test | 16380 | 530731 | 87.6 | 5.4 | 6.9 | 4.7 | 17.0 | 70.8 |

## asr_train_asr_ebranchformer_small_wavlm_large

- [train_asr_ebranchformer_small_wavlm_large.yaml](conf/tuning/train_asr_ebranchformer_small_wavlm_large.yaml)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ encoder_conf:
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
layer_drop_rate: 0.1
input_layer: conv2d2 # subsampling rate = 2
input_layer: conv2d2 # subsampling rate = 2 (WavLM) * 2 (conv2d2)
macaron_ffn: true
pos_enc_layer_type: rel_pos
attention_layer_type: rel_selfattn
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# https://github.com/espnet/espnet/blob/master/egs2/librispeech/asr1/conf/tuning/train_asr_conformer7_wavlm_large.yaml
unused_parameters: true
freeze_param: [
"frontend.upstream"
]
frontend: s3prl
frontend_conf:
frontend_conf:
upstream: wavlm_large # Note: If the upstream is changed, please change the input_size in the preencoder.
download_dir: ./hub
multilayer_feature: True

preencoder: linear
preencoder_conf:
input_size: 1024 # Note: If the upstream is changed, please change this value accordingly.
output_size: 80

model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false
extract_feats_in_collect_stats: false # Note: "False" means during collect stats (stage 10), generating dummy stats files rather than extract_feats by forward frontend.


# Based on https://github.com/espnet/espnet/blob/master/egs2/librispeech/asr1/conf/tuning/train_asr_e_branchformer.yaml
# The encoder is smaller as we keep the size roughly the same as the conformer and transformer experiments
encoder: e_branchformer
encoder_conf:
output_size: 256
attention_heads: 4
linear_units: 1024
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.1
layer_drop_rate: 0.1
input_layer: conv2d1 # subsampling rate = 2 (WavLM) * 1 (conv2d1)
macaron_ffn: true
pos_enc_layer_type: rel_pos
attention_layer_type: rel_selfattn
rel_pos_type: latest
cgmlp_linear_units: 3072
cgmlp_conv_kernel: 31
use_linear_after_conv: false
gate_activation: identity
positionwise_layer_type: linear
use_ffn: true
merge_conv_kernel: 31

decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
layer_drop_rate: 0.2

seed: 2022
log_interval: 200
num_att_plot: 0
num_workers: 4
sort_in_batch: descending
sort_batch: descending
batch_type: numel
batch_bins: 3000000
accum_grad: 16
grad_clip: 5
max_epoch: 30
patience: none
init: none
best_model_criterion:
- - valid
- acc
- max
keep_nbest_models: 10

use_amp: true
cudnn_deterministic: false
cudnn_benchmark: false

optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
warmup_steps: 2500

specaug: specaug
specaug_conf:
apply_time_warp: true
time_warp_window: 5
time_warp_mode: bicubic
apply_freq_mask: true
freq_mask_width_range:
- 0
- 27
num_freq_mask: 2
apply_time_mask: true
time_mask_width_ratio_range:
- 0.
- 0.05
num_time_mask: 5
2 changes: 1 addition & 1 deletion egs2/aphasiabank/asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ valid_set="val"
test_sets="test"
include_control=true

asr_config=conf/train_asr.yaml
asr_config=conf/tuning/train_asr_ebranchformer_small_wavlm_large.yaml

feats_normalize=global_mvn
if [[ ${asr_config} == *"hubert"* ]] || [[ ${asr_config} == *"wavlm"* ]]; then
Expand Down
59 changes: 59 additions & 0 deletions espnet/nets/pytorch_backend/transformer/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def __init__(self, message, actual_size, limit):

def check_short_utt(ins, size):
"""Check if the utterance is too short for subsampling."""
if isinstance(ins, Conv2dSubsampling1) and size < 5:
return True, 5
if isinstance(ins, Conv2dSubsampling2) and size < 7:
return True, 7
if isinstance(ins, Conv2dSubsampling) and size < 7:
Expand Down Expand Up @@ -100,6 +102,63 @@ def __getitem__(self, key):
return self.out[key]


class Conv2dSubsampling1(torch.nn.Module):
"""Convolutional 2D subsampling (to the same length).

Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.

"""

def __init__(self, idim, odim, dropout_rate, pos_enc=None):
"""Construct an Conv2dSubsampling1 object."""
super(Conv2dSubsampling1, self).__init__()
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(1, odim, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(odim, odim, 3, 1),
torch.nn.ReLU(),
)
self.out = torch.nn.Sequential(
torch.nn.Linear(odim * (idim - 4), odim),
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
)

def forward(self, x, x_mask):
"""Subsample x with a ratio of 1.

Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).

Returns:
torch.Tensor: Subsampled tensor (#batch, time, odim).
torch.Tensor: Subsampled mask (#batch, 1, time).

"""
x = x.unsqueeze(1) # (b, c, t, f)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
if x_mask is None:
return x, None
return x, x_mask[:, :, :-4]

def __getitem__(self, key):
"""Get item.

When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.

"""
if key != -1:
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
return self.out[key]


class Conv2dSubsampling2(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).

Expand Down
9 changes: 9 additions & 0 deletions espnet2/asr/encoder/branchformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -366,6 +367,13 @@ def __init__(
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
Expand Down Expand Up @@ -521,6 +529,7 @@ def forward(

if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
9 changes: 9 additions & 0 deletions espnet2/asr/encoder/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -155,6 +156,13 @@ def __init__(
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
Expand Down Expand Up @@ -313,6 +321,7 @@ def forward(

if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
9 changes: 9 additions & 0 deletions espnet2/asr/encoder/e_branchformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -252,6 +253,13 @@ def __init__(
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
Expand Down Expand Up @@ -395,6 +403,7 @@ def forward(

if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
9 changes: 9 additions & 0 deletions espnet2/asr/encoder/longformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -158,6 +159,13 @@ def __init__(
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(
input_size,
output_size,
dropout_rate,
pos_enc_class(output_size, positional_dropout_rate),
)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(
input_size,
Expand Down Expand Up @@ -304,6 +312,7 @@ def forward(
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
4 changes: 4 additions & 0 deletions espnet2/asr/encoder/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -92,6 +93,8 @@ def __init__(
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
Expand Down Expand Up @@ -183,6 +186,7 @@ def forward(
xs_pad = xs_pad
elif (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down
4 changes: 4 additions & 0 deletions espnet2/asr/encoder/transformer_encoder_multispkr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from espnet.nets.pytorch_backend.transformer.repeat import repeat
from espnet.nets.pytorch_backend.transformer.subsampling import (
Conv2dSubsampling,
Conv2dSubsampling1,
Conv2dSubsampling2,
Conv2dSubsampling6,
Conv2dSubsampling8,
Expand Down Expand Up @@ -92,6 +93,8 @@ def __init__(
)
elif input_layer == "conv2d":
self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
elif input_layer == "conv2d1":
self.embed = Conv2dSubsampling1(input_size, output_size, dropout_rate)
elif input_layer == "conv2d2":
self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
elif input_layer == "conv2d6":
Expand Down Expand Up @@ -193,6 +196,7 @@ def forward(

if (
isinstance(self.embed, Conv2dSubsampling)
or isinstance(self.embed, Conv2dSubsampling1)
or isinstance(self.embed, Conv2dSubsampling2)
or isinstance(self.embed, Conv2dSubsampling6)
or isinstance(self.embed, Conv2dSubsampling8)
Expand Down