Skip to content

Commit

Permalink
Merge pull request #4387 from LiChenda/dynamic_mixing
Browse files Browse the repository at this point in the history
Add dynamic mixing in the speech separation task.
  • Loading branch information
mergify[bot] committed Aug 10, 2022
2 parents 96bd746 + d0117f7 commit 24b12f8
Show file tree
Hide file tree
Showing 8 changed files with 356 additions and 14 deletions.
1 change: 1 addition & 0 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ if python -c 'import torch as t; from packaging.version import parse as L; asser
echo "==== feats_type=${t} ==="
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}"
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}" --use_preprocessor true --extra_wav_list "rirs.scp noises.scp" --enh_config ./conf/train_with_preprocessor.yaml
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}" --enh_config conf/train_with_dynamic_mixing.yaml --dynamic_mixing true --spk-num 2
done
# Remove generated files in order to reduce the disk usage
rm -rf exp dump data
Expand Down
47 changes: 38 additions & 9 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ enh_config= # Config for enhancement model training.
enh_args= # Arguments for enhancement model training, e.g., "--max_epoch 10".
# Note that it will overwrite args in enhancement config.
spk_num=2 # Number of speakers
dynamic_mixing=false # Flag for dynamic mixing in speech separation task.
noise_type_num=1
dereverb_ref_num=1

Expand Down Expand Up @@ -142,6 +143,7 @@ Options:
--enh_args # Arguments for enhancement model training, e.g., "--max_epoch 10" (default="${enh_args}").
# Note that it will overwrite args in enhancement config.
--spk_num # Number of speakers in the input audio (default="${spk_num}")
--dynamic_mixing # Flag for dynamic mixing in speech separation task (default="${dynamic_mixing}").
--noise_type_num # Number of noise types in the input audio (default="${noise_type_num}")
--dereverb_ref_num # Number of references for dereverberation (default="${dereverb_ref_num}")
Expand Down Expand Up @@ -265,6 +267,8 @@ if [ -z "${inference_tag}" ]; then
fi
fi



# ========================== Main stages start from here. ==========================

if ! "${skip_data_prep}"; then
Expand Down Expand Up @@ -545,7 +549,15 @@ if ! "${skip_train}"; then
_opts+="--config ${enh_config} "
fi

_scp=wav.scp
if ${dynamic_mixing}; then
# In current version, if you want to enable dynamic mixing in speech separation,
# you need to prepare the training set manually. Here we assume all speech sources
# are collected in "spk1.scp", and other scp files (wav.scp, spk{N}.scp) are not used.
log "Dynamic mixing is enabled, use spk1.scp as the source file list."
_scp=spk1.scp
else
_scp=wav.scp
fi
# "sound" supports "wav", "flac", etc.
if [[ "${audio_format}" == *ark* ]]; then
_type=kaldi_ark
Expand All @@ -555,15 +567,32 @@ if ! "${skip_train}"; then
fi
_fold_length="$((enh_speech_fold_length * 100))"

# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/wav.scp,speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "

if ! ${dynamic_mixing} ; then

# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "

for spk in $(seq "${spk_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
done

else
# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_ref1,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_ref1_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "
_opts+="--utt2spk ${_enh_train_dir}/utt2spk "
fi

for spk in $(seq "${spk_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_shape_param+="--valid_shape_file ${enh_stats_dir}/valid/speech_ref${spk}_shape "
_fold_length_param+="--fold_length ${_fold_length} "
Expand Down
3 changes: 2 additions & 1 deletion egs2/mini_an4/asr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,10 @@ fcaw-cen8-b fcaw-cen8-b_org 0.0 2.9
mmxg-cen8-b mmxg-cen8-b_org 0.0 2.3
EOF

# for enh task
# for enh and separation task
for x in test ${train_set} ${train_dev}; do
cp data/${x}/wav.scp data/${x}/spk1.scp
cp data/${x}/wav.scp data/${x}/spk2.scp
done

find downloads/noise/ -iname "*.wav" | awk '{print "noise" NR " " $1}' > data/${train_set}/noises.scp
Expand Down
37 changes: 37 additions & 0 deletions egs2/mini_an4/enh1/conf/train_with_dynamic_mixing.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

encoder: stft
encoder_conf:
n_fft: 512
hop_length: 128
decoder: stft
decoder_conf:
n_fft: 512
hop_length: 128
separator: rnn
separator_conf:
rnn_type: blstm
num_spk: 2
nonlinear: relu
layer: 1
unit: 128
dropout: 0.2

# dynamic_mixing related
# dynamic_mixing_gain_db:
# The maximum random gain (in dB) for each source before the mixing.
# The gain (in dB) of each source is unifromly sampled in
# [-dynamic_mixing_gain_db, dynamic_mixing_gain_db]
dynamic_mixing: True
dynamic_mixing_gain_db: 2.0

criterions:
# The first criterion
- name: mse
conf:
compute_on_mask: True
mask_type: PSM^2
# the wrapper for the current criterion
# for single-talker case, we simplely use fixed_order wrapper
wrapper: fixed_order
wrapper_conf:
weight: 1.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
init: xavier_uniform
max_epoch: 150
batch_type: folded
# When dynamic mixing is enabled, the actual batch_size will
# be (batch_size / num_spk)
batch_size: 8
iterator_type: chunk
chunk_length: 16000
num_workers: 2
optim: adamw
optim_conf:
lr: 1.0e-03
eps: 1.0e-06
weight_decay: 0
patience: 20
grad_clip: 5.0
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 10

scheduler: steplr
scheduler_conf:
step_size: 2
gamma: 0.97

# dynamic_mixing related
# dynamic_mixing_gain_db:
# The maximum random gain (in dB) for each source before the mixing.
# The gain (in dB) of each source is unifromly sampled in
# [-dynamic_mixing_gain_db, dynamic_mixing_gain_db]
dynamic_mixing: True
dynamic_mixing_gain_db: 2.0

encoder: conv
encoder_conf:
channel: 64
kernel_size: 2
stride: 1
decoder: conv
decoder_conf:
channel: 64
kernel_size: 2
stride: 1
separator: skim
separator_conf:
causal: False
num_spk: 2
layer: 6
nonlinear: relu
unit: 128
segment_size: 250
dropout: 0.05
mem_type: hc
seg_overlap: True

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: si_snr
conf:
eps: 1.0e-6
wrapper: pit
wrapper_conf:
weight: 1.0
independent_perm: True

1 change: 1 addition & 0 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def forward(
espnet2/iterators/chunk_iter_factory.py
kwargs: "utt_id" is among the input.
"""

# clean speech signal of each speaker
speech_ref = [
kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
Expand Down
67 changes: 63 additions & 4 deletions espnet2/tasks/enh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import copy
from typing import Callable, Collection, Dict, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -56,11 +57,13 @@
from espnet2.enh.separator.svoice_separator import SVoiceSeparator
from espnet2.enh.separator.tcn_separator import TCNSeparator
from espnet2.enh.separator.transformer_separator import TransformerSeparator
from espnet2.iterators.abs_iter_factory import AbsIterFactory
from espnet2.tasks.abs_task import AbsTask
from espnet2.torch_utils.initialize import initialize
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import EnhPreprocessor
from espnet2.train.distributed_utils import DistributedOption
from espnet2.train.preprocessor import DynamicMixingPreprocessor, EnhPreprocessor
from espnet2.train.trainer import Trainer
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
Expand Down Expand Up @@ -296,6 +299,27 @@ def add_task_arguments(cls, parser: argparse.ArgumentParser):
help="Whether to force all data to be single-channel.",
)

group.add_argument(
"--dynamic_mixing",
type=str2bool,
default=False,
help="Apply dynamic mixing",
)

group.add_argument(
"--utt2spk",
type=str_or_none,
default=None,
help="The file path of utt2spk file. Only used in dynamic_mixing mode.",
)

group.add_argument(
"--dynamic_mixing_gain_db",
type=float,
default=0.0,
help="Random gain (in dB) for dynamic mixing sources",
)

for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
Expand All @@ -317,7 +341,25 @@ def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:

dynamic_mixing = getattr(args, "dynamic_mixing", False)
use_preprocessor = getattr(args, "use_preprocessor", False)

assert (
dynamic_mixing and use_preprocessor
) is not True, (
"'dynamic_mixing' and 'use_preprocessor' should not both be 'True'"
)

if dynamic_mixing and train:
retval = DynamicMixingPreprocessor(
train=train,
source_scp=args.train_data_path_and_name_and_type[0][0],
num_spk=args.separator_conf["num_spk"],
dynamic_mixing_gain_db=getattr(args, "dynamic_mixing_gain_db", 0.0),
utt2spk=getattr(args, "utt2spk", None),
)
elif use_preprocessor:
retval = EnhPreprocessor(
train=train,
# NOTE(kamo): Check attribute existence for backward compatibility
Expand Down Expand Up @@ -360,7 +402,7 @@ def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech_mix", "speech_ref1")
retval = ("speech_ref1",)
else:
# Recognition mode
retval = ("speech_mix",)
Expand All @@ -370,7 +412,8 @@ def required_data_names(
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval = ["speech_mix"]
retval += ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval += ["speech_ref{}".format(n) for n in range(2, MAX_REFERENCE_NUM + 1)]
retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval = tuple(retval)
Expand Down Expand Up @@ -424,3 +467,19 @@ def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel:

assert check_return_type(model)
return model

@classmethod
def build_iter_factory(
cls,
args: argparse.Namespace,
distributed_option: DistributedOption,
mode: str,
kwargs: dict = None,
) -> AbsIterFactory:

dynamic_mixing = getattr(args, "dynamic_mixing", False)
if dynamic_mixing and mode == "train":
args = copy.deepcopy(args)
args.fold_length = args.fold_length[0:1]

return super().build_iter_factory(args, distributed_option, mode, kwargs)

0 comments on commit 24b12f8

Please sign in to comment.