Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LID score v2 #4983

Merged
merged 3 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions egs2/msuperb/asr1/local/lid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def main(args):
for line in f:
if line == "\n":
continue
[pred, gt, name] = line.strip().split("\t")
gt = f"[{gt}]"
[gt, pred, name] = line.strip().split("\t")
y_true.append(gt)
y_pred.append(pred)
if pred == gt:
Expand All @@ -25,7 +24,7 @@ def main(args):
# f.write(lid_report)

with open(f"{args.dir}/scores.txt", "w") as f:
f.write(f"Acc: {correct / total * 100:.2f}%")
f.write(f"Acc: {correct / total * 100:.2f}%\n")


if __name__ == "__main__":
Expand Down
31 changes: 13 additions & 18 deletions egs2/msuperb/asr1/local/score.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,30 @@ python local/split_results.py \
--lid ${lid} \
--only_lid ${only_lid}

if "${only_lid}"; then
# directories=$(find ${asr_exp} -wholename "*/*/score_wer/independent/*" -type d -not -path '/\.')
if "${only_lid}" || "${lid}"; then
# directories=$(find ${asr_exp} -wholename "*/*/score_lid/independent/*" -type d -not -path '/\.')
# directories+=" "
directories=$(find ${asr_exp} -wholename "*/*/score_wer/few_shot/*" -type d -not -path '/\.')
directories=$(find ${asr_exp} -wholename "*/*/score_lid/few_shot/*" -type d -not -path '/\.')
# directories+=" "
# directories+=$(find ${asr_exp} -wholename "*/*/score_wer/language_family/*" -type d -not -path '/\.')
# directories+=$(find ${asr_exp} -wholename "*/*/score_lid/language_family/*" -type d -not -path '/\.')
# directories+=" "
# directories+=$(find ${asr_exp} -wholename "*/*/score_wer/all/*" -type d -not -path '/\.')
# directories+=$(find ${asr_exp} -wholename "*/*/score_lid/all/*" -type d -not -path '/\.')
for _scoredir in ${directories}
do
log "Write result in ${_scoredir}/scores.txt"
python local/lid.py --dir ${_scoredir}
cat "${_scoredir}/scores.txt"
done
else
if "${lid}"; then
directories=$(find ${asr_exp} -wholename "*/*/score_wer/few_shot/*" -type d -not -path '/\.')
for _scoredir in ${directories}
do
log "Write result in ${_scoredir}/scores.txt"
python local/lid.py --dir ${_scoredir}
done
fi
# directories=$(find ${asr_exp} -wholename "*/*/*/independent/*" -type d -not -path '/\.')
fi

if ! "${only_lid}"; then
# directories=$(find ${asr_exp} -wholename "*/*/score_cer/independent/*" -type d -not -path '/\.')
# directories+=" "
directories=$(find ${asr_exp} -wholename "*/*/*/few_shot/*" -type d -not -path '/\.')
directories=$(find ${asr_exp} -wholename "*/*/score_cer/few_shot/*" -type d -not -path '/\.')
# directories+=" "
# directories+=$(find ${asr_exp} -wholename "*/*/*/language_family/*" -type d -not -path '/\.')
# directories+=$(find ${asr_exp} -wholename "*/*/score_cer/language_family/*" -type d -not -path '/\.')
# directories+=" "
# directories+=$(find ${asr_exp} -wholename "*/*/*/all/*" -type d -not -path '/\.')
# directories+=$(find ${asr_exp} -wholename "*/*/score_cer/all/*" -type d -not -path '/\.')
for _scoredir in ${directories}
do
log "Write result in ${_scoredir}/result.txt"
Expand Down
151 changes: 89 additions & 62 deletions egs2/msuperb/asr1/local/split_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import glob
import json
import os
import re
from typing import Callable, List

from linguistic_tree import LanguageTree
Expand Down Expand Up @@ -88,49 +89,55 @@ def write_lines(lines, path):
f.write(f"{line}\n")


def get_info_from_line(line):
def get_info_from_trn_line(line):
utt_id = line.split("\t")[-1]
utt_id = utt_id[1 : len(utt_id) // 2]
iso = utt_id.split("_")[-2]
return iso, utt_id


def lid_parse(root, lines):
new_lines = []
lid_info = None
# Extract LID info from WER results if LID exists
if (LID or ONLY_LID) and "score_wer" in root:
lid_info = []
for line in lines:
if line[0] == "\t": # prediction is NULL...
lid = "[UNK]"
else:
words = line.split("\t")[0].split(" ")
lid = words[0]
iso, utt_id = get_info_from_line(line)
lid_info.append(f"{lid}\t{iso}\t{utt_id}")

# Remove LID in Multilingual + LID case
for line in lines:
if LID and (not ONLY_LID):
if line[0] == "\t": # prediction is NULL...
pass
else:
text = line.split("\t")[0]
if "score_wer" in root:
if " " in text:
line = line.split(" ", 1)[1]
else:
line = line[len(text) :]
elif "score_cer" in root:
chars = line.split(" ")
if "<space>" in chars:
idx = chars.index("<space>")
line = " ".join(chars[idx + 1 :])
else:
line = line[len(text) :]
new_lines.append(line)
return new_lines, lid_info
def get_info_from_raw_line(line):
segs = line.split(" ", 1)
if len(segs) == 1: # null prediction
utt_id, text = segs[0], ""
else:
[utt_id, text] = line.split(" ", 1)
isos = re.findall(r"\[[a-z]{3}\]", text)
isos.extend(
re.findall(r"\[[a-z]{3}_[a-z]{3}\]", text)
) # There is a LID call [org_jpn]
return isos, utt_id


def lid_parse(ref_path, hyp_path):
def read_lid(path):
res = {}
for line in read_lines(path):
try:
isos, utt_id = get_info_from_raw_line(line)
except Exception:
print(line)
raise
res[utt_id] = isos
return res

lid_info = []
ref_lid_info = read_lid(ref_path)
hyp_lid_info = read_lid(hyp_path)
for utt_id, hyp_isos in hyp_lid_info.items():
assert (
utt_id in ref_lid_info
), f"Can not find groundtruth of utterance ({utt_id}) in {ref_path}."
assert (
len(ref_lid_info[utt_id]) == 1
), f"Utternace ({utt_id}) should have exactly one LID in {ref_path}."
if len(hyp_isos) == 0:
hyp_isos.append("[UNK]")
ref_iso = ref_lid_info[utt_id][0]
hyp_iso = hyp_isos[0]
lid_info.append(f"{ref_iso}\t{hyp_iso}\t({utt_id}-{utt_id})")

return lid_info


def no_rule(iso):
Expand Down Expand Up @@ -161,39 +168,59 @@ def language_family_rule(iso):

def split_trn_by_rule(root, name, rule_fn, trn_path):
lines = read_lines(trn_path)
lines, lid_info = lid_parse(root, lines)

categorizer.set_category_func(lambda line: rule_fn(get_info_from_line(line)[0]))
categorizer.set_category_func(lambda line: rule_fn(get_info_from_trn_line(line)[0]))
set2lines = categorizer.exec(lines)
for k, v in set2lines.items():
write_lines(v, f"{root}/{name}/{k}/{os.path.basename(trn_path)}")

if lid_info is not None and "hyp.trn" in trn_path:
categorizer.set_category_func(lambda line: rule_fn(line.split("\t")[1]))
set2lid_results = categorizer.exec(lid_info)
for k, v in set2lid_results.items():
write_lines(v, f"{root}/{name}/{k}/lid.trn")


def main(args):
roots = []
if not ONLY_LID:
for txt_paths in glob.glob(f"{args.dir}/*/*/score_cer/result.txt"):
roots.append(os.path.dirname(txt_paths))
for txt_paths in glob.glob(f"{args.dir}/*/*/score_wer/result.txt"):
roots.append(os.path.dirname(txt_paths))
for root in roots:
print(f"Parsing results in {root}...")
ref_trn_path = f"{root}/ref.trn"
hyp_trn_path = f"{root}/hyp.trn"
split_trn_by_rule(root, "independent", independent_rule, ref_trn_path)
split_trn_by_rule(root, "independent", independent_rule, hyp_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, ref_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, hyp_trn_path)
split_trn_by_rule(root, "language_family", language_family_rule, ref_trn_path)
split_trn_by_rule(root, "language_family", language_family_rule, hyp_trn_path)
split_trn_by_rule(root, "all", no_rule, ref_trn_path)
split_trn_by_rule(root, "all", no_rule, hyp_trn_path)
if not ONLY_LID: # TER will be parsed from trn and using sclite as the usual case
for txt_path in glob.glob(f"{args.dir}/*/*/score_cer/result.txt"):
roots.append(os.path.dirname(txt_path))
for txt_path in glob.glob(f"{args.dir}/*/*/score_wer/result.txt"):
roots.append(os.path.dirname(txt_path))
for root in roots:
print(f"Parsing TER results in {root}...")
ref_trn_path = f"{root}/ref.trn"
hyp_trn_path = f"{root}/hyp.trn"

split_trn_by_rule(root, "independent", independent_rule, ref_trn_path)
split_trn_by_rule(root, "independent", independent_rule, hyp_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, ref_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, hyp_trn_path)
split_trn_by_rule(
root, "language_family", language_family_rule, ref_trn_path
)
split_trn_by_rule(
root, "language_family", language_family_rule, hyp_trn_path
)
split_trn_by_rule(root, "all", no_rule, ref_trn_path)
split_trn_by_rule(root, "all", no_rule, hyp_trn_path)

if LID or ONLY_LID: # LID will be parsed from inferenced text file directly
tasks = []
for hyp_txt_path in glob.glob(f"{args.dir}/*/*/text"):
root = os.path.dirname(hyp_txt_path)
data_dirname = os.path.basename(root)
ref_txt_path = f"data/{data_dirname}/text"
tasks.append((root, ref_txt_path, hyp_txt_path))
for root, ref_txt_path, hyp_txt_path in tasks:
print(f"Parsing LID results in {root}...")
root = f"{root}/score_lid"
lid_trn_path = f"{root}/lid.trn"

lid_info = lid_parse(ref_txt_path, hyp_txt_path)
write_lines(lid_info, lid_trn_path)

split_trn_by_rule(root, "independent", independent_rule, lid_trn_path)
split_trn_by_rule(root, "few_shot", few_shot_rule, lid_trn_path)
split_trn_by_rule(
root, "language_family", language_family_rule, lid_trn_path
)
split_trn_by_rule(root, "all", no_rule, lid_trn_path)


if __name__ == "__main__":
Expand Down