Skip to content

Commit

Permalink
Merge pull request kaldi-asr#3 from jsalt2020-asrdiar/libricss
Browse files Browse the repository at this point in the history
Decoding and scoring with diarized output
  • Loading branch information
desh2608 committed May 25, 2020
2 parents 581cee8 + 61fa30f commit bba30e5
Show file tree
Hide file tree
Showing 13 changed files with 806 additions and 22 deletions.
78 changes: 78 additions & 0 deletions egs/libri_css/s5_mono/local/best_wer_matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright 2020 Desh Raj
# Apache 2.0.

import sys, io
import itertools
import numpy as np
from scipy.optimize import linear_sum_assignment
import math

# Helper function to group the list by ref/hyp ids
def groupby(iterable, keyfunc):
"""Wrapper around ``itertools.groupby`` which sorts data first."""
iterable = sorted(iterable, key=keyfunc)
for key, group in itertools.groupby(iterable, keyfunc):
yield key, group

# This class stores all information about a ref/hyp matching
class WerObject:
# By default, we set the errors to very high values to
# handle the error case.
id = ''
ref_id = ''
hyp_id= ''
wer = 0
num_ins = 0
num_del = 0
num_sub = 0
wc = 0

def __init__(self, line):
self.id, details = line.strip().split(maxsplit=1)
tokens = details.split()
self.wer = float(tokens[1])
self.wc = int(tokens[5][:-1])
self.num_ins = int(tokens[6])
self.num_del = int(tokens[8])
self.num_sub = int(tokens[10])
self.ref_id, self.hyp_id = self.id[1:].split('h')


infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')

# First we read all lines and create a list of WER objects
wer_objects=[]
for line in infile:
if line.strip() == "":
continue
wer_object = WerObject(line)
wer_objects.append(wer_object)

# Now we create a matrix of costs (WER) which we will use to solve
# a linear sum assignment problem
wer_object_matrix = [list(g) for ref_id, g in groupby(wer_objects, lambda x: x.ref_id)]
if len(wer_object_matrix) > len(wer_object_matrix[0]):
# More references than hypothesis; take transpose
wer_object_matrix = [*zip(*wer_object_matrix)]
wer_matrix = np.array([[1000 if math.isnan(obj.wer) else obj.wer
for obj in row]
for row in wer_object_matrix])

# Solve the assignment problem and compute WER statistics
row_ind, col_ind = linear_sum_assignment(wer_matrix)
total_ins = 0
total_del = 0
total_sub = 0
total_wc = 0
for row,col in zip(row_ind,col_ind):
total_ins += wer_object_matrix[row][col].num_ins
total_del += wer_object_matrix[row][col].num_del
total_sub += wer_object_matrix[row][col].num_sub
total_wc += wer_object_matrix[row][col].wc
total_error = total_ins+total_del+total_sub
wer = float(100*total_error)/total_wc

# Write the final statistics to stdout
print ("%WER {:.2f} [ {} / {}, {} ins, {} del, {} sub ]".format(wer, total_error, total_wc,
total_ins, total_del, total_sub))
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#! /usr/bin/env python
# Copyright 2019 Vimal Manohar
# Apache 2.0.

"""This script converts an RTTM with
speaker info into kaldi utt2spk and segments"""

import argparse

def get_args():
parser = argparse.ArgumentParser(
description="""This script converts an RTTM with
speaker info into kaldi utt2spk and segments""")
parser.add_argument("--use-reco-id-as-spkr", type=str,
choices=["true", "false"], default="false",
help="Use the recording ID based on RTTM and "
"reco2file_and_channel as the speaker")
parser.add_argument("--append-reco-id-to-spkr", type=str,
choices=["true", "false"], default="false",
help="Append recording ID to the speaker ID")

parser.add_argument("rttm_file", type=str,
help="""Input RTTM file.
The format of the RTTM file is
<type> <file-id> <channel-id> <begin-time> """
"""<end-time> <NA> <NA> <speaker> <conf>""")
parser.add_argument("reco2file_and_channel", type=str,
help="""Input reco2file_and_channel.
The format is <recording-id> <file-id> <channel-id>.""")
parser.add_argument("utt2spk", type=str,
help="Output utt2spk file")
parser.add_argument("segments", type=str,
help="Output segments file")

args = parser.parse_args()

args.use_reco_id_as_spkr = bool(args.use_reco_id_as_spkr == "true")
args.append_reco_id_to_spkr = bool(args.append_reco_id_to_spkr == "true")

if args.use_reco_id_as_spkr:
if args.append_reco_id_to_spkr:
raise Exception("Appending recording ID to speaker does not make sense when using --use-reco-id-as-spkr=true")

return args

def main():
args = get_args()

file_and_channel2reco = {}
utt2spk={}
segments={}
for line in open(args.reco2file_and_channel):
parts = line.strip().split()
file_and_channel2reco[(parts[1], parts[2])] = parts[0]

utt2spk_writer = open(args.utt2spk, 'w')
segments_writer = open(args.segments, 'w')
for line in open(args.rttm_file):
parts = line.strip().split()
if parts[0] != "SPEAKER":
continue

file_id = parts[1]
channel = parts[2]

try:
reco = file_and_channel2reco[(file_id, channel)]
except KeyError as e:
raise Exception("Could not find recording with "
"(file_id, channel) "
"= ({0},{1}) in {2}: {3}\n".format(
file_id, channel,
args.reco2file_and_channel, str(e)))

start_time = float(parts[3])
end_time = start_time + float(parts[4])

if args.use_reco_id_as_spkr:
spkr = reco
else:
if args.append_reco_id_to_spkr:
spkr = parts[7] + "_" + reco
else:
spkr = parts[7]

st = int(start_time * 100)
end = int(end_time * 100)
utt = "{0}_{1:06d}_{2:06d}".format(spkr, st, end)
utt2spk[utt]=spkr
segments[utt]=(reco, start_time, end_time)

for uttid_id in sorted(utt2spk):
utt2spk_writer.write("{0} {1}\n".format(uttid_id, utt2spk[uttid_id]))
segments_writer.write("{0} {1} {2:7.2f} {3:7.2f}\n".format(
uttid_id, segments[uttid_id][0], segments[uttid_id][1], segments[uttid_id][2]))

if __name__ == '__main__':
main()
25 changes: 18 additions & 7 deletions egs/libri_css/s5_mono/local/decode.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
nj=8
stage=0
score_sad=true
diarizer_stage=4
diarizer_stage=0
decode_diarize_stage=0
decode_oracle_stage=1
score_stage=0
affix=1d # This should be the affix of the tdnn model you want to decode with

# If the following is set to true, we use the oracle speaker and segment
# information instead of performing SAD and diarization.
Expand Down Expand Up @@ -107,19 +108,28 @@ fi
#######################################################################
if [ $stage -le 4 ]; then
for datadir in ${test_sets}; do
local/decode_diarized.sh --nj $nj --cmd "$decode_cmd" --stage $decode_diarize_stage \
exp/${datadir}_${nnet_type}_seg_diarization data/$datadir data/lang \
exp/chain_${train_set}_cleaned_rvb exp/nnet3_${train_set}_cleaned_rvb \
asr_nj=$(wc -l < "data/$datadir/wav.scp")
local/decode_diarized.sh --nj $asr_nj --cmd "$decode_cmd" --stage $decode_diarize_stage \
--lm-suffix "_tgsmall" \
exp/${datadir}_diarization data/$datadir data/lang_nosp_test_tgsmall \
exp/chain_cleaned/tdnn_${affix}_sp exp/nnet3_cleaned \
data/${datadir}_diarized || exit 1
done
fi

#######################################################################
# Score decoded dev/eval sets
#######################################################################
# if [ $stage -le 5 ]; then
# # TODO
# fi
if [ $stage -le 5 ]; then
# final scoring to get the challenge result
# please specify both dev and eval set directories so that the search parameters
# (insertion penalty and language model weight) will be tuned using the dev set
local/score_reco_diarized.sh --stage $score_stage \
--dev_decodedir exp/chain_cleaned/tdnn_${affix}_sp/decode_dev_diarized_2stage \
--dev_datadir dev_diarized_hires \
--eval_decodedir exp/chain_cleaned/tdnn_${affix}_sp/decode_eval_diarized_2stage \
--eval_datadir eval_diarized_hires
fi

$use_oracle_segments || exit 0

Expand Down Expand Up @@ -148,6 +158,7 @@ fi

if [ $stage -le 7 ]; then
local/decode_oracle.sh --stage $decode_oracle_stage \
--affix $affix \
--lang-dir data/lang_nosp_test_tgsmall \
--lm-suffix "_tgsmall" \
--test_sets "$test_sets"
Expand Down
70 changes: 70 additions & 0 deletions egs/libri_css/s5_mono/local/decode_diarized.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env bash
# Copyright 2019 Ashish Arora, Vimal Manohar
# Apache 2.0.
# This script takes an rttm file, and performs decoding on on a test directory.
# The output directory contains a text file which can be used for scoring.


stage=0
nj=8
cmd=queue.pl
lm_suffix=

echo "$0 $@" # Print the command line for logging

if [ -f path.sh ]; then . ./path.sh; fi
. utils/parse_options.sh || exit 1;
if [ $# != 6 ]; then
echo "Usage: $0 <rttm-dir> <in-data-dir> <lang-dir> <model-dir> <ivector-dir> <out-dir>"
echo "e.g.: $0 data/rttm data/dev data/lang_chain exp/chain/tdnn_1a \
exp/nnet3_cleaned data/dev_diarized"
echo "Options: "
echo " --nj <nj> # number of parallel jobs."
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs."
exit 1;
fi

rttm_dir=$1
data_in=$2
lang_dir=$3
asr_model_dir=$4
ivector_extractor=$5
out_dir=$6

for f in $rttm_dir/rttm $data_in/wav.scp $data_in/text.bak \
$lang_dir/L.fst $asr_model_dir/graph${lm_suffix}/HCLG.fst \
$asr_model_dir/final.mdl; do
[ ! -f $f ] && echo "$0: No such file $f" && exit 1;
done

if [ $stage -le 0 ]; then
echo "$0 copying data files in output directory"
mkdir -p ${out_dir}_hires
cp ${data_in}/{wav.scp,utt2spk,utt2spk.bak} ${out_dir}_hires
utils/data/get_reco2dur.sh ${out_dir}_hires
fi

if [ $stage -le 1 ]; then
echo "$0 creating segments file from rttm and utt2spk, reco2file_and_channel "
local/convert_rttm_to_utt2spk_and_segments.py --append-reco-id-to-spkr=true $rttm_dir/rttm \
<(awk '{print $2" "$2" "$3}' $rttm_dir/rttm |sort -u) \
${out_dir}_hires/utt2spk ${out_dir}_hires/segments

utils/utt2spk_to_spk2utt.pl ${out_dir}_hires/utt2spk > ${out_dir}_hires/spk2utt
utils/fix_data_dir.sh ${out_dir}_hires || exit 1;
fi

if [ $stage -le 2 ]; then
echo "$0 extracting mfcc freatures using segments file"
steps/make_mfcc.sh --mfcc-config conf/mfcc_hires.conf --nj $nj --cmd queue.pl ${out_dir}_hires
steps/compute_cmvn_stats.sh ${out_dir}_hires
cp $data_in/text.bak ${out_dir}_hires/text
fi

if [ $stage -le 3 ]; then
echo "$0 performing decoding on the extracted features"
local/nnet3/decode.sh --affix 2stage --acwt 1.0 --post-decode-acwt 10.0 \
--frames-per-chunk 150 --nj $nj --ivector-dir $ivector_extractor \
$out_dir $lang_dir $asr_model_dir/graph${lm_suffix} $asr_model_dir/
fi

2 changes: 1 addition & 1 deletion egs/libri_css/s5_mono/local/decode_oracle.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ stage=0
test_sets=
lang_dir=
lm_suffix=
affix=1d # affix for the TDNN directory name

# End configuration section
. ./utils/parse_options.sh

. ./cmd.sh
. ./path.sh

affix=1d # affix for the TDNN directory name
dir=exp/chain${nnet3_affix}/tdnn_${affix}_sp


Expand Down

0 comments on commit bba30e5

Please sign in to comment.