Skip to content

Commit

Permalink
Merge pull request #4925 from popcornell/chime7task1
Browse files Browse the repository at this point in the history
Fixing some issues with chime7-task1 baseline
  • Loading branch information
mergify[bot] committed Feb 9, 2023
2 parents 7cf99cf + ecf3bd3 commit ffbf7e0
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 59 deletions.
93 changes: 50 additions & 43 deletions egs2/chime7_task1/asr1/local/data.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ gss_dsets=
manifests_root=
gss_dump_root=
augm_num_data_reps=4
decode_only=0
foreground_snrs="20:10:15:5:0"
background_snrs="20:10:15:5:0"

Expand All @@ -25,6 +26,11 @@ background_snrs="20:10:15:5:0"
gss_dsets=$(echo $gss_dsets | tr "," " ") # split by commas


if [ $decode_only == 1 ]; then
# stop after gss
stop_stage=1
fi

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
log "Dumping all lhotse manifests to kaldi manifests and merging everything for dev set close mics,
you may want these for validation."
Expand All @@ -43,7 +49,49 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
fi


if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
if [ ${stage} -le 1 ] && [ $stop_stage -ge 1 ]; then
# Preparing ASR training and validation data;
log "Parsing the GSS output to Kaldi manifests"
cv_kaldi_manifests_gss=()
tr_kaldi_manifests_gss=() # if gss is used also for training
for dset in $gss_dsets; do
# for each dataset get the name and part (dev or train)
dset_name="$(cut -d'_' -f1 <<<${dset})"
dset_part="$(cut -d'_' -f2 <<<${dset})"
python local/gss2lhotse.py -i ${gss_dump_root}/${dset_name}/${dset_part} \
-o $manifests_root/gss/${dset_name}/${dset_part}/${dset_name}_${dset_part}_gss

lhotse kaldi export -p $manifests_root/gss/${dset_name}/${dset_part}/${dset_name}_${dset_part}_gss_recordings.jsonl.gz \
$manifests_root/gss/${dset_name}/${dset_part}/${dset_name}_${dset_part}_gss_supervisions.jsonl.gz \
data/kaldi/${dset_name}/${dset_part}/gss

./utils/utt2spk_to_spk2utt.pl data/kaldi/${dset_name}/${dset_part}/gss/utt2spk > data/kaldi/${dset_name}/${dset_part}/gss/spk2utt
./utils/fix_data_dir.sh data/kaldi/${dset_name}/${dset_part}/gss

if [ $dset_part == train ]; then
tr_kaldi_manifests_gss+=( "data/kaldi/${dset}/${dset_part}/gss")
fi

if [ $dset_part == dev ]; then
cv_kaldi_manifests_gss+=( "data/kaldi/${dset}/${dset_part}/gss")
fi
done

if (( ${#tr_kaldi_manifests_gss[@]} )); then
# possibly combine with all training data the gss training data
tr_kaldi_manifests_gss+=( data/kaldi/train_all_mdm_ihm_rvb)
./utils/combine_data.sh data/kaldi/train_all_mdm_ihm_rvb_gss "${tr_kaldi_manifests_gss[@]}"
./utils/fix_data_dir.sh data/kaldi/train_all_mdm_ihm_rvb_gss
fi

if (( ${#cv_kaldi_manifests_gss[@]} )); then # concatenate all gss data to use for validation
./utils/combine_data.sh data/kaldi/dev_all_gss "${cv_kaldi_manifests_gss[@]}"
./utils/fix_data_dir.sh data/kaldi/dev_all_gss
fi
fi


if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
all_tr_manifests=()
all_tr_manifests_ihm=()
log "Dumping all lhotse manifests to kaldi manifests and merging everything for training set."
Expand Down Expand Up @@ -78,7 +126,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
fi


if [ $stage -le 2 ] && [ $stop_stage -ge 3 ]; then
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Augmenting close-talk data with MUSAN and CHiME-6 extracted noises."
local/extract_noises.py ${chime6_root}/audio/train ${chime6_root}/transcriptions/train \
local/distant_audio_list distant_noises
Expand Down Expand Up @@ -116,47 +164,6 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 3 ]; then
fi


if [ ${stage} -le 3 ] && [ $stop_stage -ge 3 ]; then
# Preparing ASR training and validation data;
log "Parsing the GSS output to Kaldi manifests"
cv_kaldi_manifests_gss=()
tr_kaldi_manifests_gss=() # if gss is used also for training
for dset in $gss_dsets; do
# for each dataset get the name and part (dev or train)
dset_name="$(cut -d'_' -f1 <<<${dset})"
dset_part="$(cut -d'_' -f2 <<<${dset})"
python local/gss2lhotse.py -i ${gss_dump_root}/${dset_name}/${dset_part} \
-o $manifests_root/gss/${dset_name}/${dset_part}/${dset_name}_${dset_part}_gss

lhotse kaldi export -p $manifests_root/gss/${dset_name}/${dset_part}/${dset_name}_${dset_part}_gss_recordings.jsonl.gz \
$manifests_root/gss/${dset_name}/${dset_part}/${dset_name}_${dset_part}_gss_supervisions.jsonl.gz \
data/kaldi/${dset_name}/${dset_part}/gss

./utils/utt2spk_to_spk2utt.pl data/kaldi/${dset_name}/${dset_part}/gss/utt2spk > data/kaldi/${dset_name}/${dset_part}/gss/spk2utt
./utils/fix_data_dir.sh data/kaldi/${dset_name}/${dset_part}/gss

if [ $dset_part == train ]; then
tr_kaldi_manifests_gss+=( "data/kaldi/${dset}/${dset_part}/gss")
fi

if [ $dset_part == dev ]; then
cv_kaldi_manifests_gss+=( "data/kaldi/${dset}/${dset_part}/gss")
fi
done

if (( ${#tr_kaldi_manifests_gss[@]} )); then
# possibly combine with all training data the gss training data
tr_kaldi_manifests_gss+=( data/kaldi/train_all_mdm_ihm_rvb)
./utils/combine_data.sh data/kaldi/train_all_mdm_ihm_rvb_gss "${tr_kaldi_manifests_gss[@]}"
./utils/fix_data_dir.sh data/kaldi/train_all_mdm_ihm_rvb_gss
fi

if (( ${#cv_kaldi_manifests_gss[@]} )); then # concatenate all gss data to use for validation
./utils/combine_data.sh data/kaldi/dev_all_gss "${cv_kaldi_manifests_gss[@]}"
./utils/fix_data_dir.sh data/kaldi/dev_all_gss
fi
fi


if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
log "stage 2: Create non linguistic symbols: ${nlsyms_file}"
Expand Down
6 changes: 5 additions & 1 deletion egs2/chime7_task1/asr1/local/gen_task1_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,11 @@ def _get_time(x):
sess2audio = {}
for x in audio_files:
session_name = Path(x).stem.split("_")[0]
if Path(x).stem.split("_")[-1].startswith("P") and eval_opt != 2:
if (
split == "eval"
and Path(x).stem.split("_")[-1].startswith("P")
and eval_opt != 2
):
continue
if session_name not in sess2audio:
sess2audio[session_name] = [x]
Expand Down
10 changes: 8 additions & 2 deletions egs2/chime7_task1/asr1/local/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
set -euo pipefail
[ -f ./path.sh ] && . ./path.sh


if ! command conda &>/dev/null; then
echo "Conda command not found, please follow the instructions on
this recipe README.md on how to install ESPNet with conda as the venv."
fi

# install lhotse from master, we need the most up-to-date one
pip install git+https://github.com/lhotse-speech/lhotse

Expand All @@ -13,11 +19,11 @@ if ! command -v wav-reverberate &>/dev/null; then
fi

# install s3prl
./tools/installers/install_s3prl.sh
${MAIN_ROOT}/tools/installers/install_s3prl.sh

if ! command -v gss &>/dev/null; then
conda install -yc conda-forge cupy=10.2
./${MAIN_ROOT}/tools/install_gss.sh
${MAIN_ROOT}/tools/installers/install_gss.sh.
fi

sox_conda=`command -v ../../../tools/venv/bin/sox 2>/dev/null`
Expand Down
6 changes: 3 additions & 3 deletions egs2/chime7_task1/asr1/local/run_gss.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
echo "Stage 4: Enhance segments using GSS"

affix=()
affix=
if ! [ $channels == all ]; then
affix+=("--channels=$channels")
affix+="--channels=$channels"
fi

$cmd JOB=1:$nj ${exp_dir}/${dset_name}/${dset_part}/log/enhance.JOB.log \
Expand All @@ -66,5 +66,5 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
--num-workers 4 \
--force-overwrite \
--duration-tolerance 3.0 \
"${affix[@]}" || exit 1
${affix} || exit 1
fi
42 changes: 32 additions & 10 deletions egs2/chime7_task1/asr1/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ log() {
echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
######################################################################################
# CHiME-7 Task 1 SUB-TASK 1 baseline system script: GSS + ASR using oracle diarization
# CHiME-7 Task 1 SUB-TASK baseline system script: GSS + ASR using oracle diarization
######################################################################################

stage=3
Expand All @@ -19,7 +19,7 @@ stop_stage=100
chime7_root=${PWD}/chime7_task1
chime5_root= # you can leave it empty if you have already generated CHiME-6 data
chime6_root=/raid/users/popcornell/CHiME6/espnet/egs2/chime6/asr1/CHiME6 # will be created automatically from chime5
# but if you have it already it will be skipped
# but if you have it already it will be skipped, please put your own path
dipco_root=${PWD}/../../chime7/task1/datasets/dipco # this will be automatically downloaded
mixer6_root=/raid/users/popcornell/mixer6/

Expand All @@ -38,6 +38,7 @@ train_max_segment_length=20 # also reduce if you get OOM, here A100 40GB
gss_max_batch_dur=360 # set accordingly to your GPU VRAM, here A100 40GB
cmd_gss=run.pl # change to suit your needs e.g. slurm !
gss_dsets="chime6_train,chime6_dev,dipco_dev,mixer6_dev" # no mixer6 train in baseline
# but you can try to add it.


# ASR CONFIG
Expand All @@ -52,26 +53,31 @@ use_lm=false
use_word_lm=false
word_vocab_size=65000
nbpe=500


# and not contribute much (but you may use all)
asr_max_epochs=8
# ESPNet does not scale parameters with num of GPUs by default, doing it
# here for you
# put popcornell/chime7_task1_asr1_baseline if you want to test with pretrained model
use_pretrained=
decode_only=0

. ./path.sh
. ./cmd.sh
. ./utils/parse_options.sh


# ESPNet does not scale parameters with num of GPUs by default, doing it
# here for you
asr_batch_size=$(calc_int 128*$ngpu) # reduce 128 bsz if you get OOMs errors
asr_max_lr=$(calc_float $ngpu/10000.0)
asr_warmup=$(calc_int 40000.0/$ngpu)

if [ $decode_only == 1 ]; then
# apply gss only on dev
gss_dsets="chime6_dev,dipco_dev,mixer6_dev"
fi

if [ ${stage} -le 0 ] && [ $stop_stage -ge 0 ]; then
# this script creates the task1 dataset
local/gen_task1_data.sh --chime6-root $chime6_root --stage $dprep_stage --chime7-root $chime7_root \
--chime5_root $chime5_root \
--chime5_root "$chime5_root" \
--dipco-root $dipco_root \
--mixer6-root $mixer6_root \
--stage $dprep_stage \
Expand All @@ -86,6 +92,11 @@ if [ ${stage} -le 1 ] && [ $stop_stage -ge 1 ]; then
if [ $dset == dipco ] && [ $dset_part == train ]; then
continue # dipco has no train set
fi

if [ $decode_only == 1 ] && [ $dset_part == train ]; then
continue
fi

log "Creating lhotse manifests for ${dset} in $manifests_root/${dset}"
python local/get_lhotse_manifests.py -c $chime7_root \
-d $dset \
Expand Down Expand Up @@ -117,7 +128,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then

if [ ${dset_name} == dipco ]; then
channels=2,5,9,12,16,19,23,26,30,33 # in dipco only using opposite mics on each array, works better
elif [ ${dset_name} == chime6 ] && [ ${dset_part} == dev ]; then
elif [ ${dset_name} == chime6 ] && [ ${dset_part} == dev ]; then # use only outer mics
channels=0,3,4,7,8,11,12,15,16,19
fi

Expand All @@ -140,17 +151,28 @@ fi


if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then

# ASR training and inference on dev set
asr_train_set=kaldi/train_all_mdm_ihm_rvb_gss
asr_cv_set=kaldi/chime6/dev/gss # use chime only for validation
# Decoding on dev set because test is blind for now
asr_tt_set="kaldi/chime6/dev/gss/ kaldi/dipco/dev/gss/ kaldi/mixer6/dev/gss/"
# these are args to ASR data prep, done in local/data.sh
data_opts="--stage $asr_dprep_stage --chime6-root ${chime6_root} --train-set ${asr_train_set}"
data_opts+=" --manifests-root $manifests_root --gss_dsets $gss_dsets --gss-dump-root $gss_dump_root"
data_opts+=" --decode-only $decode_only"
# override ASR conf/tuning to scale automatically with num of GPUs
asr_args="--batch_size ${asr_batch_size} --scheduler_conf warmup_steps=${asr_warmup}"
asr_args+=" --max_epoch=${asr_max_epochs} --optim_conf lr=${asr_max_lr}"

pretrained_affix=
if [ -z "$use_pretrained" ]; then
pretrained_affix+="--skip_data_prep true --skip_train true "
pretrained_affix+="--download_model ${use_pretrained}"
fi



./asr.sh \
--lang en \
--local_data_opts "${data_opts}" \
Expand All @@ -177,5 +199,5 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--valid_set "${asr_cv_set}" \
--test_sets "${asr_tt_set}" \
--bpe_train_text "data/${asr_train_set}/text" \
--lm_train_text "data/${asr_train_set}/text" "$@"
--lm_train_text "data/${asr_train_set}/text" ${pretrained_affix}
fi

0 comments on commit ffbf7e0

Please sign in to comment.