Skip to content

Commit

Permalink
Add MixIT support. It is unsupervised only. Semi-supervised config is…
Browse files Browse the repository at this point in the history
… not available for now.
  • Loading branch information
simpleoier committed Sep 4, 2022
1 parent 6d52365 commit b70d77b
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 95 deletions.
6 changes: 3 additions & 3 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ if python -c 'import torch as t; from packaging.version import parse as L; asser
feats_types="raw"
for t in ${feats_types}; do
echo "==== feats_type=${t} ==="
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}"
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}" --use_preprocessor true --extra_wav_list "rirs.scp noises.scp" --enh_config ./conf/train_with_preprocessor.yaml
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}" --enh_config conf/train_with_dynamic_mixing.yaml --dynamic_mixing true --spk-num 2
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --ref-num 1 --enh-args "--max_epoch=1" --python "${python}"
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --ref-num 1 --enh-args "--max_epoch=1" --python "${python}" --use_preprocessor true --extra_wav_list "rirs.scp noises.scp" --enh_config ./conf/train_with_preprocessor.yaml
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --ref-num 1 --enh-args "--max_epoch=1" --python "${python}" --enh_config conf/train_with_dynamic_mixing.yaml --dynamic_mixing true --ref-num 2
done
# Remove generated files in order to reduce the disk usage
rm -rf exp dump data
Expand Down
56 changes: 33 additions & 23 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,16 @@ enh_tag= # Suffix to the result dir for enhancement model training.
enh_config= # Config for enhancement model training.
enh_args= # Arguments for enhancement model training, e.g., "--max_epoch 10".
# Note that it will overwrite args in enhancement config.
spk_num=2 # Number of speakers
ref_num=2 # Number of references (similar to speakers)
inf_num= # Number of inferences output by the model
# If not specified, it will be the same as ref_num. If specified, it will be overwritten.
dynamic_mixing=false # Flag for dynamic mixing in speech separation task.
noise_type_num=1
dereverb_ref_num=1

# Training data related
use_dereverb_ref=false
use_noise_ref=false
use_preprocessor=false
extra_wav_list= # Extra list of scp files for wav formatting

# Pretrained model related
Expand Down Expand Up @@ -142,7 +143,9 @@ Options:
--enh_config # Config for enhancement model training (default="${enh_config}").
--enh_args # Arguments for enhancement model training, e.g., "--max_epoch 10" (default="${enh_args}").
# Note that it will overwrite args in enhancement config.
--spk_num # Number of speakers in the input audio (default="${spk_num}")
--ref_num # Number of reference audios for each mixture (default="${ref_num}")
--inf_num # Number of inference audio generated by the model (default="${ref_num}")
# If not specified, it will be the same as ref_num. If specified, it will be overwritten.
--dynamic_mixing # Flag for dynamic mixing in speech separation task (default="${dynamic_mixing}").
--noise_type_num # Number of noise types in the input audio (default="${noise_type_num}")
--dereverb_ref_num # Number of references for dereverberation (default="${dereverb_ref_num}")
Expand All @@ -152,7 +155,6 @@ Options:
for training a dereverberation model (default="${use_dereverb_ref}")
--use_noise_ref # Whether or not to use noise signal as an additional reference
for training a denoising model (default="${use_noise_ref}")
--use_preprocessor # Whether or not to apply preprocessing (default="${use_preprocessor}")
--extra_wav_list # Extra list of scp files for wav formatting (default="${extra_wav_list}")
# Pretrained model related
Expand Down Expand Up @@ -215,6 +217,7 @@ utt_extra_files="utt2category"

data_feats=${dumpdir}/raw

inf_num=${inf_num:=${ref_num}}

# Set tag for naming of model directory
if [ -z "${enh_tag}" ]; then
Expand Down Expand Up @@ -283,7 +286,7 @@ if ! "${skip_data_prep}"; then
log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"

_scp_list="wav.scp "
for i in $(seq ${spk_num}); do
for i in $(seq ${ref_num}); do
_scp_list+="spk${i}.scp "
done

Expand Down Expand Up @@ -338,7 +341,7 @@ if ! "${skip_data_prep}"; then


_spk_list=" "
for i in $(seq ${spk_num}); do
for i in $(seq ${ref_num}); do
_spk_list+="spk${i} "
done
if $use_noise_ref && [ -n "${_suf}" ]; then
Expand Down Expand Up @@ -385,7 +388,7 @@ if ! "${skip_data_prep}"; then

_spk_list=" "
_scp_list=" "
for i in $(seq ${spk_num}); do
for i in $(seq ${ref_num}); do
_spk_list+="spk${i} "
_scp_list+="spk${i}.scp "
done
Expand Down Expand Up @@ -489,7 +492,7 @@ if ! "${skip_train}"; then
# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/wav.scp,speech_mix,${_type} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
done
Expand Down Expand Up @@ -518,7 +521,6 @@ if ! "${skip_train}"; then
${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
${python} -m espnet2.bin.enh_train \
--collect_stats true \
${use_preprocessor:+--use_preprocessor $use_preprocessor} \
${_train_data_param} \
${_valid_data_param} \
--train_shape_file "${_logdir}/train.JOB.scp" \
Expand Down Expand Up @@ -577,7 +579,7 @@ if ! "${skip_train}"; then
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
done
Expand All @@ -592,7 +594,7 @@ if ! "${skip_train}"; then
_opts+="--utt2spk ${_enh_train_dir}/utt2spk "
fi

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_shape_param+="--valid_shape_file ${enh_stats_dir}/valid/speech_ref${spk}_shape "
_fold_length_param+="--fold_length ${_fold_length} "
Expand Down Expand Up @@ -639,7 +641,6 @@ if ! "${skip_train}"; then
--init_file_prefix "${enh_exp}"/.dist_init_ \
--multiprocessing_distributed true -- \
${python} -m espnet2.bin.enh_train \
${use_preprocessor:+--use_preprocessor $use_preprocessor} \
${_train_data_param} \
${_valid_data_param} \
${_train_shape_param} \
Expand Down Expand Up @@ -713,7 +714,7 @@ if ! "${skip_eval}"; then


_spk_list=" "
for i in $(seq ${spk_num}); do
for i in $(seq ${inf_num}); do
_spk_list+="spk${i} "
done

Expand Down Expand Up @@ -765,18 +766,26 @@ if ! "${skip_eval}"; then


_ref_scp=
for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_ref_scp+="--ref_scp ${_data}/spk${spk}.scp "
done
_inf_scp=
for spk in $(seq "${spk_num}"); do
if "${score_obs}"; then
if "${score_obs}"; then
for spk in $(seq "${ref_num}"); do
# To compute the score of observation, input original wav.scp
_inf_scp+="--inf_scp ${data_feats}/${dset}/wav.scp "
else
done
flexible_numspk=false
else
for spk in $(seq "${inf_num}"); do
_inf_scp+="--inf_scp ${enh_exp}/${inference_tag}_${dset}/spk${spk}.scp "
done
if [[ "${ref_num}" -ne "${inf_num}" ]]; then
flexible_numspk=true
else
flexible_numspk=false
fi
done
fi

# 2. Submit scoring jobs
log "Scoring started... log: '${_logdir}/enh_scoring.*.log'"
Expand All @@ -787,9 +796,10 @@ if ! "${skip_eval}"; then
--output_dir "${_logdir}"/output.JOB \
${_ref_scp} \
${_inf_scp} \
--ref_channel ${ref_channel}
--ref_channel ${ref_channel} \
--flexible_numspk ${flexible_numspk}

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
for protocol in ${scoring_protocol} wav; do
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/${protocol}_spk${spk}"
Expand All @@ -800,7 +810,7 @@ if ! "${skip_eval}"; then

for protocol in ${scoring_protocol}; do
# shellcheck disable=SC2046
paste $(for j in $(seq ${spk_num}); do echo "${_dir}"/"${protocol}"_spk"${j}" ; done) |
paste $(for j in $(seq ${ref_num}); do echo "${_dir}"/"${protocol}"_spk"${j}" ; done) |
awk 'BEGIN{sum=0}
{n=0;score=0;for (i=2; i<=NF; i+=2){n+=1;score+=$i}; sum+=score/n}
END{printf ("%.2f\n",sum/NR)}' > "${_dir}/result_${protocol,,}.txt"
Expand Down Expand Up @@ -857,7 +867,7 @@ if "${score_with_asr}"; then
_dir="${enh_exp}/${inference_asr_tag}/${dset}"
fi

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_ddir=${_dir}/spk_${spk}
_logdir="${_ddir}/logdir"
_decode_dir="${_ddir}/decode"
Expand Down Expand Up @@ -953,7 +963,7 @@ if "${score_with_asr}"; then
_dir="${enh_exp}/${inference_asr_tag}/${dset}/"
fi

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_ddir=${_dir}/spk_${spk}
_logdir="${_ddir}/logdir"
_decode_dir="${_ddir}/decode"
Expand Down
74 changes: 74 additions & 0 deletions egs2/wsj0_2mix/enh1/conf/tuning/train_enh_mixit_conv_tasnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
optim: adam
init: xavier_uniform
max_epoch: 100
batch_type: folded
# When dynamic mixing is enabled, the actual batch_size will
# be (batch_size / num_spk)
batch_size: 16 # batch_size 16 can be trained on 4 RTX 2080ti
iterator_type: chunk
chunk_length: 32000
num_workers: 4
optim_conf:
lr: 1.0e-03
eps: 1.0e-08
weight_decay: 0
patience: 4
val_scheduler_criterion:
- valid
- loss
best_model_criterion:
- - valid
- si_snr
- max
- - valid
- loss
- min
keep_nbest_models: 1

scheduler: reducelronplateau
scheduler_conf:
mode: min
factor: 0.5
patience: 1
encoder: conv
encoder_conf:
channel: 256
kernel_size: 20
stride: 10
decoder: conv
decoder_conf:
channel: 256
kernel_size: 20
stride: 10
separator: tcn
separator_conf:
num_spk: 4
layer: 8
stack: 4
bottleneck_dim: 256
hidden_dim: 512
kernel: 3
causal: False
norm_type: "gLN"
nonlinear: relu

# dynamic_mixing related
# dynamic_mixing_gain_db:
# The maximum random gain (in dB) for each source before the mixing.
# The gain (in dB) of each source is unifromly sampled in
# [-dynamic_mixing_gain_db, dynamic_mixing_gain_db]
preprocessor: dynamic_mixing
preprocessor_conf:
num_utts: 2
dynamic_mixing_gain_db: 0.0
source_scp_name: "wav.scp"
mixture_source_name: "speech_mix"

criterions:
# The first criterion
- name: snr
conf:
eps: 1.0e-7
wrapper: mixit
wrapper_conf:
weight: 1.0
3 changes: 2 additions & 1 deletion espnet2/bin/enh_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def scoring(
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)

assert len(ref_scp) == len(inf_scp), ref_scp
if not flexible_numspk:
assert len(ref_scp) == len(inf_scp), ref_scp
num_spk = len(ref_scp)

keys = [
Expand Down
10 changes: 7 additions & 3 deletions espnet2/enh/espnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,14 @@ def forward(
espnet2/iterators/chunk_iter_factory.py
kwargs: "utt_id" is among the input.
"""

# clean speech signal of each speaker
# reference speech signal of each speaker
assert "speech_ref1" in kwargs, "At least 1 reference signal input is required."
speech_ref = [
kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk)
kwargs.get(
f"speech_ref{spk + 1}",
torch.zeros_like(kwargs["speech_ref1"]),
)
for spk in range(self.num_spk)
]
# (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels)
speech_ref = torch.stack(speech_ref, dim=1)
Expand Down
86 changes: 86 additions & 0 deletions espnet2/enh/loss/wrappers/mixit_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import itertools

import torch

from espnet2.enh.loss.criterions.abs_loss import AbsEnhLoss
from espnet2.enh.loss.wrappers.abs_wrapper import AbsLossWrapper


class MixITSolver(AbsLossWrapper):
def __init__(
self,
criterion: AbsEnhLoss,
weight=1.0,
):
"""Mixture Invariant Training Solver.
Args:
criterion (AbsEnhLoss): an instance of AbsEnhLoss
weight (float): weight (between 0 and 1) of current loss
for multi-task learning.
"""
super().__init__()
self.criterion = criterion
self.weight = weight

@property
def type(self):
return "mixit"

def forward(self, ref, inf, others={}):
"""MixIT solver.
Args:
ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk
inf (List[torch.Tensor]): [(batch, ...), ...] x n_est
Returns:
loss: (torch.Tensor): minimum loss with the best permutation
stats: dict, for collecting training status
others: dict, in this PIT solver, permutation order will be returned
"""
num_inf = len(inf)
num_ref = num_inf // 2
device = ref[0].device

ref_tensor = torch.stack(ref[:num_ref], dim=1) # (batch, num_ref, ...)
inf_tensor = torch.stack(inf, dim=1) # (batch, num_inf, ...)

# all permutation assignments:
# [(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 0), ..., (1, 1, 1, 1)]
all_assignments = list(itertools.product(range(num_ref), repeat=num_inf))
all_mixture_matrix = torch.stack(
[
torch.nn.functional.one_hot(
torch.tensor(asm, dtype=torch.int64, device=device),
num_classes=num_ref,
).transpose(1, 0)
for asm in all_assignments
],
dim=0,
).float() # (num_ref ^ num_inf, num_ref, num_inf)

def pair_loss(matrix):
mix_estimated = torch.matmul(matrix[None], inf_tensor)
return (
sum(
[
self.criterion(ref_tensor[:, i], mix_estimated[:, i])
for i in range(num_ref)
]
)
/ num_ref
)

losses = torch.stack(
[pair_loss(matrix) for matrix in all_mixture_matrix],
dim=1,
) # (batch, num_ref ^ num_inf)
loss, perm = torch.min(losses, dim=1)
perm = torch.index_select(all_mixture_matrix, 0, perm)

loss = loss.mean()

stats = dict()
stats[f"{self.criterion.name}_{self.type}"] = loss.detach()

return loss.mean(), stats, {"perm": perm}

0 comments on commit b70d77b

Please sign in to comment.