forked from kaldi-asr/kaldi
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request kaldi-asr#8 from jsalt2020-asrdiar/libricss
Added BUT's VBx diarization
- Loading branch information
Showing
7 changed files
with
480 additions
and
138 deletions.
There are no files selected for viewing
249 changes: 121 additions & 128 deletions
249
egs/callhome_diarization/v1/diarization/VB_diarization.py
Large diffs are not rendered by default.
Oops, something went wrong.
115 changes: 115 additions & 0 deletions
115
egs/callhome_diarization/v1/diarization/vb_hmm_xvector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
#!/usr/bin/env python | ||
# Copyright 2020 Johns Hopkins University (Author: Desh Raj) | ||
# Apache 2.0 | ||
|
||
# This script is based on the Bayesian HMM-based xvector clustering | ||
# code released by BUTSpeech at: https://github.com/BUTSpeechFIT/VBx. | ||
# Note that this assumes that the provided labels are for a single | ||
# recording. So this should be called from a script such as | ||
# vb_hmm_xvector.sh which can divide all labels into per recording | ||
# labels. | ||
|
||
import sys, argparse, struct | ||
import numpy as np | ||
import itertools | ||
import kaldi_io | ||
|
||
from scipy.special import softmax | ||
|
||
import VB_diarization | ||
|
||
########### HELPER FUNCTIONS ##################################### | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser( | ||
description="""This script performs Bayesian HMM-based | ||
clustering of x-vectors for one recording""", | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument("--init-smoothing", type=float, default=10, | ||
help="AHC produces hard assignments of x-vetors to speakers." | ||
" These are smoothed to soft assignments as the initialization" | ||
" for VB-HMM. This parameter controls the amount of smoothing." | ||
" Not so important, high value (e.g. 10) is OK => keeping hard assigment") | ||
parser.add_argument("--loop-prob", type=float, default=0.80, | ||
help="probability of not switching speakers between frames") | ||
parser.add_argument("--fa", type=float, default=0.4, | ||
help="scale sufficient statistics collected using UBM") | ||
parser.add_argument("--fb", type=float, default=11, | ||
help="speaker regularization coefficient Fb (controls final # of speaker)") | ||
parser.add_argument("xvector_ark_file", type=str, | ||
help="Ark file containing xvectors for all subsegments") | ||
parser.add_argument("plda", type=str, | ||
help="path to PLDA model") | ||
parser.add_argument("input_label_file", type=str, | ||
help="path of input label file") | ||
parser.add_argument("output_label_file", type=str, | ||
help="path of output label file") | ||
args = parser.parse_args() | ||
return args | ||
|
||
def read_labels_file(label_file): | ||
segments = [] | ||
labels = [] | ||
with open(label_file, 'r') as f: | ||
for line in f.readlines(): | ||
segment, label = line.strip().split() | ||
segments.append(segment) | ||
labels.append(int(label)) | ||
return segments, labels | ||
|
||
def write_labels_file(seg2label, out_file): | ||
f = open(out_file, 'w') | ||
for seg in sorted(seg2label.keys()): | ||
f.write("{} {}\n".format(seg, seg2label[seg])) | ||
f.close() | ||
return | ||
|
||
def read_args(args): | ||
segments, labels = read_labels_file(args.input_label_file) | ||
xvec_all = dict(kaldi_io.read_vec_flt_ark(args.xvector_ark_file)) | ||
xvectors = [] | ||
for segment in segments: | ||
xvectors.append(xvec_all[segment]) | ||
_, _, plda_psi = kaldi_io.read_plda(args.plda) | ||
return xvectors, segments, labels, plda_psi | ||
|
||
|
||
################################################################### | ||
|
||
def vb_hmm(segments, in_labels, xvectors, plda_psi, init_smoothing, loop_prob, fa, fb): | ||
x = np.array(xvectors) | ||
dim = x.shape[1] | ||
|
||
# Smooth the hard labels obtained from AHC to soft assignments of x-vectors to speakers | ||
q_init = np.zeros((len(in_labels), np.max(in_labels)+1)) | ||
q_init[range(len(in_labels)), in_labels] = 1.0 | ||
q_init = softmax(q_init*init_smoothing, axis=1) | ||
|
||
# Prepare model for VB-HMM clustering | ||
ubmWeights = np.array([1.0]) | ||
ubmMeans = np.zeros((1,dim)) | ||
invSigma= np.ones((1,dim)) | ||
V=np.diag(np.sqrt(plda_psi[:dim]))[:,np.newaxis,:] | ||
|
||
# Use VB-HMM for x-vector clustering. Instead of i-vector extractor model, we use PLDA | ||
# => GMM with only 1 component, V derived across-class covariance, and invSigma is inverse | ||
# within-class covariance (i.e. identity) | ||
q, _, _ = VB_diarization.VB_diarization(x, ubmMeans, invSigma, ubmWeights, V, pi=None, | ||
gamma=q_init, maxSpeakers=q_init.shape[1], maxIters=40, epsilon=1e-6, loopProb=loop_prob, | ||
Fa=fa, Fb=fb) | ||
|
||
labels = np.unique(q.argmax(1), return_inverse=True)[1] | ||
|
||
return {seg:label for seg,label in zip(segments,labels)} | ||
|
||
def main(): | ||
args = get_args() | ||
xvectors, segments, labels, plda_psi = read_args(args) | ||
|
||
seg2label_vb = vb_hmm(segments, labels, xvectors, plda_psi, args.init_smoothing, | ||
args.loop_prob, args.fa, args.fb) | ||
write_labels_file(seg2label_vb, args.output_label_file) | ||
|
||
if __name__=="__main__": | ||
main() | ||
|
104 changes: 104 additions & 0 deletions
104
egs/callhome_diarization/v1/diarization/vb_hmm_xvector.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#!/usr/bin/env bash | ||
|
||
# Copyright 2020 Desh Raj | ||
# Apache 2.0. | ||
|
||
# This script performs Bayesian HMM on top of labels produced | ||
# by a first-pass AHC clustering. See https://arxiv.org/abs/1910.08847 | ||
# for details about the model. | ||
|
||
# Begin configuration section. | ||
cmd="run.pl" | ||
stage=0 | ||
nj=10 | ||
cleanup=true | ||
rttm_channel=0 | ||
|
||
# The hyperparameters used here are taken from the DIHARD | ||
# optimal hyperparameter values reported in: | ||
# http://www.fit.vutbr.cz/research/groups/speech/publi/2019/diez_IEEE_ACM_2019_08910412.pdf | ||
# These may require tuning for different datasets. | ||
loop_prob=0.85 | ||
fa=0.2 | ||
fb=1 | ||
|
||
# End configuration section. | ||
|
||
echo "$0 $@" # Print the command line for logging | ||
|
||
if [ -f path.sh ]; then . ./path.sh; fi | ||
. parse_options.sh || exit 1; | ||
|
||
|
||
if [ $# != 3 ]; then | ||
echo "Usage: $0 <dir> <xvector-dir> <plda>" | ||
echo " e.g.: $0 exp/ exp/xvectors_dev exp/xvector_nnet_1a/plda" | ||
echo "main options (for others, see top of script file)" | ||
echo " --config <config-file> # config containing options" | ||
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs." | ||
echo " --nj <n|10> # Number of jobs (also see num-processes and num-threads)" | ||
echo " --stage <stage|0> # To control partial reruns" | ||
echo " --cleanup <bool|false> # If true, remove temporary files" | ||
exit 1; | ||
fi | ||
|
||
dir=$1 | ||
xvec_dir=$2 | ||
plda=$3 | ||
|
||
mkdir -p $dir/tmp | ||
|
||
for f in $dir/labels ; do | ||
[ ! -f $f ] && echo "No such file $f" && exit 1; | ||
done | ||
|
||
# check if numexpr is installed. Also install | ||
# a modified version of kaldi_io with extra functions | ||
# needed to read the PLDA file | ||
result=`python3 -c "\ | ||
try: | ||
import kaldi_io, numexpr | ||
print (int(hasattr(kaldi_io, 'read_plda'))) | ||
except ImportError: | ||
print('0')"` | ||
|
||
if [ "$result" == "0" ]; then | ||
echo "Installing kaldi_io and numexpr" | ||
python3 -m pip install git+https://github.com/desh2608/kaldi-io-for-python.git@vbx | ||
python3 -m pip install numexpr | ||
fi | ||
|
||
if [ $stage -le 0 ]; then | ||
# Mean subtraction (If original x-vectors are high-dim, e.g. 512, you should | ||
# consider also applying LDA to reduce dimensionality to, say, 200) | ||
$cmd $xvec_dir/log/transform.log \ | ||
ivector-subtract-global-mean scp:$xvec_dir/xvector.scp ark:$xvec_dir/xvector_norm.ark | ||
fi | ||
|
||
echo -e "Performing bayesian HMM based x-vector clustering..\n" | ||
# making a shell script for each job | ||
for n in `seq $nj`; do | ||
cat <<-EOF > $dir/tmp/vb_hmm.$n.sh | ||
python3 diarization/vb_hmm_xvector.py \ | ||
--loop-prob $loop_prob --fa $fa --fb $fb \ | ||
$xvec_dir/xvector_norm.ark $plda $dir/labels.$n $dir/labels.vb.$n | ||
EOF | ||
done | ||
|
||
chmod a+x $dir/tmp/vb_hmm.*.sh | ||
$cmd JOB=1:$nj $dir/log/vb_hmm.JOB.log \ | ||
$dir/tmp/vb_hmm.JOB.sh | ||
|
||
if [ $stage -le 1 ]; then | ||
echo "$0: combining labels" | ||
for j in $(seq $nj); do cat $dir/labels.vb.$j; done > $dir/labels.vb || exit 1; | ||
fi | ||
|
||
if [ $stage -le 2 ]; then | ||
echo "$0: computing RTTM" | ||
diarization/make_rttm.py --rttm-channel $rttm_channel $xvec_dir/plda_scores/segments $dir/labels.vb $dir/rttm.vb || exit 1; | ||
fi | ||
|
||
if $cleanup ; then | ||
rm -r $dir/tmp || exit 1; | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.