Skip to content

Commit

Permalink
Add MixIT support. It is unsupervised only. Semi-supervised config is…
Browse files Browse the repository at this point in the history
… not available for now.
  • Loading branch information
simpleoier committed Sep 5, 2022
1 parent 6d52365 commit 1754de7
Show file tree
Hide file tree
Showing 40 changed files with 682 additions and 137 deletions.
6 changes: 3 additions & 3 deletions ci/test_integration_espnet2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ if python -c 'import torch as t; from packaging.version import parse as L; asser
feats_types="raw"
for t in ${feats_types}; do
echo "==== feats_type=${t} ==="
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}"
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}" --use_preprocessor true --extra_wav_list "rirs.scp noises.scp" --enh_config ./conf/train_with_preprocessor.yaml
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --spk-num 1 --enh-args "--max_epoch=1" --python "${python}" --enh_config conf/train_with_dynamic_mixing.yaml --dynamic_mixing true --spk-num 2
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --ref-num 1 --enh-args "--max_epoch=1" --python "${python}"
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --ref-num 1 --enh-args "--max_epoch=1" --python "${python}" --extra_wav_list "rirs.scp noises.scp" --enh_config ./conf/train_with_preprocessor.yaml
./run.sh --ngpu 0 --stage 2 --stop-stage 10 --skip-upload false --feats-type "${t}" --ref-num 1 --enh-args "--max_epoch=1" --python "${python}" --enh_config conf/train_with_dynamic_mixing.yaml --ref-num 2
done
# Remove generated files in order to reduce the disk usage
rm -rf exp dump data
Expand Down
107 changes: 50 additions & 57 deletions egs2/TEMPLATE/enh1/enh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ enh_tag= # Suffix to the result dir for enhancement model training.
enh_config= # Config for enhancement model training.
enh_args= # Arguments for enhancement model training, e.g., "--max_epoch 10".
# Note that it will overwrite args in enhancement config.
spk_num=2 # Number of speakers
dynamic_mixing=false # Flag for dynamic mixing in speech separation task.
ref_num=2 # Number of references (similar to speakers)
inf_num= # Number of inferences output by the model
# If not specified, it will be the same as ref_num. If specified, it will be overwritten.
noise_type_num=1
dereverb_ref_num=1

# Training data related
use_dereverb_ref=false
use_noise_ref=false
use_preprocessor=false
extra_wav_list= # Extra list of scp files for wav formatting

# Pretrained model related
Expand Down Expand Up @@ -142,8 +142,9 @@ Options:
--enh_config # Config for enhancement model training (default="${enh_config}").
--enh_args # Arguments for enhancement model training, e.g., "--max_epoch 10" (default="${enh_args}").
# Note that it will overwrite args in enhancement config.
--spk_num # Number of speakers in the input audio (default="${spk_num}")
--dynamic_mixing # Flag for dynamic mixing in speech separation task (default="${dynamic_mixing}").
--ref_num # Number of reference audios for each mixture (default="${ref_num}")
--inf_num # Number of inference audio generated by the model (default="${ref_num}")
# If not specified, it will be the same as ref_num. If specified, it will be overwritten.
--noise_type_num # Number of noise types in the input audio (default="${noise_type_num}")
--dereverb_ref_num # Number of references for dereverberation (default="${dereverb_ref_num}")
Expand All @@ -152,7 +153,6 @@ Options:
for training a dereverberation model (default="${use_dereverb_ref}")
--use_noise_ref # Whether or not to use noise signal as an additional reference
for training a denoising model (default="${use_noise_ref}")
--use_preprocessor # Whether or not to apply preprocessing (default="${use_preprocessor}")
--extra_wav_list # Extra list of scp files for wav formatting (default="${extra_wav_list}")
# Pretrained model related
Expand Down Expand Up @@ -215,6 +215,7 @@ utt_extra_files="utt2category"

data_feats=${dumpdir}/raw

inf_num=${inf_num:=${ref_num}}

# Set tag for naming of model directory
if [ -z "${enh_tag}" ]; then
Expand Down Expand Up @@ -283,7 +284,7 @@ if ! "${skip_data_prep}"; then
log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"

_scp_list="wav.scp "
for i in $(seq ${spk_num}); do
for i in $(seq ${ref_num}); do
_scp_list+="spk${i}.scp "
done

Expand Down Expand Up @@ -338,7 +339,7 @@ if ! "${skip_data_prep}"; then


_spk_list=" "
for i in $(seq ${spk_num}); do
for i in $(seq ${ref_num}); do
_spk_list+="spk${i} "
done
if $use_noise_ref && [ -n "${_suf}" ]; then
Expand Down Expand Up @@ -373,6 +374,10 @@ if ! "${skip_data_prep}"; then

echo "${feats_type}" > "${data_feats}${_suf}/${dset}/feats_type"

for f in ${utt_extra_files}; do
[ -f data/${dset}/${f} ] && cp data/${dset}/${f} ${data_feats}${_suf}/${dset}/${f}
done

done
fi

Expand All @@ -385,7 +390,7 @@ if ! "${skip_data_prep}"; then

_spk_list=" "
_scp_list=" "
for i in $(seq ${spk_num}); do
for i in $(seq ${ref_num}); do
_spk_list+="spk${i} "
_scp_list+="spk${i}.scp "
done
Expand All @@ -406,6 +411,11 @@ if ! "${skip_data_prep}"; then
for spk in ${_spk_list};do
cp "${data_feats}/org/${dset}/${spk}.scp" "${data_feats}/${dset}/${spk}.scp"
done
for f in ${utt_extra_files}; do
if [ -f "${data_feats}/org/${dset}/${f}" ]; then
cp "${data_feats}/org/${dset}/${f}" "${data_feats}/${dset}/${f}"
fi
done

_fs=$(python3 -c "import humanfriendly as h;print(h.parse_size('${fs}'))")
_min_length=$(python3 -c "print(int(${min_wav_duration} * ${_fs}))")
Expand All @@ -423,7 +433,7 @@ if ! "${skip_data_prep}"; then
done

# fix_data_dir.sh leaves only utts which exist in all files
utils/fix_data_dir.sh --utt_extra_files "${_scp_list}" "${data_feats}/${dset}"
utils/fix_data_dir.sh --utt_extra_files "${_scp_list} ${utt_extra_files}" "${data_feats}/${dset}"
done
fi
else
Expand Down Expand Up @@ -489,7 +499,7 @@ if ! "${skip_train}"; then
# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/wav.scp,speech_mix,${_type} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
done
Expand Down Expand Up @@ -518,7 +528,6 @@ if ! "${skip_train}"; then
${train_cmd} JOB=1:"${_nj}" "${_logdir}"/stats.JOB.log \
${python} -m espnet2.bin.enh_train \
--collect_stats true \
${use_preprocessor:+--use_preprocessor $use_preprocessor} \
${_train_data_param} \
${_valid_data_param} \
--train_shape_file "${_logdir}/train.JOB.scp" \
Expand Down Expand Up @@ -549,15 +558,7 @@ if ! "${skip_train}"; then
_opts+="--config ${enh_config} "
fi

if ${dynamic_mixing}; then
# In current version, if you want to enable dynamic mixing in speech separation,
# you need to prepare the training set manually. Here we assume all speech sources
# are collected in "spk1.scp", and other scp files (wav.scp, spk{N}.scp) are not used.
log "Dynamic mixing is enabled, use spk1.scp as the source file list."
_scp=spk1.scp
else
_scp=wav.scp
fi
_scp="wav.scp"
# "sound" supports "wav", "flac", etc.
if [[ "${audio_format}" == *ark* ]]; then
_type=kaldi_ark
Expand All @@ -567,32 +568,19 @@ if ! "${skip_train}"; then
fi
_fold_length="$((enh_speech_fold_length * 100))"


if ! ${dynamic_mixing} ; then

# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "
# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_mix,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_mix_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "

for spk in $(seq "${spk_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
done

else
# prepare train and valid data parameters
_train_data_param="--train_data_path_and_name_and_type ${_enh_train_dir}/${_scp},speech_ref1,${_type} "
_train_shape_param="--train_shape_file ${enh_stats_dir}/train/speech_ref1_shape "
_fold_length_param="--fold_length ${_fold_length} "
_valid_data_param="--valid_data_path_and_name_and_type ${_enh_valid_dir}/wav.scp,speech_mix,${_type} "
_valid_shape_param="--valid_shape_file ${enh_stats_dir}/valid/speech_mix_shape "
_opts+="--utt2spk ${_enh_train_dir}/utt2spk "
fi
for spk in $(seq "${ref_num}"); do
_train_data_param+="--train_data_path_and_name_and_type ${_enh_train_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_train_shape_param+="--train_shape_file ${enh_stats_dir}/train/speech_ref${spk}_shape "
done

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_valid_data_param+="--valid_data_path_and_name_and_type ${_enh_valid_dir}/spk${spk}.scp,speech_ref${spk},${_type} "
_valid_shape_param+="--valid_shape_file ${enh_stats_dir}/valid/speech_ref${spk}_shape "
_fold_length_param+="--fold_length ${_fold_length} "
Expand Down Expand Up @@ -639,7 +627,6 @@ if ! "${skip_train}"; then
--init_file_prefix "${enh_exp}"/.dist_init_ \
--multiprocessing_distributed true -- \
${python} -m espnet2.bin.enh_train \
${use_preprocessor:+--use_preprocessor $use_preprocessor} \
${_train_data_param} \
${_valid_data_param} \
${_train_shape_param} \
Expand Down Expand Up @@ -713,7 +700,7 @@ if ! "${skip_eval}"; then


_spk_list=" "
for i in $(seq ${spk_num}); do
for i in $(seq ${inf_num}); do
_spk_list+="spk${i} "
done

Expand Down Expand Up @@ -765,18 +752,23 @@ if ! "${skip_eval}"; then


_ref_scp=
for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_ref_scp+="--ref_scp ${_data}/spk${spk}.scp "
done
_inf_scp=
for spk in $(seq "${spk_num}"); do
if "${score_obs}"; then
if "${score_obs}"; then
for spk in $(seq "${ref_num}"); do
# To compute the score of observation, input original wav.scp
_inf_scp+="--inf_scp ${data_feats}/${dset}/wav.scp "
else
done
else
for spk in $(seq "${inf_num}"); do
_inf_scp+="--inf_scp ${enh_exp}/${inference_tag}_${dset}/spk${spk}.scp "
done
if [[ "${ref_num}" -ne "${inf_num}" ]]; then
flexible_numspk=true
fi
done
fi

# 2. Submit scoring jobs
log "Scoring started... log: '${_logdir}/enh_scoring.*.log'"
Expand All @@ -787,9 +779,10 @@ if ! "${skip_eval}"; then
--output_dir "${_logdir}"/output.JOB \
${_ref_scp} \
${_inf_scp} \
--ref_channel ${ref_channel}
--ref_channel ${ref_channel} \
${flexible_numspk:+"--flexible_numspk"}

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
for protocol in ${scoring_protocol} wav; do
for i in $(seq "${_nj}"); do
cat "${_logdir}/output.${i}/${protocol}_spk${spk}"
Expand All @@ -800,7 +793,7 @@ if ! "${skip_eval}"; then

for protocol in ${scoring_protocol}; do
# shellcheck disable=SC2046
paste $(for j in $(seq ${spk_num}); do echo "${_dir}"/"${protocol}"_spk"${j}" ; done) |
paste $(for j in $(seq ${ref_num}); do echo "${_dir}"/"${protocol}"_spk"${j}" ; done) |
awk 'BEGIN{sum=0}
{n=0;score=0;for (i=2; i<=NF; i+=2){n+=1;score+=$i}; sum+=score/n}
END{printf ("%.2f\n",sum/NR)}' > "${_dir}/result_${protocol,,}.txt"
Expand Down Expand Up @@ -857,7 +850,7 @@ if "${score_with_asr}"; then
_dir="${enh_exp}/${inference_asr_tag}/${dset}"
fi

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_ddir=${_dir}/spk_${spk}
_logdir="${_ddir}/logdir"
_decode_dir="${_ddir}/decode"
Expand Down Expand Up @@ -953,7 +946,7 @@ if "${score_with_asr}"; then
_dir="${enh_exp}/${inference_asr_tag}/${dset}/"
fi

for spk in $(seq "${spk_num}"); do
for spk in $(seq "${ref_num}"); do
_ddir=${_dir}/spk_${spk}
_logdir="${_ddir}/logdir"
_decode_dir="${_ddir}/decode"
Expand Down
2 changes: 1 addition & 1 deletion egs2/TEMPLATE/enh_diar1/enh_diar.sh
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ if ! "${skip_eval}"; then
${_ref_scp} \
${_inf_scp} \
--ref_channel ${ref_channel} \
--flexible_numspk True
--flexible_numspk

for spk in $(seq "${spk_num}"); do
for protocol in ${scoring_protocol}; do
Expand Down
8 changes: 6 additions & 2 deletions egs2/mini_an4/enh1/conf/train_with_dynamic_mixing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ separator_conf:
# The maximum random gain (in dB) for each source before the mixing.
# The gain (in dB) of each source is unifromly sampled in
# [-dynamic_mixing_gain_db, dynamic_mixing_gain_db]
dynamic_mixing: True
dynamic_mixing_gain_db: 2.0
preprocessor: dynamic_mixing
preprocessor_conf:
num_utts: 2
dynamic_mixing_gain_db: 2.0
source_scp_name: "spk1.scp"
mixture_source_name: "speech_mix"

criterions:
# The first criterion
Expand Down
2 changes: 2 additions & 0 deletions egs2/mini_an4/enh1/conf/train_with_preprocessor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ separator_conf:
unit: 128
dropout: 0.2

preprocessor: enh

criterions:
# The first criterion
- name: mse
Expand Down
2 changes: 1 addition & 1 deletion egs2/mini_an4/enh1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ set -o pipefail
./enh.sh \
--fs 16k \
--lang en \
--spk-num 1 \
--ref-num 1 \
--train_set train_nodev \
--valid_set train_dev \
--test_sets "train_dev test" \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ scheduler_conf:
# The maximum random gain (in dB) for each source before the mixing.
# The gain (in dB) of each source is unifromly sampled in
# [-dynamic_mixing_gain_db, dynamic_mixing_gain_db]
dynamic_mixing: True
dynamic_mixing_gain_db: 2.0
preprocessor: dynamic_mixing
preprocessor_conf:
num_utts: 2
dynamic_mixing_gain_db: 0.0
source_scp_name: "spk1.scp"
mixture_source_name: "speech_mix"

encoder: conv
encoder_conf:
Expand Down
18 changes: 18 additions & 0 deletions egs2/wsj0_2mix/mixit_enh1/RESULTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# RESULTS
## Environments
- date: `Mon Sep 5 14:55:27 EDT 2022`
- python version: `3.9.12 (main, Apr 5 2022, 06:56:58) [GCC 7.5.0]`
- espnet version: `espnet 202207`
- pytorch version: `pytorch 1.10.1`
- Git hash: `6d5236553b7fb3e653907c447bbbbb0790a013f9`
- Commit date: `Wed Aug 31 08:17:56 2022 -0400`


## enh_train_enh_mixit_conv_tasnet_raw

config: conf/tuning/train_enh_conv_tasnet.yaml

|dataset|STOI|SAR|SDR|SIR|SI_SNR|
|---|---|---|---|---|---|
|enhanced_cv_min_8k|91.43|14.55|13.96|24.12|13.34|
|enhanced_tt_min_8k|91.32|13.68|12.91|22.61|12.25|

0 comments on commit 1754de7

Please sign in to comment.