Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add dynamic mixing in the speech separation task. #4387

Merged
merged 13 commits into from
Aug 10, 2022
48 changes: 38 additions & 10 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ enh_config= # Config for enhancement model training.
enh_args= # Arguments for enhancement model training, e.g., "--max_epoch 10".
# Note that it will overwrite args in enhancement config.
spk_num=2 # Number of speakers
dynamic_mixing=false # Flag for dynamic mixing in speech separation task.
noise_type_num=1
dereverb_ref_num=1

Expand Down Expand Up @@ -142,6 +143,7 @@ Options:
--enh_args # Arguments for enhancement model training, e.g., "--max_epoch 10" (default="${enh_args}").
# Note that it will overwrite args in enhancement config.
--spk_num # Number of speakers in the input audio (default="${spk_num}")
--dynamic_mixing # Flag for dynamic mixing in speech separation task (default="${dynamic_mixing}").
--noise_type_num # Number of noise types in the input audio (default="${noise_type_num}")
--dereverb_ref_num # Number of references for dereverberation (default="${dereverb_ref_num}")

Expand Down Expand Up @@ -265,6 +267,8 @@ if [ -z "${inference_tag}" ]; then
fi
fi



# ========================== Main stages start from here. ==========================

if ! "${skip_data_prep}"; then
Expand Down Expand Up @@ -545,7 +549,14 @@ if ! "${skip_train}"; then
_opts+="--config ${enh_config} "
fi

_scp=wav.scp
if [ ${dynamic_mixing} ]; then
LiChenda marked this conversation as resolved.
Show resolved Hide resolved
# In current version, if you want to enable dynamic mixing in speech separation,
# you need to prepare the training set manually. Here we assume all speech sources
# are collected in "spk1.scp", and other scp files (wav.scp, spk{N}.scp) are not used.
_scp=spk1.scp
else
_scp=wav.scp
fi
# "sound" supports "wav", "flac", etc.
if [[ "${audio_format}" == *ark* ]]; then
_type=kaldi_ark
Expand All @@ -555,18 +566,35 @@ if ! "${skip_train}"; then
fi
_fold_length="$((enh_speech_fold_length * 100))"

# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/wav.scp,speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "

if ! ${dynamic_mixing} ; then

# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "

for spk in $(seq "${spk_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
_fold_length_param+="--fold_length ${_fold_length} "
done

else
# prepare train and valid data parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it is better to echo some message here to warn the user that dynamic mixing is being used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Can you add a comment, @LiChenda?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_ref1,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_ref1_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "

fi

for spk in $(seq "${spk_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_shape_param+="--valid_shape_file ${enh_stats_dir}/valid/speech_ref${spk}_shape "
_fold_length_param+="--fold_length ${_fold_length} "
done

if $use_dereverb_ref; then
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
init: xavier_uniform
max_epoch: 150
batch_type: folded
# When dynamic mixing is enabled, the actual batch_size will
# be (batch_size / num_spk)
batch_size: 8
iterator_type: chunk
chunk_length: 16000
num_workers: 2
optim: adamw
optim_conf:
lr: 1.0e-03
eps: 1.0e-06
weight_decay: 0
patience: 20
grad_clip: 5.0
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 10

scheduler: steplr
scheduler_conf:
step_size: 2
gamma: 0.97

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

dynamic_mixing: True
dynamic_mixing_gain_db: 2.0

encoder: conv
encoder_conf:
channel: 64
kernel_size: 2
stride: 1
decoder: conv
decoder_conf:
channel: 64
kernel_size: 2
stride: 1
separator: skim
separator_conf:
causal: False
num_spk: 2
layer: 6
nonlinear: relu
unit: 128
segment_size: 250
dropout: 0.05
mem_type: hc
seg_overlap: True

# A list for criterions
# The overlall loss in the multi-task learning will be:
# loss = weight_1 * loss_1 + ... + weight_N * loss_N
# The default `weight` for each sub-loss is 1.0
criterions:
# The first criterion
- name: si_snr
conf:
eps: 1.0e-6
wrapper: pit
wrapper_conf:
weight: 1.0
independent_perm: True

1 change: 1 addition & 0 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def forward(
espnet2/iterators/chunk_iter_factory.py
kwargs: "utt_id" is among the input.
"""

# clean speech signal of each speaker
speech_ref = [
kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
Expand Down
50 changes: 46 additions & 4 deletions espnet2/tasks/enh.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from espnet2.torch_utils.initialize import initialize
from espnet2.train.class_choices import ClassChoices
from espnet2.train.collate_fn import CommonCollateFn
from espnet2.train.preprocessor import EnhPreprocessor
from espnet2.train.preprocessor import DynamicMixingPreprocessor, EnhPreprocessor
from espnet2.train.trainer import Trainer
from espnet2.utils.get_default_kwargs import get_default_kwargs
from espnet2.utils.nested_dict_action import NestedDictAction
Expand Down Expand Up @@ -277,6 +277,27 @@ def add_task_arguments(cls, parser: argparse.ArgumentParser):
help="Whether to force all data to be single-channel.",
)

group.add_argument(
"--dynamic_mixing",
type=str2bool,
default=False,
help="Apply dynamic mixing",
)

group.add_argument(
"--utt2spk",
type=str_or_none,
default=None,
help="The file path of utt2spk file. Only used in dynamic_mixing mode.",
)

group.add_argument(
"--dynamic_mixing_gain_db",
type=float,
default=0.0,
help="Random gain (in dB) for dynamic mixing sources",
)

for class_choices in cls.class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
Expand All @@ -298,7 +319,25 @@ def build_preprocess_fn(
cls, args: argparse.Namespace, train: bool
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
assert check_argument_types()
if args.use_preprocessor:

dynamic_mixing = getattr(args, "dynamic_mixing", False)
use_preprocessor = getattr(args, "use_preprocessor", False)

assert (
dynamic_mixing and use_preprocessor
) is not True, (
"'dynamic_mixing' and 'use_preprocessor' should not both be 'True'"
)

if dynamic_mixing:
retval = DynamicMixingPreprocessor(
train=train,
source_scp=args.train_data_path_and_name_and_type[0][0],
num_spk=args.separator_conf["num_spk"],
dynamic_mixing_gain_db=args.dynamic_mixing_gain_db,
utt2spk=args.utt2spk,
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved
)
elif use_preprocessor:
retval = EnhPreprocessor(
train=train,
# NOTE(kamo): Check attribute existence for backward compatibility
Expand Down Expand Up @@ -341,7 +380,7 @@ def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
if not inference:
retval = ("speech_mix", "speech_ref1")
retval = ("speech_ref1",)
else:
# Recognition mode
retval = ("speech_mix",)
Expand All @@ -351,7 +390,10 @@ def required_data_names(
def optional_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval = [
"speech_mix",
]
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved
retval += ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval += ["speech_ref{}".format(n) for n in range(2, MAX_REFERENCE_NUM + 1)]
retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)]
retval = tuple(retval)
Expand Down
132 changes: 132 additions & 0 deletions espnet2/train/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import math
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Collection, Dict, Iterable, List, Union
Expand Down Expand Up @@ -499,6 +501,136 @@ def _text_process(
return data


class DynamicMixingPreprocessor(AbsPreprocessor):
def __init__(
self,
train: bool,
source_scp: str = None,
num_spk: int = 2,
dynamic_mixing_gain_db: float = 0.0,
speech_name: str = "speech_mix",
speech_ref_name_prefix: str = "speech_ref",
utt2spk: str = None,
):

super().__init__(
train,
)
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved
self.source_scp = source_scp
self.num_spk = num_spk
self.dynamic_mixing_gain_db = dynamic_mixing_gain_db
self.speech_name = speech_name
self.srnp = speech_ref_name_prefix
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved

self.sources = {}
with open(source_scp, "r", encoding="utf-8") as f:
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved
for line in f:
sps = line.strip().split(None, 1)
assert len(sps) == 2
self.sources[sps[0]] = sps[1]

self.utt2spk = {}
if utt2spk is None:
# if utt2spk is not provided, create a dummy utt2spk with uid.
for key in self.sources.keys():
self.utt2spk[key] = key
else:
with open(utt2spk, "r", encoding="utf-8") as f:
for line in f:
sps = line.strip().split(None, 1)
assert len(sps) == 2
self.utt2spk[sps[0]] = sps[1]

for key in self.sources.keys():
assert key in self.utt2spk

self.source_keys = list(self.sources.keys())

def _pick_source_utterances_(self, uid):
# return (num_spk - 1) uid of reference sources.

source_keys = [
uid,
]
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved

spk_ids = [
self.utt2spk[uid],
]
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved

while len(source_keys) < self.num_spk:
picked = random.choice(self.source_keys)
spk_id = self.utt2spk[picked]

# make one utterance or one speaker only appears once in mixing.
if (picked not in source_keys) and (spk_id not in spk_ids):
source_keys.append(picked)
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved

return source_keys[1:]

def _read_source_(self, key, speech_length):

source, _ = soundfile.read(
self.sources[key],
dtype=np.float32,
always_2d=False,
)

if speech_length > source.shape[0]:
pad = speech_length - source.shape[0]
source = np.pad(source, (0, pad), "reflect")
else:
source = source[0:speech_length]

assert speech_length == source.shape[0]

return source

def _mix_speech_(self, uid, data):

# pick sources
source_keys = self._pick_source_utterances_(uid)

# load audios
speech_length = data[f"{self.srnp}1"].shape[0]
ref_audios = [self._read_source_(key, speech_length) for key in source_keys]
ref_audios = [
data[f"{self.srnp}1"],
] + ref_audios
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved

# apply random gain to speech sources

gain_in_db = [
random.uniform(-self.dynamic_mixing_gain_db, self.dynamic_mixing_gain_db)
for i in range(len(ref_audios))
]
gain = [math.pow(10, g_db / 20.0) for g_db in gain_in_db]
Emrys365 marked this conversation as resolved.
Show resolved Hide resolved

ref_audios = [ref * g for ref, g in zip(ref_audios, gain)]

speech_mix = np.sum(np.array(ref_audios), axis=0)

for i, ref in enumerate(ref_audios):
data[f"{self.srnp}{i+1}"] = ref
data[self.speech_name] = speech_mix

return data

def __call__(
self, uid: str, data: Dict[str, Union[str, np.ndarray]]
) -> Dict[str, np.ndarray]:

# TODO(Chenda): need to test for multi-channel data.
assert (
len(data[f"{self.srnp}1"].shape) == 1
), "Multi-channel input has not been tested"

if self.train:
data = self._mix_speech_(uid, data)

assert check_return_type(data)
return data


class EnhPreprocessor(CommonPreprocessor):
"""Preprocessor for Speech Enhancement (Enh) task."""

Expand Down