Skip to content

Commit

Permalink
Merge pull request #5275 from iamanigeeit/fix_mfa_format
Browse files Browse the repository at this point in the history
MFA format fix
  • Loading branch information
mergify[bot] committed Jul 24, 2023
2 parents c9cd4de + 93b136a commit 4847b5f
Showing 1 changed file with 44 additions and 39 deletions.
83 changes: 44 additions & 39 deletions egs2/TEMPLATE/asr1/pyscripts/utils/mfa_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def make_labs_[dataset]:
DURATIONS_PATH = os.path.join(WORK_DIR, "durations")
DICTIONARY_PATH = os.path.join(WORK_DIR, "train_dict.txt")

punctuation = "!',.?" + '"'
punctuation = '!,.?"'

# JP_DICT_URL =
# "https://raw.githubusercontent.com/r9y9/open_jtalk/1.11/src/mecab-naist-jdic/unidic-csj.csv"
Expand All @@ -70,8 +70,8 @@ def get_parser():
parser = argparse.ArgumentParser(
description=(
"Utilities to format from MFA to ESPnet.\n"
"Usage: python scripts/utils/mfa_format.py ACTION [dataset] [options]\n"
"python scripts/utils/mfa_format.py labs ljspeech\n"
"Usage: python scripts/utils/mfa_format.py TASK [--options]\n"
"python scripts/utils/mfa_format.py labs\n"
"python scripts/utils/mfa_format.py validate\n"
"python scripts/utils/mfa_format.py durations\n"
)
Expand Down Expand Up @@ -145,7 +145,7 @@ def get_parser():
def get_phoneme_durations(
data: Dict, original_text: str, fs: int, hop_size: int, n_samples: int
):
"""Get phohene durations."""
"""Get phoneme durations."""
orig_text = original_text.replace(" ", "").rstrip()
text_pos = 0
maxTimestamp = data["end"]
Expand All @@ -172,7 +172,7 @@ def get_phoneme_durations(
puncs = []
while text_pos < len(orig_text):
char = orig_text[text_pos]
if char.isalpha():
if char.isalpha() or char == "'":
break
else:
puncs.append(char)
Expand Down Expand Up @@ -359,64 +359,69 @@ def make_labs(args):
logging.info("Preparing data for %s", dset)
dset: Path = Path("data") / dset
# Generate directories according to spk2utt
speakers = dict()
with open(dset / "spk2utt") as reader:
utt2spk = dict()
with open(dset / "utt2spk") as reader:
for line in reader:
line = line.split()
(corpus_dir / line[0]).mkdir(parents=True, exist_ok=True)
speakers.update({key: line[0] for key in line[1:]})
utt, spk = line.strip().split(maxsplit=1)
utt2spk[utt] = spk
for spk in set(utt2spk.values()):
(corpus_dir / spk).mkdir(parents=True, exist_ok=True)

# Generate labs according to text file
with open(dset / "text", encoding="utf-8") as reader:
for line in reader:
key = line.split()[0]
text = " ".join(line.split()[1:])
utt, text = line.strip().split(maxsplit=1)
text = cleaner(text).lower()
# Convert single quotes into double quotes
# so that MFA doesn't confuse them with clitics.
# Find ' not preceded by a letter to the last ' not followed by a letter
text = re.sub(r"(\W|^)'(\w[\w .,!?']*)'(\W|$)", r'\1"\2"\3', text)

# Remove braces because MFA interprets them as enclosing a single word
text = re.sub(r"[\{\}]", "", text)

# In case of frontend, preprocess data.
if frontend is not None:
text = frontend(text)

spk = speakers.get(key, None)
if spk is None:
continue
with open(
corpus_dir / spk / f"{key}.lab", "w", encoding="utf-8"
) as writer:
writer.write(text)
try:
spk = utt2spk[utt]
with open(
corpus_dir / spk / f"{utt}.lab", "w", encoding="utf-8"
) as writer:
writer.write(text)
except KeyError:
logging.warning(f"{utt} is in text file but not in utt2spk")

# Generate wavs according to wav.scp and segment files
if (dset / "segments").exists():
wscp = (dset / "wav.scp").as_posix()
segments = (dset / "segments").as_posix()
with kaldiio.ReadHelper(f"scp:{wscp}", segments=segments) as reader:
for key, (rate, array) in reader:
spk: str = speakers.get(key, None)
if spk is None:
continue
dst_file = (corpus_dir / spk / f"{key}.wav").as_posix()
sf.write(dst_file, array, rate)
for utt, (rate, array) in reader:
try:
spk = utt2spk[utt]
dst_file = (corpus_dir / spk / f"{utt}.wav").as_posix()
sf.write(dst_file, array, rate)
except KeyError:
logging.warning(f"{utt} is in wav.scp file but not in utt2spk")
else:
with open(dset / "wav.scp") as reader:
rate = None
for line in reader:
line = line.split()
src_file = os.path.abspath(get_path(line))
spk: str = speakers.get(line[0], None)
if spk is None:
continue
dst_file = corpus_dir / spk / f"{line[0]}.wav"
if src_file.endswith(".wav"):
# Create symlink
dst_file.symlink_to(src_file)
else:
# Create wav file
rate, array = kaldiio.load_mat(" ".join(line[1:]))
sf.write(dst_file.as_posix(), array, rate)
utt, src_file = line.strip().split(maxsplit=1)
src_file = os.path.abspath(src_file)
try:
spk = utt2spk[utt]
dst_file = corpus_dir / spk / f"{utt}.wav"
if src_file.endswith(".wav"):
# Create symlink
dst_file.symlink_to(src_file)
else:
# Create wav file
rate, array = kaldiio.load_mat(src_file)
sf.write(dst_file.as_posix(), array, rate)
except KeyError:
logging.warning(f"{utt} is in wav.scp file but not in utt2spk")
logging.info("Finished writing .lab files")


Expand Down

0 comments on commit 4847b5f

Please sign in to comment.