Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic mixing in the speech separation task. #4387

Merged
merged 13 commits into from
Aug 10, 2022
24 changes: 21 additions & 3 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ enh_config= # Config for enhancement model training.
enh_args= # Arguments for enhancement model training, e.g., "--max_epoch 10".
# Note that it will overwrite args in enhancement config.
spk_num=2 # Number of speakers
dynamic_mixing=false # Flag for dynamic mixing in speech separation task.
noise_type_num=1
dereverb_ref_num=1

Expand Down Expand Up @@ -140,6 +141,7 @@ Options:
--enh_args # Arguments for enhancement model training, e.g., "--max_epoch 10" (default="${enh_args}").
# Note that it will overwrite args in enhancement config.
--spk_num # Number of speakers in the input audio (default="${spk_num}")
--dynamic_mixing # Flag for dynamic mixing in speech separation task (default="${dynamic_mixing}").
--noise_type_num # Number of noise types in the input audio (default="${noise_type_num}")
--dereverb_ref_num # Number of references for dereverberation (default="${dereverb_ref_num}")

Expand Down Expand Up @@ -261,6 +263,8 @@ if [ -z "${inference_tag}" ]; then
fi
fi



# ========================== Main stages start from here. ==========================

if ! "${skip_data_prep}"; then
Expand Down Expand Up @@ -529,7 +533,14 @@ if ! "${skip_train}"; then
_opts+="--config ${enh_config} "
fi

_scp=wav.scp
if [ ${dynamic_mixing} ]; then
LiChenda marked this conversation as resolved.
Show resolved Hide resolved
# In current version, if you want to enable dynamic mixing in speech separation,
# you need to prepare the training set manually. Here we assume all speech sources
# are collected in "spk1.scp", and other scp files (wav.scp, spk{N}.scp) are not used.
_scp=spk1.scp
else
_scp=wav.scp
fi
# "sound" supports "wav", "flac", etc.
if [[ "${audio_format}" == *ark* ]]; then
_type=kaldi_ark
Expand All @@ -540,13 +551,20 @@ if ! "${skip_train}"; then
_fold_length="$((enh_speech_fold_length * 100))"

# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/wav.scp,speech_mix,${_type} "
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "

for spk in $(seq "${spk_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
if [ ${dynamic_mixing} ]; then
# TODO (Chenda): Not needed in the dynamic mixing mode.
# Better to find a way to remove it without affecting compatibility.
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk1.scp,speech_ref${spk},${_type} "
else
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
fi
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_shape_param+="--valid_shape_file ${enh_stats_dir}/valid/speech_ref${spk}_shape "
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
init: xavier_uniform
max_epoch: 150
batch_type: folded
# When dynamic mixing is enabled, the actual batch_size will
# be (batch_size / num_spk)
batch_size: 32
iterator_type: chunk
chunk_length: 16000
num_workers: 2
optim: adamw
optim_conf:
lr: 1.0e-03
eps: 1.0e-06
weight_decay: 0
patience: 20
grad_clip: 5.0
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 10

scheduler: steplr
scheduler_conf:
step_size: 2
gamma: 0.97

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

model_conf:
dynamic_mixing: True
dynamic_mixing_gain_db: 2.0

encoder: conv
encoder_conf:
channel: 64
kernel_size: 2
stride: 1
decoder: conv
decoder_conf:
channel: 64
kernel_size: 2
stride: 1
separator: skim
separator_conf:
causal: False
num_spk: 2
layer: 6
nonlinear: relu
unit: 128
segment_size: 250
dropout: 0.05
mem_type: hc
seg_overlap: True

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: si_snr
conf:
eps: 1.0e-6
wrapper: pit
wrapper_conf:
weight: 1.0
independent_perm: True








Emrys365 marked this conversation as resolved.
Show resolved Hide resolved
51 changes: 45 additions & 6 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(
stft_consistency: bool = False,
loss_type: str = "mask_mse",
mask_type: Optional[str] = None,
dynamic_mixing: bool = False,
dynamic_mixing_gain_db: float = 0.0,
):
assert check_argument_types()

Expand All @@ -43,6 +45,9 @@ def __init__(
self.loss_wrappers = loss_wrappers
self.num_spk = separator.num_spk
self.num_noise_type = getattr(self.separator, "num_noise_type", 1)
self.dynamic_mixing = dynamic_mixing
# Max +/- gain (dB) of sources in dynamic_mixing
self.dynamic_mixing_gain_db = dynamic_mixing_gain_db

# get mask type for TF-domain models
# (only used when loss_type="mask_*") (deprecated, keep for compatibility)
Expand Down Expand Up @@ -76,12 +81,46 @@ def forward(
espnet2/iterators/chunk_iter_factory.py
kwargs: "utt_id" is among the input.
"""
# clean speech signal of each speaker
speech_ref = [
kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
]
# (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
speech_ref = torch.stack(speech_ref, dim=1)

if self.dynamic_mixing and self.training:
# Dynamic mixing mode.
sources = kwargs["speech_ref1"]
pseudo_batch = sources.shape[0]
assert (
pseudo_batch % self.num_spk == 0
), f"In the dynamic mixing mode, real batchsize is batchsize/num_spk. \
Current batchsize on a single GPU is {pseudo_batch}, \
and num_spk is {self.num_spk}"

# Apply random gain to speech sources.
gain_in_db = (
torch.FloatTensor(pseudo_batch, 1)
.uniform_(-self.dynamic_mixing_gain_db, self.dynamic_mixing_gain_db)
.to(sources.device)
)
gain = torch.pow(10, gain_in_db / 20.0)
sources = sources * gain

# Create speech mixture.
rand_perm = torch.randperm(sources.shape[0])
rand_perm = rand_perm.view(self.num_spk, -1)
speech_ref = torch.stack([sources[p] for p in rand_perm], dim=1)
speech_mix = speech_ref.sum(dim=1)

# Calculate speech_mix_lengths.
if speech_mix_lengths is not None:
ref_lengths = torch.stack(
[speech_mix_lengths[p] for p in rand_perm], dim=1
)
speech_mix_lengths = ref_lengths.max(dim=1)[0]

else:
# clean speech signal of each speaker
speech_ref = [
kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
]
# (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
speech_ref = torch.stack(speech_ref, dim=1)

if "noise_ref1" in kwargs:
# noise signal (optional, required when using beamforming-based
Expand Down