forked from kaldi-asr/kaldi
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request kaldi-asr#3 from jsalt2020-asrdiar/libricss
Decoding and scoring with diarized output
- Loading branch information
Showing
13 changed files
with
806 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright 2020 Desh Raj | ||
# Apache 2.0. | ||
|
||
import sys, io | ||
import itertools | ||
import numpy as np | ||
from scipy.optimize import linear_sum_assignment | ||
import math | ||
|
||
# Helper function to group the list by ref/hyp ids | ||
def groupby(iterable, keyfunc): | ||
"""Wrapper around ``itertools.groupby`` which sorts data first.""" | ||
iterable = sorted(iterable, key=keyfunc) | ||
for key, group in itertools.groupby(iterable, keyfunc): | ||
yield key, group | ||
|
||
# This class stores all information about a ref/hyp matching | ||
class WerObject: | ||
# By default, we set the errors to very high values to | ||
# handle the error case. | ||
id = '' | ||
ref_id = '' | ||
hyp_id= '' | ||
wer = 0 | ||
num_ins = 0 | ||
num_del = 0 | ||
num_sub = 0 | ||
wc = 0 | ||
|
||
def __init__(self, line): | ||
self.id, details = line.strip().split(maxsplit=1) | ||
tokens = details.split() | ||
self.wer = float(tokens[1]) | ||
self.wc = int(tokens[5][:-1]) | ||
self.num_ins = int(tokens[6]) | ||
self.num_del = int(tokens[8]) | ||
self.num_sub = int(tokens[10]) | ||
self.ref_id, self.hyp_id = self.id[1:].split('h') | ||
|
||
|
||
infile = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') | ||
|
||
# First we read all lines and create a list of WER objects | ||
wer_objects=[] | ||
for line in infile: | ||
if line.strip() == "": | ||
continue | ||
wer_object = WerObject(line) | ||
wer_objects.append(wer_object) | ||
|
||
# Now we create a matrix of costs (WER) which we will use to solve | ||
# a linear sum assignment problem | ||
wer_object_matrix = [list(g) for ref_id, g in groupby(wer_objects, lambda x: x.ref_id)] | ||
if len(wer_object_matrix) > len(wer_object_matrix[0]): | ||
# More references than hypothesis; take transpose | ||
wer_object_matrix = [*zip(*wer_object_matrix)] | ||
wer_matrix = np.array([[1000 if math.isnan(obj.wer) else obj.wer | ||
for obj in row] | ||
for row in wer_object_matrix]) | ||
|
||
# Solve the assignment problem and compute WER statistics | ||
row_ind, col_ind = linear_sum_assignment(wer_matrix) | ||
total_ins = 0 | ||
total_del = 0 | ||
total_sub = 0 | ||
total_wc = 0 | ||
for row,col in zip(row_ind,col_ind): | ||
total_ins += wer_object_matrix[row][col].num_ins | ||
total_del += wer_object_matrix[row][col].num_del | ||
total_sub += wer_object_matrix[row][col].num_sub | ||
total_wc += wer_object_matrix[row][col].wc | ||
total_error = total_ins+total_del+total_sub | ||
wer = float(100*total_error)/total_wc | ||
|
||
# Write the final statistics to stdout | ||
print ("%WER {:.2f} [ {} / {}, {} ins, {} del, {} sub ]".format(wer, total_error, total_wc, | ||
total_ins, total_del, total_sub)) |
98 changes: 98 additions & 0 deletions
98
egs/libri_css/s5_mono/local/convert_rttm_to_utt2spk_and_segments.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#! /usr/bin/env python | ||
# Copyright 2019 Vimal Manohar | ||
# Apache 2.0. | ||
|
||
"""This script converts an RTTM with | ||
speaker info into kaldi utt2spk and segments""" | ||
|
||
import argparse | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser( | ||
description="""This script converts an RTTM with | ||
speaker info into kaldi utt2spk and segments""") | ||
parser.add_argument("--use-reco-id-as-spkr", type=str, | ||
choices=["true", "false"], default="false", | ||
help="Use the recording ID based on RTTM and " | ||
"reco2file_and_channel as the speaker") | ||
parser.add_argument("--append-reco-id-to-spkr", type=str, | ||
choices=["true", "false"], default="false", | ||
help="Append recording ID to the speaker ID") | ||
|
||
parser.add_argument("rttm_file", type=str, | ||
help="""Input RTTM file. | ||
The format of the RTTM file is | ||
<type> <file-id> <channel-id> <begin-time> """ | ||
"""<end-time> <NA> <NA> <speaker> <conf>""") | ||
parser.add_argument("reco2file_and_channel", type=str, | ||
help="""Input reco2file_and_channel. | ||
The format is <recording-id> <file-id> <channel-id>.""") | ||
parser.add_argument("utt2spk", type=str, | ||
help="Output utt2spk file") | ||
parser.add_argument("segments", type=str, | ||
help="Output segments file") | ||
|
||
args = parser.parse_args() | ||
|
||
args.use_reco_id_as_spkr = bool(args.use_reco_id_as_spkr == "true") | ||
args.append_reco_id_to_spkr = bool(args.append_reco_id_to_spkr == "true") | ||
|
||
if args.use_reco_id_as_spkr: | ||
if args.append_reco_id_to_spkr: | ||
raise Exception("Appending recording ID to speaker does not make sense when using --use-reco-id-as-spkr=true") | ||
|
||
return args | ||
|
||
def main(): | ||
args = get_args() | ||
|
||
file_and_channel2reco = {} | ||
utt2spk={} | ||
segments={} | ||
for line in open(args.reco2file_and_channel): | ||
parts = line.strip().split() | ||
file_and_channel2reco[(parts[1], parts[2])] = parts[0] | ||
|
||
utt2spk_writer = open(args.utt2spk, 'w') | ||
segments_writer = open(args.segments, 'w') | ||
for line in open(args.rttm_file): | ||
parts = line.strip().split() | ||
if parts[0] != "SPEAKER": | ||
continue | ||
|
||
file_id = parts[1] | ||
channel = parts[2] | ||
|
||
try: | ||
reco = file_and_channel2reco[(file_id, channel)] | ||
except KeyError as e: | ||
raise Exception("Could not find recording with " | ||
"(file_id, channel) " | ||
"= ({0},{1}) in {2}: {3}\n".format( | ||
file_id, channel, | ||
args.reco2file_and_channel, str(e))) | ||
|
||
start_time = float(parts[3]) | ||
end_time = start_time + float(parts[4]) | ||
|
||
if args.use_reco_id_as_spkr: | ||
spkr = reco | ||
else: | ||
if args.append_reco_id_to_spkr: | ||
spkr = parts[7] + "_" + reco | ||
else: | ||
spkr = parts[7] | ||
|
||
st = int(start_time * 100) | ||
end = int(end_time * 100) | ||
utt = "{0}_{1:06d}_{2:06d}".format(spkr, st, end) | ||
utt2spk[utt]=spkr | ||
segments[utt]=(reco, start_time, end_time) | ||
|
||
for uttid_id in sorted(utt2spk): | ||
utt2spk_writer.write("{0} {1}\n".format(uttid_id, utt2spk[uttid_id])) | ||
segments_writer.write("{0} {1} {2:7.2f} {3:7.2f}\n".format( | ||
uttid_id, segments[uttid_id][0], segments[uttid_id][1], segments[uttid_id][2])) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!/usr/bin/env bash | ||
# Copyright 2019 Ashish Arora, Vimal Manohar | ||
# Apache 2.0. | ||
# This script takes an rttm file, and performs decoding on on a test directory. | ||
# The output directory contains a text file which can be used for scoring. | ||
|
||
|
||
stage=0 | ||
nj=8 | ||
cmd=queue.pl | ||
lm_suffix= | ||
|
||
echo "$0 $@" # Print the command line for logging | ||
|
||
if [ -f path.sh ]; then . ./path.sh; fi | ||
. utils/parse_options.sh || exit 1; | ||
if [ $# != 6 ]; then | ||
echo "Usage: $0 <rttm-dir> <in-data-dir> <lang-dir> <model-dir> <ivector-dir> <out-dir>" | ||
echo "e.g.: $0 data/rttm data/dev data/lang_chain exp/chain/tdnn_1a \ | ||
exp/nnet3_cleaned data/dev_diarized" | ||
echo "Options: " | ||
echo " --nj <nj> # number of parallel jobs." | ||
echo " --cmd (utils/run.pl|utils/queue.pl <queue opts>) # how to run jobs." | ||
exit 1; | ||
fi | ||
|
||
rttm_dir=$1 | ||
data_in=$2 | ||
lang_dir=$3 | ||
asr_model_dir=$4 | ||
ivector_extractor=$5 | ||
out_dir=$6 | ||
|
||
for f in $rttm_dir/rttm $data_in/wav.scp $data_in/text.bak \ | ||
$lang_dir/L.fst $asr_model_dir/graph${lm_suffix}/HCLG.fst \ | ||
$asr_model_dir/final.mdl; do | ||
[ ! -f $f ] && echo "$0: No such file $f" && exit 1; | ||
done | ||
|
||
if [ $stage -le 0 ]; then | ||
echo "$0 copying data files in output directory" | ||
mkdir -p ${out_dir}_hires | ||
cp ${data_in}/{wav.scp,utt2spk,utt2spk.bak} ${out_dir}_hires | ||
utils/data/get_reco2dur.sh ${out_dir}_hires | ||
fi | ||
|
||
if [ $stage -le 1 ]; then | ||
echo "$0 creating segments file from rttm and utt2spk, reco2file_and_channel " | ||
local/convert_rttm_to_utt2spk_and_segments.py --append-reco-id-to-spkr=true $rttm_dir/rttm \ | ||
<(awk '{print $2" "$2" "$3}' $rttm_dir/rttm |sort -u) \ | ||
${out_dir}_hires/utt2spk ${out_dir}_hires/segments | ||
|
||
utils/utt2spk_to_spk2utt.pl ${out_dir}_hires/utt2spk > ${out_dir}_hires/spk2utt | ||
utils/fix_data_dir.sh ${out_dir}_hires || exit 1; | ||
fi | ||
|
||
if [ $stage -le 2 ]; then | ||
echo "$0 extracting mfcc freatures using segments file" | ||
steps/make_mfcc.sh --mfcc-config conf/mfcc_hires.conf --nj $nj --cmd queue.pl ${out_dir}_hires | ||
steps/compute_cmvn_stats.sh ${out_dir}_hires | ||
cp $data_in/text.bak ${out_dir}_hires/text | ||
fi | ||
|
||
if [ $stage -le 3 ]; then | ||
echo "$0 performing decoding on the extracted features" | ||
local/nnet3/decode.sh --affix 2stage --acwt 1.0 --post-decode-acwt 10.0 \ | ||
--frames-per-chunk 150 --nj $nj --ivector-dir $ivector_extractor \ | ||
$out_dir $lang_dir $asr_model_dir/graph${lm_suffix} $asr_model_dir/ | ||
fi | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.