Skip to content

Commit

Permalink
Merge pull request #4983 from hhhaaahhhaa/master
Browse files Browse the repository at this point in the history
LID score v2
  • Loading branch information
ftshijt committed Mar 6, 2023
2 parents b5b2b11 + cafe59b commit c44538a
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 83 deletions.
5 changes: 2 additions & 3 deletions egs2/msuperb/asr1/local/lid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def main(args):
for line in f:
if line == "\n":
continue
[pred, gt, name] = line.strip().split("\t")
gt = f"[{gt}]"
[gt, pred, name] = line.strip().split("\t")
y_true.append(gt)
y_pred.append(pred)
if pred == gt:
Expand All @@ -25,7 +24,7 @@ def main(args):
# f.write(lid_report)

with open(f"{args.dir}/scores.txt", "w") as f:
f.write(f"Acc: {correct / total * 100:.2f}%")
f.write(f"Acc: {correct / total * 100:.2f}%\n")


if __name__ == "__main__":
Expand Down
31 changes: 13 additions & 18 deletions egs2/msuperb/asr1/local/score.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,30 @@ python local/split_results.py \
--lid ${lid} \
--only_lid ${only_lid}

if "${only_lid}"; then
# directories=$(find ${asr_exp} -wholename "*/*/score_wer/independent/*" -type d -not -path '/\.')
if "${only_lid}" || "${lid}"; then
# directories=$(find ${asr_exp} -wholename "*/*/score_lid/independent/*" -type d -not -path '/\.')
# directories+=" "
directories=$(find ${asr_exp} -wholename "*/*/score_wer/few_shot/*" -type d -not -path '/\.')
directories=$(find ${asr_exp} -wholename "*/*/score_lid/few_shot/*" -type d -not -path '/\.')
# directories+=" "
# directories+=$(find ${asr_exp} -wholename "*/*/score_wer/language_family/*" -type d -not -path '/\.')
# directories+=$(find ${asr_exp} -wholename "*/*/score_lid/language_family/*" -type d -not -path '/\.')
# directories+=" "
# directories+=$(find ${asr_exp} -wholename "*/*/score_wer/all/*" -type d -not -path '/\.')
# directories+=$(find ${asr_exp} -wholename "*/*/score_lid/all/*" -type d -not -path '/\.')
for _scoredir in ${directories}
do
log "Write result in ${_scoredir}/scores.txt"
python local/lid.py --dir ${_scoredir}
cat "${_scoredir}/scores.txt"
done
else
if "${lid}"; then
directories=$(find ${asr_exp} -wholename "*/*/score_wer/few_shot/*" -type d -not -path '/\.')
for _scoredir in ${directories}
do
log "Write result in ${_scoredir}/scores.txt"
python local/lid.py --dir ${_scoredir}
done
fi
# directories=$(find ${asr_exp} -wholename "*/*/*/independent/*" -type d -not -path '/\.')
fi

if ! "${only_lid}"; then
# directories=$(find ${asr_exp} -wholename "*/*/score_cer/independent/*" -type d -not -path '/\.')
# directories+=" "
directories=$(find ${asr_exp} -wholename "*/*/*/few_shot/*" -type d -not -path '/\.')
directories=$(find ${asr_exp} -wholename "*/*/score_cer/few_shot/*" -type d -not -path '/\.')
# directories+=" "
# directories+=$(find ${asr_exp} -wholename "*/*/*/language_family/*" -type d -not -path '/\.')
# directories+=$(find ${asr_exp} -wholename "*/*/score_cer/language_family/*" -type d -not -path '/\.')
# directories+=" "
# directories+=$(find ${asr_exp} -wholename "*/*/*/all/*" -type d -not -path '/\.')
# directories+=$(find ${asr_exp} -wholename "*/*/score_cer/all/*" -type d -not -path '/\.')
for _scoredir in ${directories}
do
log "Write result in ${_scoredir}/result.txt"
Expand Down
151 changes: 89 additions & 62 deletions egs2/msuperb/asr1/local/split_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import glob
import json
import os
import re
from typing import Callable, List

from linguistic_tree import LanguageTree
Expand Down Expand Up @@ -88,49 +89,55 @@ def write_lines(lines, path):
f.write(f"{line}\n")


def get_info_from_line(line):
def get_info_from_trn_line(line):
utt_id = line.split("\t")[-1]
utt_id = utt_id[1 : len(utt_id) // 2]
iso = utt_id.split("_")[-2]
return iso, utt_id


def lid_parse(root, lines):
new_lines = []
lid_info = None
# Extract LID info from WER results if LID exists
if (LID or ONLY_LID) and "score_wer" in root:
lid_info = []
for line in lines:
if line[0] == "\t": # prediction is NULL...
lid = "[UNK]"
else:
words = line.split("\t")[0].split(" ")
lid = words[0]
iso, utt_id = get_info_from_line(line)
lid_info.append(f"{lid}\t{iso}\t{utt_id}")

# Remove LID in Multilingual + LID case
for line in lines:
if LID and (not ONLY_LID):
if line[0] == "\t": # prediction is NULL...
pass
else:
text = line.split("\t")[0]
if "score_wer" in root:
if " " in text:
line = line.split(" ", 1)[1]
else:
line = line[len(text) :]
elif "score_cer" in root:
chars = line.split(" ")
if "<space>" in chars:
idx = chars.index("<space>")
line = " ".join(chars[idx + 1 :])
else:
line = line[len(text) :]
new_lines.append(line)
return new_lines, lid_info
def get_info_from_raw_line(line):
segs = line.split(" ", 1)
if len(segs) == 1: # null prediction
utt_id, text = segs[0], ""
else:
[utt_id, text] = line.split(" ", 1)
isos = re.findall(r"\[[a-z]{3}\]", text)
isos.extend(
re.findall(r"\[[a-z]{3}_[a-z]{3}\]", text)
) # There is a LID call [org_jpn]
return isos, utt_id


def lid_parse(ref_path, hyp_path):
def read_lid(path):
res = {}
for line in read_lines(path):
try:
isos, utt_id = get_info_from_raw_line(line)
except Exception:
print(line)
raise
res[utt_id] = isos
return res

lid_info = []
ref_lid_info = read_lid(ref_path)
hyp_lid_info = read_lid(hyp_path)
for utt_id, hyp_isos in hyp_lid_info.items():
assert (
utt_id in ref_lid_info
), f"Can not find groundtruth of utterance ({utt_id}) in {ref_path}."
assert (
len(ref_lid_info[utt_id]) == 1
), f"Utternace ({utt_id}) should have exactly one LID in {ref_path}."
if len(hyp_isos) == 0:
hyp_isos.append("[UNK]")
ref_iso = ref_lid_info[utt_id][0]
hyp_iso = hyp_isos[0]
lid_info.append(f"{ref_iso}\t{hyp_iso}\t({utt_id}-{utt_id})")

return lid_info


def no_rule(iso):
Expand Down Expand Up @@ -161,39 +168,59 @@ def language_family_rule(iso):

def split_trn_by_rule(root, name, rule_fn, trn_path):
lines = read_lines(trn_path)
lines, lid_info = lid_parse(root, lines)

categorizer.set_category_func(lambda line: rule_fn(get_info_from_line(line)[0]))
categorizer.set_category_func(lambda line: rule_fn(get_info_from_trn_line(line)[0]))
set2lines = categorizer.exec(lines)
for k, v in set2lines.items():
write_lines(v, f"{root}/{name}/{k}/{os.path.basename(trn_path)}")

if lid_info is not None and "hyp.trn" in trn_path:
categorizer.set_category_func(lambda line: rule_fn(line.split("\t")[1]))
set2lid_results = categorizer.exec(lid_info)
for k, v in set2lid_results.items():
write_lines(v, f"{root}/{name}/{k}/lid.trn")


def main(args):
roots = []
if not ONLY_LID:
for txt_paths in glob.glob(f"{args.dir}/*/*/score_cer/result.txt"):
roots.append(os.path.dirname(txt_paths))
for txt_paths in glob.glob(f"{args.dir}/*/*/score_wer/result.txt"):
roots.append(os.path.dirname(txt_paths))
for root in roots:
print(f"Parsing results in {root}...")
ref_trn_path = f"{root}/ref.trn"
hyp_trn_path = f"{root}/hyp.trn"
split_trn_by_rule(root, "independent", independent_rule, ref_trn_path)
split_trn_by_rule(root, "independent", independent_rule, hyp_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, ref_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, hyp_trn_path)
split_trn_by_rule(root, "language_family", language_family_rule, ref_trn_path)
split_trn_by_rule(root, "language_family", language_family_rule, hyp_trn_path)
split_trn_by_rule(root, "all", no_rule, ref_trn_path)
split_trn_by_rule(root, "all", no_rule, hyp_trn_path)
if not ONLY_LID: # TER will be parsed from trn and using sclite as the usual case
for txt_path in glob.glob(f"{args.dir}/*/*/score_cer/result.txt"):
roots.append(os.path.dirname(txt_path))
for txt_path in glob.glob(f"{args.dir}/*/*/score_wer/result.txt"):
roots.append(os.path.dirname(txt_path))
for root in roots:
print(f"Parsing TER results in {root}...")
ref_trn_path = f"{root}/ref.trn"
hyp_trn_path = f"{root}/hyp.trn"

split_trn_by_rule(root, "independent", independent_rule, ref_trn_path)
split_trn_by_rule(root, "independent", independent_rule, hyp_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, ref_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, hyp_trn_path)
split_trn_by_rule(
root, "language_family", language_family_rule, ref_trn_path
)
split_trn_by_rule(
root, "language_family", language_family_rule, hyp_trn_path
)
split_trn_by_rule(root, "all", no_rule, ref_trn_path)
split_trn_by_rule(root, "all", no_rule, hyp_trn_path)

if LID or ONLY_LID: # LID will be parsed from inferenced text file directly
tasks = []
for hyp_txt_path in glob.glob(f"{args.dir}/*/*/text"):
root = os.path.dirname(hyp_txt_path)
data_dirname = os.path.basename(root)
ref_txt_path = f"data/{data_dirname}/text"
tasks.append((root, ref_txt_path, hyp_txt_path))
for root, ref_txt_path, hyp_txt_path in tasks:
print(f"Parsing LID results in {root}...")
root = f"{root}/score_lid"
lid_trn_path = f"{root}/lid.trn"

lid_info = lid_parse(ref_txt_path, hyp_txt_path)
write_lines(lid_info, lid_trn_path)

split_trn_by_rule(root, "independent", independent_rule, lid_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, lid_trn_path)
split_trn_by_rule(
root, "language_family", language_family_rule, lid_trn_path
)
split_trn_by_rule(root, "all", no_rule, lid_trn_path)


if __name__ == "__main__":
Expand Down

0 comments on commit c44538a

Please sign in to comment.