Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add few shot subset for mSuperb multilingual setting #4923

Merged
merged 3 commits into from
Feb 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
264 changes: 203 additions & 61 deletions egs2/msuperb/asr1/local/data_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,154 @@
"ceb",
"luo",
"kea",
"sum",
"sun",
"tso",
"tos",
]

FEW_SHOT_SELECTED_DATA = {
"lit": [
"cv_lit_000026",
"cv_lit_000105",
"fleurs_lit_000097",
"mls_lit_000040",
"voxpopuli_lit_001707",
],
"dan": [
"NST_dan_000007",
"NST_dan_000072",
"cv_dan_000158",
"fleurs_dan_000058",
"fleurs_dan_000071",
],
"tur": [
"cv_tur_000095",
"cv_tur_000101",
"fleurs_tur_000060",
"fleurs_tur_000069",
"fleurs_tur_000095",
],
"srp": [
"cv_srp_000078",
"cv_srp_000145",
"fleurs_srp_000068",
"fleurs_srp_000093",
"fleurs_srp_000105",
],
"vie": [
"cv_vie_000105",
"cv_vie_000128",
"fleurs_vie_000067",
"fleurs_vie_000068",
"fleurs_vie_000077",
],
"kaz": [
"cv_kaz_000080",
"cv_kaz_000097",
"cv_kaz_000111",
"fleurs_kaz_000036",
"fleurs_kaz_000066",
],
"zul": [
"fleurs_zul_000049",
"fleurs_zul_000057",
"nchlt_zul_000027",
"nchlt_zul_000090",
"nchlt_zul_000104",
],
"tsn": [
"googlei18n-tts_tsn_000026",
"googlei18n-tts_tsn_000044",
"googlei18n-tts_tsn_000108",
"nchlt_tsn_000001",
"nchlt_tsn_000032",
],
"epo": [
"cv_epo_000006",
"cv_epo_000039",
"cv_epo_000063",
"cv_epo_000066",
"cv_epo_000076",
],
"frr": [
"cv_frr_000023",
"cv_frr_000086",
"cv_frr_000095",
"cv_frr_000102",
"cv_frr_000104",
],
"tok": [
"cv_tok_000004",
"cv_tok_000011",
"cv_tok_000030",
"cv_tok_000084",
"cv_tok_000101",
],
"umb": [
"fleurs_umb_000028",
"fleurs_umb_000029",
"fleurs_umb_000033",
"fleurs_umb_000040",
"fleurs_umb_000047",
],
"bos": [
"fleurs_bos_000067",
"fleurs_bos_000078",
"fleurs_bos_000080",
"fleurs_bos_000088",
"fleurs_bos_000090",
],
"ful": [
"fleurs_ful_000055",
"fleurs_ful_000059",
"fleurs_ful_000067",
"fleurs_ful_000076",
"fleurs_ful_000081",
],
"ceb": [
"fleurs_ceb_000054",
"fleurs_ceb_000064",
"fleurs_ceb_000069",
"fleurs_ceb_000071",
"fleurs_ceb_000080",
],
"luo": [
"fleurs_luo_000056",
"fleurs_luo_000062",
"fleurs_luo_000067",
"fleurs_luo_000073",
"fleurs_luo_000077",
],
"kea": [
"fleurs_kea_000052",
"fleurs_kea_000070",
"fleurs_kea_000078",
"fleurs_kea_000083",
"fleurs_kea_000085",
],
"sun": [
"googlei18n-asr_sun_000001",
"googlei18n-asr_sun_000007",
"googlei18n-asr_sun_000041",
"googlei18n-asr_sun_000099",
"googlei18n-asr_sun_000106",
],
"tso": [
"nchlt_tso_000035",
"nchlt_tso_000040",
"nchlt_tso_000089",
"nchlt_tso_000104",
"nchlt_tso_000125",
],
"tos": [
"mexico-el_tos_000006",
"mexico-el_tos_000122",
"mexico-el_tos_000152",
"mexico-el_tos_000496",
"mexico-el_tos_000563",
],
}

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train_set", type=str, default="train_10min")
Expand Down Expand Up @@ -95,69 +238,68 @@
langs_info[lang] = []
langs_info[lang].append(dataset)

if not reserve_flag:
# process train
train_transcript = open(
os.path.join(
args.source,
dataset,
lang,
"transcript_{}_train.txt".format(args.duration),
),
"r",
encoding="utf-8",
)
for line in train_transcript.readlines():
line = line.strip().split(maxsplit=2)
utt_id, _, text = line
train_wavscp.write(
"{} sox {} -c 1 -t wavpcm -|\n".format(
utt_id,
os.path.join(
args.source,
dataset,
lang,
"wav",
"{}.wav".format(utt_id),
),
)
# process train
train_transcript = open(
os.path.join(
args.source,
dataset,
lang,
"transcript_{}_train.txt".format(args.duration),
),
"r",
encoding="utf-8",
)
for line in train_transcript.readlines():
line = line.strip().split(maxsplit=2)
utt_id, _, text = line
if reserve_flag and utt_id not in FEW_SHOT_SELECTED_DATA[lang]:
continue
train_wavscp.write(
"{} sox {} -c 1 -t wavpcm -|\n".format(
utt_id,
os.path.join(
args.source,
dataset,
lang,
"wav",
"{}.wav".format(utt_id),
),
)
if args.lid:
train_text.write("{} [{}] {}\n".format(utt_id, lang, text))
else:
train_text.write("{} {}\n".format(utt_id, text))
train_utt2spk.write("{} {}\n".format(utt_id, utt_id))
train_transcript.close()

# process dev
dev_transcript = open(
os.path.join(
args.source, dataset, lang, "transcript_10min_dev.txt"
),
"r",
encoding="utf-8",
)
for line in dev_transcript.readlines():
line = line.strip().split(maxsplit=2)
utt_id, _, text = line
dev_wavscp.write(
"{} sox {} -c 1 -t wavpcm -|\n".format(
utt_id,
os.path.join(
args.source,
dataset,
lang,
"wav",
"{}.wav".format(utt_id),
),
)
if args.lid:
train_text.write("{} [{}] {}\n".format(utt_id, lang, text))
else:
train_text.write("{} {}\n".format(utt_id, text))
train_utt2spk.write("{} {}\n".format(utt_id, utt_id))
train_transcript.close()

# process dev
dev_transcript = open(
os.path.join(args.source, dataset, lang, "transcript_10min_dev.txt"),
"r",
encoding="utf-8",
)
for line in dev_transcript.readlines():
line = line.strip().split(maxsplit=2)
utt_id, _, text = line
dev_wavscp.write(
"{} sox {} -c 1 -t wavpcm -|\n".format(
utt_id,
os.path.join(
args.source,
dataset,
lang,
"wav",
"{}.wav".format(utt_id),
),
)
if args.lid:
dev_text.write("{} [{}] {}\n".format(utt_id, lang, text))
else:
dev_text.write("{} {}\n".format(utt_id, text))
dev_utt2spk.write("{} {}\n".format(utt_id, utt_id))
dev_transcript.close()
)
if args.lid:
dev_text.write("{} [{}] {}\n".format(utt_id, lang, text))
else:
dev_text.write("{} {}\n".format(utt_id, text))
dev_utt2spk.write("{} {}\n".format(utt_id, utt_id))
dev_transcript.close()

# process test
test_transcript = open(
Expand Down