Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic mixing in the speech separation task. #4387

Merged
merged 13 commits into from
Aug 10, 2022
1 change: 1 addition & 0 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ if python -c 'import torch as t; from packaging.version import parse as L; asser
echo "==== feats_type=${t} ==="
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}"
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}" --use_preprocessor true --extra_wav_list "rirs.scp noises.scp" --enh_config ./conf/train_with_preprocessor.yaml
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}" --enh_config conf/train_with_dynamic_mixing.yaml --dynamic_mixing true --spk-num 2
done
# Remove generated files in order to reduce the disk usage
rm -rf exp dump data
Expand Down
47 changes: 38 additions & 9 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ enh_config= # Config for enhancement model training.
enh_args= # Arguments for enhancement model training, e.g., "--max_epoch 10".
# Note that it will overwrite args in enhancement config.
spk_num=2 # Number of speakers
dynamic_mixing=false # Flag for dynamic mixing in speech separation task.
noise_type_num=1
dereverb_ref_num=1

Expand Down Expand Up @@ -142,6 +143,7 @@ Options:
--enh_args # Arguments for enhancement model training, e.g., "--max_epoch 10" (default="${enh_args}").
# Note that it will overwrite args in enhancement config.
--spk_num # Number of speakers in the input audio (default="${spk_num}")
--dynamic_mixing # Flag for dynamic mixing in speech separation task (default="${dynamic_mixing}").
--noise_type_num # Number of noise types in the input audio (default="${noise_type_num}")
--dereverb_ref_num # Number of references for dereverberation (default="${dereverb_ref_num}")

Expand Down Expand Up @@ -265,6 +267,8 @@ if [ -z "${inference_tag}" ]; then
fi
fi



# ========================== Main stages start from here. ==========================

if ! "${skip_data_prep}"; then
Expand Down Expand Up @@ -545,7 +549,15 @@ if ! "${skip_train}"; then
_opts+="--config ${enh_config} "
fi

_scp=wav.scp
if ${dynamic_mixing}; then
# In current version, if you want to enable dynamic mixing in speech separation,
# you need to prepare the training set manually. Here we assume all speech sources
# are collected in "spk1.scp", and other scp files (wav.scp, spk{N}.scp) are not used.
log "Dynamic mixing is enabled, use spk1.scp as the source file list."
_scp=spk1.scp
else
_scp=wav.scp
fi
# "sound" supports "wav", "flac", etc.
if [[ "${audio_format}" == *ark* ]]; then
_type=kaldi_ark
Expand All @@ -555,15 +567,32 @@ if ! "${skip_train}"; then
fi
_fold_length="$((enh_speech_fold_length * 100))"

# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/wav.scp,speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "

if ! ${dynamic_mixing} ; then

# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "

for spk in $(seq "${spk_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
done

else
# prepare train and valid data parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is better to echo some message here to warn the user that dynamic mixing is being used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Can you add a comment, @LiChenda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_ref1,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_ref1_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "
_opts+="--utt2spk ${_enh_train_dir}/utt2spk "
fi

for spk in $(seq "${spk_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_shape_param+="--valid_shape_file ${enh_stats_dir}/valid/speech_ref${spk}_shape "
_fold_length_param+="--fold_length ${_fold_length} "
Expand Down
3 changes: 2 additions & 1 deletion egs2/mini_an4/asr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ fcaw-cen8-b fcaw-cen8-b_org 0.0 2.9
mmxg-cen8-b mmxg-cen8-b_org 0.0 2.3
EOF

# for enh task
# for enh and separation task
for x in test ${train_set} ${train_dev}; do
cp data/${x}/wav.scp data/${x}/spk1.scp
cp data/${x}/wav.scp data/${x}/spk2.scp
done

find downloads/noise/ -iname "*.wav" | awk '{print "noise" NR " " $1}' > data/${train_set}/noises.scp
Expand Down
37 changes: 37 additions & 0 deletions egs2/mini_an4/enh1/conf/train_with_dynamic_mixing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

encoder: stft
encoder_conf:
n_fft: 512
hop_length: 128
decoder: stft
decoder_conf:
n_fft: 512
hop_length: 128
separator: rnn
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: relu
layer: 1
unit: 128
dropout: 0.2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment about this configuration?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

# dynamic_mixing related
# dynamic_mixing_gain_db:
# The maximum random gain (in dB) for each source before the mixing.
# The gain (in dB) of each source is unifromly sampled in
# [-dynamic_mixing_gain_db, dynamic_mixing_gain_db]
dynamic_mixing: True
dynamic_mixing_gain_db: 2.0

criterions:
# The first criterion
- name: mse
conf:
compute_on_mask: True
mask_type: PSM^2
# the wrapper for the current criterion
# for single-talker case, we simplely use fixed_order wrapper
wrapper: fixed_order
wrapper_conf:
weight: 1.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
init: xavier_uniform
max_epoch: 150
batch_type: folded
# When dynamic mixing is enabled, the actual batch_size will
# be (batch_size / num_spk)
batch_size: 8
iterator_type: chunk
chunk_length: 16000
num_workers: 2
optim: adamw
optim_conf:
lr: 1.0e-03
eps: 1.0e-06
weight_decay: 0
patience: 20
grad_clip: 5.0
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 10

scheduler: steplr
scheduler_conf:
step_size: 2
gamma: 0.97

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# dynamic_mixing related
# dynamic_mixing_gain_db:
# The maximum random gain (in dB) for each source before the mixing.
# The gain (in dB) of each source is unifromly sampled in
# [-dynamic_mixing_gain_db, dynamic_mixing_gain_db]
dynamic_mixing: True
dynamic_mixing_gain_db: 2.0

encoder: conv
encoder_conf:
channel: 64
kernel_size: 2
stride: 1
decoder: conv
decoder_conf:
channel: 64
kernel_size: 2
stride: 1
separator: skim
separator_conf:
causal: False
num_spk: 2
layer: 6
nonlinear: relu
unit: 128
segment_size: 250
dropout: 0.05
mem_type: hc
seg_overlap: True

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: si_snr
conf:
eps: 1.0e-6
wrapper: pit
wrapper_conf:
weight: 1.0
independent_perm: True

1 change: 1 addition & 0 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(
espnet2/iterators/chunk_iter_factory.py
kwargs: "utt_id" is among the input.
"""

# clean speech signal of each speaker
speech_ref = [
kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
Expand Down
67 changes: 63 additions & 4 deletions espnet2/tasks/enh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import copy
from typing import Callable, Collection, Dict, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -54,11 +55,13 @@
from espnet2.enh.separator.svoice_separator import SVoiceSeparator
from espnet2.enh.separator.tcn_separator import TCNSeparator
from espnet2.enh.separator.transformer_separator import TransformerSeparator
from espnet2.iterators.abs_iter_factory import AbsIterFactory
from espnet2.tasks.abs_task import AbsTask
from espnet2.torch_utils.initialize import initialize
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import EnhPreprocessor
from espnet2.train.distributed_utils import DistributedOption
from espnet2.train.preprocessor import DynamicMixingPreprocessor, EnhPreprocessor
from espnet2.train.trainer import Trainer
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
Expand Down Expand Up @@ -292,6 +295,27 @@ def add_task_arguments(cls, parser: argparse.ArgumentParser):
help="Whether to force all data to be single-channel.",
)

group.add_argument(
"--dynamic_mixing",
type=str2bool,
default=False,
help="Apply dynamic mixing",
)

group.add_argument(
"--utt2spk",
type=str_or_none,
default=None,
help="The file path of utt2spk file. Only used in dynamic_mixing mode.",
)

group.add_argument(
"--dynamic_mixing_gain_db",
type=float,
default=0.0,
help="Random gain (in dB) for dynamic mixing sources",
)

for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
Expand All @@ -313,7 +337,25 @@ def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:

dynamic_mixing = getattr(args, "dynamic_mixing", False)
use_preprocessor = getattr(args, "use_preprocessor", False)

assert (
dynamic_mixing and use_preprocessor
) is not True, (
"'dynamic_mixing' and 'use_preprocessor' should not both be 'True'"
)

if dynamic_mixing and train:
retval = DynamicMixingPreprocessor(
train=train,
source_scp=args.train_data_path_and_name_and_type[0][0],
num_spk=args.separator_conf["num_spk"],
dynamic_mixing_gain_db=getattr(args, "dynamic_mixing_gain_db", 0.0),
utt2spk=getattr(args, "utt2spk", None),
)
elif use_preprocessor:
retval = EnhPreprocessor(
train=train,
# NOTE(kamo): Check attribute existence for backward compatibility
Expand Down Expand Up @@ -356,7 +398,7 @@ def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech_mix", "speech_ref1")
retval = ("speech_ref1",)
else:
# Recognition mode
retval = ("speech_mix",)
Expand All @@ -366,7 +408,8 @@ def required_data_names(
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval = ["speech_mix"]
retval += ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval += ["speech_ref{}".format(n) for n in range(2, MAX_REFERENCE_NUM + 1)]
retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval = tuple(retval)
Expand Down Expand Up @@ -420,3 +463,19 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:

assert check_return_type(model)
return model

@classmethod
def build_iter_factory(
cls,
args: argparse.Namespace,
distributed_option: DistributedOption,
mode: str,
kwargs: dict = None,
) -> AbsIterFactory:

dynamic_mixing = getattr(args, "dynamic_mixing", False)
if dynamic_mixing and mode == "train":
args = copy.deepcopy(args)
args.fold_length = args.fold_length[0:1]

return super().build_iter_factory(args, distributed_option, mode, kwargs)