Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MFA format fix #5275

Merged
merged 3 commits into from
Jul 24, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
83 changes: 44 additions & 39 deletions egs2/TEMPLATE/asr1/pyscripts/utils/mfa_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def make_labs_[dataset]:
DURATIONS_PATH = os.path.join(WORK_DIR, "durations")
DICTIONARY_PATH = os.path.join(WORK_DIR, "train_dict.txt")

punctuation = "!',.?" + '"'
punctuation = '!,.?"'

# JP_DICT_URL =
# "https://raw.githubusercontent.com/r9y9/open_jtalk/1.11/src/mecab-naist-jdic/unidic-csj.csv"
Expand All @@ -70,8 +70,8 @@ def get_parser():
parser = argparse.ArgumentParser(
description=(
"Utilities to format from MFA to ESPnet.\n"
"Usage: python scripts/utils/mfa_format.py ACTION [dataset] [options]\n"
"python scripts/utils/mfa_format.py labs ljspeech\n"
"Usage: python scripts/utils/mfa_format.py TASK [--options]\n"
"python scripts/utils/mfa_format.py labs\n"
"python scripts/utils/mfa_format.py validate\n"
"python scripts/utils/mfa_format.py durations\n"
)
Expand Down Expand Up @@ -145,7 +145,7 @@ def get_parser():
def get_phoneme_durations(
data: Dict, original_text: str, fs: int, hop_size: int, n_samples: int
):
"""Get phohene durations."""
"""Get phoneme durations."""
orig_text = original_text.replace(" ", "").rstrip()
text_pos = 0
maxTimestamp = data["end"]
Expand All @@ -172,7 +172,7 @@ def get_phoneme_durations(
puncs = []
while text_pos < len(orig_text):
char = orig_text[text_pos]
if char.isalpha():
if char.isalpha() or char == "'":
break
else:
puncs.append(char)
Expand Down Expand Up @@ -359,64 +359,69 @@ def make_labs(args):
logging.info("Preparing data for %s", dset)
dset: Path = Path("data") / dset
# Generate directories according to spk2utt
speakers = dict()
with open(dset / "spk2utt") as reader:
utt2spk = dict()
with open(dset / "utt2spk") as reader:
for line in reader:
line = line.split()
(corpus_dir / line[0]).mkdir(parents=True, exist_ok=True)
speakers.update({key: line[0] for key in line[1:]})
utt, spk = line.strip().split(maxsplit=1)
utt2spk[utt] = spk
for spk in set(utt2spk.values()):
(corpus_dir / spk).mkdir(parents=True, exist_ok=True)

# Generate labs according to text file
with open(dset / "text", encoding="utf-8") as reader:
for line in reader:
key = line.split()[0]
text = " ".join(line.split()[1:])
utt, text = line.strip().split(maxsplit=1)
text = cleaner(text).lower()
# Convert single quotes into double quotes
# so that MFA doesn't confuse them with clitics.
# Find ' not preceded by a letter to the last ' not followed by a letter
text = re.sub(r"(\W|^)'(\w[\w .,!?']*)'(\W|$)", r'\1"\2"\3', text)

# Remove braces because MFA interprets them as enclosing a single word
text = re.sub(r"[\{\}]", "", text)

# In case of frontend, preprocess data.
if frontend is not None:
text = frontend(text)

spk = speakers.get(key, None)
if spk is None:
continue
with open(
corpus_dir / spk / f"{key}.lab", "w", encoding="utf-8"
) as writer:
writer.write(text)
try:
spk = utt2spk[utt]
with open(
corpus_dir / spk / f"{utt}.lab", "w", encoding="utf-8"
) as writer:
writer.write(text)
except KeyError:
logging.warning(f"{utt} is in text file but not in utt2spk")

# Generate wavs according to wav.scp and segment files
if (dset / "segments").exists():
wscp = (dset / "wav.scp").as_posix()
segments = (dset / "segments").as_posix()
with kaldiio.ReadHelper(f"scp:{wscp}", segments=segments) as reader:
for key, (rate, array) in reader:
spk: str = speakers.get(key, None)
if spk is None:
continue
dst_file = (corpus_dir / spk / f"{key}.wav").as_posix()
sf.write(dst_file, array, rate)
for utt, (rate, array) in reader:
try:
spk = utt2spk[utt]
dst_file = (corpus_dir / spk / f"{utt}.wav").as_posix()
sf.write(dst_file, array, rate)
except KeyError:
logging.warning(f"{utt} is in wav.scp file but not in utt2spk")
else:
with open(dset / "wav.scp") as reader:
rate = None
for line in reader:
line = line.split()
src_file = os.path.abspath(get_path(line))
spk: str = speakers.get(line[0], None)
if spk is None:
continue
dst_file = corpus_dir / spk / f"{line[0]}.wav"
if src_file.endswith(".wav"):
# Create symlink
dst_file.symlink_to(src_file)
else:
# Create wav file
rate, array = kaldiio.load_mat(" ".join(line[1:]))
sf.write(dst_file.as_posix(), array, rate)
utt, src_file = line.strip().split(maxsplit=1)
src_file = os.path.abspath(src_file)
try:
spk = utt2spk[utt]
dst_file = corpus_dir / spk / f"{utt}.wav"
if src_file.endswith(".wav"):
# Create symlink
dst_file.symlink_to(src_file)
else:
# Create wav file
rate, array = kaldiio.load_mat(src_file)
sf.write(dst_file.as_posix(), array, rate)
except KeyError:
logging.warning(f"{utt} is in wav.scp file but not in utt2spk")
logging.info("Finished writing .lab files")


Expand Down