Skip to content

Commit

Permalink
Merge pull request #5265 from ftshijt/master
Browse files Browse the repository at this point in the history
A few minor fixes for SSL
  • Loading branch information
sw005320 committed Jul 25, 2023
2 parents 6071ab5 + 8c97948 commit 47f989a
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# This config was trained on 8 x A100 (40GB) for 5 days
use_amp: true
grad_clip: 5.0
batch_type: numel
batch_bins: 20000000
num_iters_per_epoch: 4000
num_workers: 8
accum_grad: 4
max_epoch: 250
patience: none
# Use self-defined function for initialization
init: none
best_model_criterion:
- - valid
- loss
- min
keep_nbest_models: 10

unused_parameters: true

input_size: 1

collate_fn_conf:
label_downsampling: 2
pad: False
rand_crop: True

encoder: torchaudio_hubert
encoder_conf:
encoder_projection_dropout: 0.1
encoder_attention_dropout: 0.1
encoder_ff_interm_dropout: 0.0
encoder_dropout: 0.1
encoder_layer_drop: 0.05
extractor_conv_layer_config: [ [512, 10, 5], [512,5,4] ,[512,3,2],[512,3,2],[512,3,2],[512,3,2],[512,3,2]]

model: torchaudio

optim: adam
optim_conf:
lr: 0.0005
scheduler: warmuplr
scheduler_conf:
warmup_steps: 32000

frontend: null

normalize: null

specaug: null
18 changes: 9 additions & 9 deletions espnet2/asr/encoder/hubert_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class TorchAudioHuBERTPretrainEncoder(AbsEncoder):
Valid values are "group_norm" or "layer_norm".
extractor_conv_layer_config: Configuration of convolution layers in feature
extractor. List of convolution configuration,
i.e. [(output_channel, kernel_size, stride), ...]
i.e. [[output_channel, kernel_size, stride], ...]
extractor_conv_bias: Whether to include bias term to each convolution
operation.
encoder_embed_dim: The dimension of embedding in encoder.
Expand Down Expand Up @@ -89,14 +89,14 @@ def __init__(
self,
input_size: int = None,
extractor_mode: str = "group_norm",
extractor_conv_layer_config: Optional[List[Tuple[int, int, int]]] = [
(512, 10, 5),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 3, 2),
(512, 2, 2),
(512, 2, 2),
extractor_conv_layer_config: Optional[List[List[int]]] = [
[512, 10, 5],
[512, 3, 2],
[512, 3, 2],
[512, 3, 2],
[512, 3, 2],
[512, 2, 2],
[512, 2, 2],
],
extractor_conv_bias: bool = False,
encoder_embed_dim: int = 768,
Expand Down
2 changes: 1 addition & 1 deletion espnet2/tasks/abs_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,7 +1502,7 @@ def build_iter_factory(
e.g. If The number of mini-batches equals to 4, the following two are same:
- 1 epoch without "--num_iters_per_epoch"
- 4 epoch with "--num_iters_per_epoch" == 4
- 4 epoch with "--num_iters_per_epoch" == 1
"""
assert check_argument_types()
Expand Down
27 changes: 26 additions & 1 deletion espnet2/tasks/hubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from typing import Callable, Collection, Dict, List, Optional, Tuple, Union

import humanfriendly
import numpy as np
import torch
from typeguard import check_argument_types, check_return_type
Expand Down Expand Up @@ -275,13 +276,37 @@ def build_collate_fn(
Tuple[List[str], Dict[str, torch.Tensor]],
]:
assert check_argument_types()

# default sampling rate is 16000
fs = args.frontend_conf.get("fs", 16000)
if isinstance(fs, str):
fs = humanfriendly.parse_size(fs)
sample_rate = fs / 1000

if args.encoder_conf.get("extractor_conv_layer_config", None) is None:
# corresponding to default conv extractor
# refer to espnet2/asr/encoder/hubert_encoder.py
reception_field = 400
stride_field = 320
else:
stride_field, reception_field = 1, 1
for conv_config in args.encoder_conf["extractor_conv_layer_config"][::-1]:
_, kernel, stride = conv_config
stride_field *= stride
reception_field = stride * (reception_field - 1) + kernel

window_size = reception_field / sample_rate
window_shift = stride_field / sample_rate
return HuBERTCollateFn(
float_pad_value=0.0,
int_pad_value=-1,
label_downsampling=args.collate_fn_conf.get("label_downsampling", 1),
pad=args.collate_fn_conf.get("pad", False),
rand_crop=args.collate_fn_conf.get("rand_crop", True),
crop_audio=not args.collect_stats,
window_size=window_size,
window_shift=window_shift,
sample_rate=sample_rate,
)

@classmethod
Expand Down Expand Up @@ -367,7 +392,7 @@ def build_model(
else:
# Give features from data-loader
args.frontend = None
args.frontend_conf = {}
args.frontend_conf = {**args.frontend_conf}
frontend = None
input_size = args.input_size

Expand Down
34 changes: 28 additions & 6 deletions espnet2/train/collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def __init__(
rand_crop: bool = True,
crop_audio: bool = True,
not_sequence: Collection[str] = (),
window_size: float = 25,
window_shift: float = 20,
sample_rate: float = 16,
):
assert check_argument_types()
super().__init__(
Expand All @@ -65,6 +68,9 @@ def __init__(
self.rand_crop = rand_crop
self.crop_audio = crop_audio
self.not_sequence = set(not_sequence)
self.window_size = window_size
self.window_shift = window_shift
self.sample_rate = sample_rate

def __repr__(self):
return (
Expand Down Expand Up @@ -96,7 +102,14 @@ def __call__(
label = label[:: self.label_downsampling]
if self.crop_audio:
waveform, label, length = _crop_audio_label(
waveform, label, length, num_frames, self.rand_crop
waveform,
label,
length,
num_frames,
self.rand_crop,
self.window_size,
self.window_shift,
self.sample_rate,
)
new_data.append((uid, dict(speech=waveform, text=label)))

Expand All @@ -114,6 +127,9 @@ def _crop_audio_label(
length: torch.Tensor,
num_frames: int,
rand_crop: bool,
window_size: int = 25,
window_shift: int = 20,
sample_rate: int = 16,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Collate the audio and label at the same time.
Expand All @@ -125,29 +141,35 @@ def _crop_audio_label(
rand_crop (bool): if ``rand_crop`` is True, the starting index of the
waveform and label is random if the length is longer than the minimum
length in the mini-batch.
window_size (int): reception field of conv feature extractor (in ms).
In default, calculated by [400 (samples) / 16 (sample_rate)].
window_shift (int): the stride of conv feature extractor (in ms).
In default, calculated by [320 (samples) / 16 (sample_rate)].
sample_rate (int): number of samples in audio signal per millisecond.
Returns:
(Tuple(Tensor, Tensor, Tensor)): Returns the Tensors for the waveform,
label, and the waveform length.
"""

kernel_size = 25
stride = 20
sample_rate = 16 # 16 per millisecond
frame_offset = 0
if waveform.size > num_frames and rand_crop:
diff = waveform.size - num_frames
frame_offset = torch.randint(diff, size=(1,))
elif waveform.size < num_frames:
num_frames = waveform.size
label_offset = max(
math.floor((frame_offset - kernel_size * sample_rate) / (stride * sample_rate))
math.floor(
(frame_offset - window_size * sample_rate) / (window_shift * sample_rate)
)
+ 1,
0,
)
num_label = (
math.floor((num_frames - kernel_size * sample_rate) / (stride * sample_rate))
math.floor(
(num_frames - window_size * sample_rate) / (window_shift * sample_rate)
)
+ 1
)
waveform = waveform[frame_offset : frame_offset + num_frames]
Expand Down
2 changes: 1 addition & 1 deletion test/espnet2/asr/encoder/test_hubert_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_Encoder_forward_backward(finetuning, eval, freeze_encoder_updates):

encoder = TorchAudioHuBERTPretrainEncoder(
20,
extractor_conv_layer_config=[(3, 3, 2)],
extractor_conv_layer_config=[[3, 3, 2]],
encoder_pos_conv_kernel=16,
encoder_pos_conv_groups=4,
encoder_embed_dim=4,
Expand Down
2 changes: 1 addition & 1 deletion test/espnet2/hubert/test_hubert_espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_forward_backward_finetuning_false(finetuning):

encoder = TorchAudioHuBERTPretrainEncoder(
20,
extractor_conv_layer_config=[(3, 3, 2)],
extractor_conv_layer_config=[[3, 3, 2]],
encoder_pos_conv_kernel=16,
encoder_pos_conv_groups=4,
encoder_embed_dim=4,
Expand Down
48 changes: 38 additions & 10 deletions test/espnet2/train/test_collate_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,26 @@ def test_CommonCollateFn_repr(float_pad_value, int_pad_value, not_sequence):


@pytest.mark.parametrize(
"float_pad_value, int_pad_value, not_sequence, label_downsampling, pad, rand_crop",
(
"float_pad_value, int_pad_value, not_sequence, label_downsampling, pad,"
"rand_crop, window_size, window_shift, sample_rate"
),
[
(0.0, -1, (), 1, True, False),
(3.0, 2, ("a",), 1, False, False),
(np.inf, 100, ("a", "b"), 2, True, False),
(0.0, -1, (), 1, True, True, 25, 20, 16),
(3.0, 2, ("a",), 1, False, False, 25, 20, 16),
(np.inf, 100, ("a", "b"), 2, True, False, 25, 20, 16),
],
)
def test_HuBERT_(
float_pad_value, int_pad_value, not_sequence, label_downsampling, pad, rand_crop
float_pad_value,
int_pad_value,
not_sequence,
label_downsampling,
pad,
rand_crop,
window_size,
window_shift,
sample_rate,
):
_hubert_collate_fn = HuBERTCollateFn(
float_pad_value=float_pad_value,
Expand All @@ -140,6 +151,9 @@ def test_HuBERT_(
label_downsampling=label_downsampling,
pad=pad,
rand_crop=rand_crop,
window_size=window_size,
window_shift=window_shift,
sample_rate=sample_rate,
)
data = [
(
Expand Down Expand Up @@ -217,15 +231,26 @@ def test_HuBERT_(


@pytest.mark.parametrize(
"float_pad_value, int_pad_value, not_sequence, label_downsampling, pad, rand_crop",
(
"float_pad_value, int_pad_value, not_sequence, label_downsampling, pad, "
"rand_crop, window_size, window_shift, sample_rate"
),
[
(0.0, -1, (), 1, True, True),
(3.0, 2, ("a",), 1, False, False),
(np.inf, 100, ("a", "b"), 2, True, False),
(0.0, -1, (), 1, True, True, 25, 20, 16),
(3.0, 2, ("a",), 1, False, False, 80, 40, 16),
(np.inf, 100, ("a", "b"), 2, True, False, 25, 20, 16),
],
)
def test_HuBERTCollateFn_repr(
float_pad_value, int_pad_value, not_sequence, label_downsampling, pad, rand_crop
float_pad_value,
int_pad_value,
not_sequence,
label_downsampling,
pad,
rand_crop,
window_size,
window_shift,
sample_rate,
):
print(
HuBERTCollateFn(
Expand All @@ -235,5 +260,8 @@ def test_HuBERTCollateFn_repr(
label_downsampling=label_downsampling,
pad=pad,
rand_crop=rand_crop,
window_size=window_size,
window_shift=window_shift,
sample_rate=sample_rate,
)
)

0 comments on commit 47f989a

Please sign in to comment.