Skip to content

Commit

Permalink
Merge pull request kaldi-asr#8 from jsalt2020-asrdiar/libricss
Browse files Browse the repository at this point in the history
Added BUT's VBx diarization
  • Loading branch information
desh2608 committed Jun 12, 2020
2 parents d6110de + ee69d97 commit 1bec93f
Show file tree
Hide file tree
Showing 7 changed files with 480 additions and 138 deletions.
249 changes: 121 additions & 128 deletions egs/callhome_diarization/v1/diarization/VB_diarization.py

Large diffs are not rendered by default.

115 changes: 115 additions & 0 deletions egs/callhome_diarization/v1/diarization/vb_hmm_xvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#!/usr/bin/env python
# Copyright 2020 Johns Hopkins University (Author: Desh Raj)
# Apache 2.0

# This script is based on the Bayesian HMM-based xvector clustering
# code released by BUTSpeech at: https://github.com/BUTSpeechFIT/VBx.
# Note that this assumes that the provided labels are for a single
# recording. So this should be called from a script such as
# vb_hmm_xvector.sh which can divide all labels into per recording
# labels.

import sys, argparse, struct
import numpy as np
import itertools
import kaldi_io

from scipy.special import softmax

import VB_diarization

########### HELPER FUNCTIONS #####################################

def get_args():
parser = argparse.ArgumentParser(
description="""This script performs Bayesian HMM-based
clustering of x-vectors for one recording""",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--init-smoothing", type=float, default=10,
help="AHC produces hard assignments of x-vetors to speakers."
" These are smoothed to soft assignments as the initialization"
" for VB-HMM. This parameter controls the amount of smoothing."
" Not so important, high value (e.g. 10) is OK => keeping hard assigment")
parser.add_argument("--loop-prob", type=float, default=0.80,
help="probability of not switching speakers between frames")
parser.add_argument("--fa", type=float, default=0.4,
help="scale sufficient statistics collected using UBM")
parser.add_argument("--fb", type=float, default=11,
help="speaker regularization coefficient Fb (controls final # of speaker)")
parser.add_argument("xvector_ark_file", type=str,
help="Ark file containing xvectors for all subsegments")
parser.add_argument("plda", type=str,
help="path to PLDA model")
parser.add_argument("input_label_file", type=str,
help="path of input label file")
parser.add_argument("output_label_file", type=str,
help="path of output label file")
args = parser.parse_args()
return args

def read_labels_file(label_file):
segments = []
labels = []
with open(label_file, 'r') as f:
for line in f.readlines():
segment, label = line.strip().split()
segments.append(segment)
labels.append(int(label))
return segments, labels

def write_labels_file(seg2label, out_file):
f = open(out_file, 'w')
for seg in sorted(seg2label.keys()):
f.write("{} {}\n".format(seg, seg2label[seg]))
f.close()
return

def read_args(args):
segments, labels = read_labels_file(args.input_label_file)
xvec_all = dict(kaldi_io.read_vec_flt_ark(args.xvector_ark_file))
xvectors = []
for segment in segments:
xvectors.append(xvec_all[segment])
_, _, plda_psi = kaldi_io.read_plda(args.plda)
return xvectors, segments, labels, plda_psi


###################################################################

def vb_hmm(segments, in_labels, xvectors, plda_psi, init_smoothing, loop_prob, fa, fb):
x = np.array(xvectors)
dim = x.shape[1]

# Smooth the hard labels obtained from AHC to soft assignments of x-vectors to speakers
q_init = np.zeros((len(in_labels), np.max(in_labels)+1))
q_init[range(len(in_labels)), in_labels] = 1.0
q_init = softmax(q_init*init_smoothing, axis=1)

# Prepare model for VB-HMM clustering
ubmWeights = np.array([1.0])
ubmMeans = np.zeros((1,dim))
invSigma= np.ones((1,dim))
V=np.diag(np.sqrt(plda_psi[:dim]))[:,np.newaxis,:]

# Use VB-HMM for x-vector clustering. Instead of i-vector extractor model, we use PLDA
# => GMM with only 1 component, V derived across-class covariance, and invSigma is inverse
# within-class covariance (i.e. identity)
q, _, _ = VB_diarization.VB_diarization(x, ubmMeans, invSigma, ubmWeights, V, pi=None,
gamma=q_init, maxSpeakers=q_init.shape[1], maxIters=40, epsilon=1e-6, loopProb=loop_prob,
Fa=fa, Fb=fb)

labels = np.unique(q.argmax(1), return_inverse=True)[1]

return {seg:label for seg,label in zip(segments,labels)}

def main():
args = get_args()
xvectors, segments, labels, plda_psi = read_args(args)

seg2label_vb = vb_hmm(segments, labels, xvectors, plda_psi, args.init_smoothing,
args.loop_prob, args.fa, args.fb)
write_labels_file(seg2label_vb, args.output_label_file)

if __name__=="__main__":
main()

104 changes: 104 additions & 0 deletions egs/callhome_diarization/v1/diarization/vb_hmm_xvector.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#!/usr/bin/env bash

# Copyright 2020 Desh Raj
# Apache 2.0.

# This script performs Bayesian HMM on top of labels produced
# by a first-pass AHC clustering. See https://arxiv.org/abs/1910.08847
# for details about the model.

# Begin configuration section.
cmd="run.pl"
stage=0
nj=10
cleanup=true
rttm_channel=0

# The hyperparameters used here are taken from the DIHARD
# optimal hyperparameter values reported in:
# http://www.fit.vutbr.cz/research/groups/speech/publi/2019/diez_IEEE_ACM_2019_08910412.pdf
# These may require tuning for different datasets.
loop_prob=0.85
fa=0.2
fb=1

# End configuration section.

echo "$0 $@" # Print the command line for logging

if [ -f path.sh ]; then . ./path.sh; fi
. parse_options.sh || exit 1;


if [ $# != 3 ]; then
echo "Usage: $0 <dir> <xvector-dir> <plda>"
echo " e.g.: $0 exp/ exp/xvectors_dev exp/xvector_nnet_1a/plda"
echo "main options (for others, see top of script file)"
echo " --config <config-file> # config containing options"
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs."
echo " --nj <n|10> # Number of jobs (also see num-processes and num-threads)"
echo " --stage <stage|0> # To control partial reruns"
echo " --cleanup <bool|false> # If true, remove temporary files"
exit 1;
fi

dir=$1
xvec_dir=$2
plda=$3

mkdir -p $dir/tmp

for f in $dir/labels ; do
[ ! -f $f ] && echo "No such file $f" && exit 1;
done

# check if numexpr is installed. Also install
# a modified version of kaldi_io with extra functions
# needed to read the PLDA file
result=`python3 -c "\
try:
import kaldi_io, numexpr
print (int(hasattr(kaldi_io, 'read_plda')))
except ImportError:
print('0')"`

if [ "$result" == "0" ]; then
echo "Installing kaldi_io and numexpr"
python3 -m pip install git+https://github.com/desh2608/kaldi-io-for-python.git@vbx
python3 -m pip install numexpr
fi

if [ $stage -le 0 ]; then
# Mean subtraction (If original x-vectors are high-dim, e.g. 512, you should
# consider also applying LDA to reduce dimensionality to, say, 200)
$cmd $xvec_dir/log/transform.log \
ivector-subtract-global-mean scp:$xvec_dir/xvector.scp ark:$xvec_dir/xvector_norm.ark
fi

echo -e "Performing bayesian HMM based x-vector clustering..\n"
# making a shell script for each job
for n in `seq $nj`; do
cat <<-EOF > $dir/tmp/vb_hmm.$n.sh
python3 diarization/vb_hmm_xvector.py \
--loop-prob $loop_prob --fa $fa --fb $fb \
$xvec_dir/xvector_norm.ark $plda $dir/labels.$n $dir/labels.vb.$n
EOF
done

chmod a+x $dir/tmp/vb_hmm.*.sh
$cmd JOB=1:$nj $dir/log/vb_hmm.JOB.log \
$dir/tmp/vb_hmm.JOB.sh

if [ $stage -le 1 ]; then
echo "$0: combining labels"
for j in $(seq $nj); do cat $dir/labels.vb.$j; done > $dir/labels.vb || exit 1;
fi

if [ $stage -le 2 ]; then
echo "$0: computing RTTM"
diarization/make_rttm.py --rttm-channel $rttm_channel $xvec_dir/plda_scores/segments $dir/labels.vb $dir/rttm.vb || exit 1;
fi

if $cleanup ; then
rm -r $dir/tmp || exit 1;
fi
8 changes: 4 additions & 4 deletions egs/libri_css/s5_mono/local/decode.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ if [ $stage -le 3 ]; then

[ ! -d exp/xvector_nnet_1a ] && ./local/download_diarizer.sh

local/diarize.sh --nj $diar_nj --cmd "$train_cmd" --stage $diarizer_stage \
local/diarize_bhmm.sh --nj $diar_nj --cmd "$train_cmd" --stage $diarizer_stage \
--ref-rttm $ref_rttm \
exp/xvector_nnet_1a \
data/${datadir} \
Expand All @@ -127,7 +127,7 @@ if [ $stage -le 4 ]; then
asr_nj=$(wc -l < "data/$datadir/wav.scp")
local/decode_diarized.sh --nj $asr_nj --cmd "$decode_cmd" --stage $decode_diarize_stage \
--lm-suffix "_tgsmall" \
exp/${datadir}_diarization data/$datadir data/lang_nosp_test_tgsmall \
exp/${datadir}_diarization/rttm.vb data/$datadir data/lang_test_tgsmall \
exp/chain${nnet3_affix}/tdnn_${affix}_sp exp/nnet3${nnet3_affix} \
data/${datadir}_diarized || exit 1
done
Expand Down Expand Up @@ -163,7 +163,7 @@ if $rnnlm_rescore; then
rnnlm/lmrescore$pruned.sh \
--cmd "$decode_cmd --mem 8G" \
--weight 0.45 --max-ngram-order $ngram_order \
data/lang_nosp_test_tgsmall $rnnlm_dir \
data/lang_test_tgsmall $rnnlm_dir \
data/${decode_set}_diarized_hires ${decode_dir} \
${ac_model_dir}/decode_${decode_set}_diarized_2stage_rescore
done
Expand Down Expand Up @@ -207,7 +207,7 @@ fi
if [ $stage -le 9 ]; then
local/decode_oracle.sh --stage $decode_oracle_stage \
--affix $affix \
--lang-dir data/lang_nosp_test_tgsmall \
--lang-dir data/lang_test_tgsmall \
--lm-suffix "_tgsmall" \
--rnnlm-rescore $rnnlm_rescore \
--test_sets "$test_sets"
Expand Down
10 changes: 5 additions & 5 deletions egs/libri_css/s5_mono/local/decode_diarized.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ echo "$0 $@" # Print the command line for logging
if [ -f path.sh ]; then . ./path.sh; fi
. utils/parse_options.sh || exit 1;
if [ $# != 6 ]; then
echo "Usage: $0 <rttm-dir> <in-data-dir> <lang-dir> <model-dir> <ivector-dir> <out-dir>"
echo "Usage: $0 <rttm> <in-data-dir> <lang-dir> <model-dir> <ivector-dir> <out-dir>"
echo "e.g.: $0 data/rttm data/dev data/lang_chain exp/chain/tdnn_1a \
exp/nnet3_cleaned data/dev_diarized"
echo "Options: "
Expand All @@ -24,14 +24,14 @@ if [ $# != 6 ]; then
exit 1;
fi

rttm_dir=$1
rttm=$1
data_in=$2
lang_dir=$3
asr_model_dir=$4
ivector_extractor=$5
out_dir=$6

for f in $rttm_dir/rttm $data_in/wav.scp $data_in/text.bak \
for f in $rttm $data_in/wav.scp $data_in/text.bak \
$lang_dir/L.fst $asr_model_dir/graph${lm_suffix}/HCLG.fst \
$asr_model_dir/final.mdl; do
[ ! -f $f ] && echo "$0: No such file $f" && exit 1;
Expand All @@ -46,8 +46,8 @@ fi

if [ $stage -le 1 ]; then
echo "$0 creating segments file from rttm and utt2spk, reco2file_and_channel "
local/convert_rttm_to_utt2spk_and_segments.py --append-reco-id-to-spkr=true $rttm_dir/rttm \
<(awk '{print $2" "$2" "$3}' $rttm_dir/rttm |sort -u) \
local/convert_rttm_to_utt2spk_and_segments.py --append-reco-id-to-spkr=true $rttm \
<(awk '{print $2" "$2" "$3}' $rttm |sort -u) \
${out_dir}_hires/utt2spk ${out_dir}_hires/segments

utils/utt2spk_to_spk2utt.pl ${out_dir}_hires/utt2spk > ${out_dir}_hires/spk2utt
Expand Down

0 comments on commit 1bec93f

Please sign in to comment.