In [92]:
docstring = """
Helper script that takes a folder of speech reps (wav2vec2, mel-spec, etc.)
and aligns them at word-level using MFA alignments.

Speech reps corresponding to word tokens in the corpus are then saved individually to an output folder
with the following structure:
- data_path
    - word1
        - word1_LJ010-0292_001.pt
        - word1_LJ010-0292_002.pt
        - ...
    - word2
        - word2_LJ001-0012_001.pt
        - word2_LJ002-0024_001.pt
        - ...
    - ...

- word1, word2, ... subfolders refer to a particular wordtype in the corpus.
- .pt files contain speech representations that map to a particular example of a wordtype.
  It is named as:
    <wordtype>_<utt id>_<numbered occurrence in the utterance>.pt

Example usage:
    #hubert w/ padding offset
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t hubert \
        --padding_idx_offset 1 \
        -s /home/s1785140/fairseq/examples/lexicon_learner/lj_speech_quantized.txt \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_hubert_reps/hubert-base/layer-6/word_level_with_padding_idx_offset

    #hubert w/o padding offset
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t hubert \
        --padding_idx_offset 0 \
        -s /home/s1785140/fairseq/examples/lexicon_learner/lj_speech_quantized.txt \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_hubert_reps/hubert-base/layer-6/word_level_without_padding_idx_offset

    #wav2vec2
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t wav2vec2 \
        -s /home/s1785140/data/ljspeech_wav2vec2_reps/wav2vec2-large-960h/layer-15/utt_level \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_wav2vec2_reps/wav2vec2-large-960h/layer-15/word_level
"""

In [93]:
# imitate CLAs
import sys
sys.argv = [
    'train.py',
    '--type', 'mel',
    '--padding_idx_offset', '0',
    '--input_directory', '/home/s1785140/data/ljspeech_fastpitch/mels',
    '--alignments', '/home/s1785140/data/ljspeech_fastpitch/aligns', 
    '--output_directory', '/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels',
    
    # FOR TESTING
    # '--input_directory', '/home/s1785140/data/ljspeech_fastpitch/mels_test',
    # '--alignments', '/home/s1785140/data/ljspeech_fastpitch/aligns_test', 
    # '--output_directory', '/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels_test',
]

In [94]:
import os
import argparse
import torch
from tqdm import tqdm
from collections import Counter
import numpy as np
import tgt
import string

SAMPLING_RATE = 22050
HOP_LENGTH = 256

In [95]:
def parse_textgrid(tier, sampling_rate, hop_length, ignore_all_pauses=True):
    # latest MFA replaces silence phones with "" in output TextGrids
    sil_phones = ["sil", "sp", "spn", ""]
    utt_start_time = tier[0].start_time
    utt_end_time = tier[-1].end_time
    phones = []
    durations = [] # NOTE includes durations of silences
    start_frames = []
    end_frames = []
    for i, t in enumerate(tier._objects):
        s, e, p = t.start_time, t.end_time, t.text
        if p not in sil_phones:
            phones.append(p)
            start_frames.append(int(np.ceil(s * sampling_rate / hop_length)))
            end_frames.append(int(np.ceil(e * sampling_rate / hop_length)))
            durations.append(int(np.ceil(e * sampling_rate / hop_length)
                                 - np.ceil(s * sampling_rate / hop_length)))
        else:
            if not ignore_all_pauses:
                if (i == 0) or (i == len(tier) - 1):
                    # leading or trailing silence
                    phones.append("sil")
                else:
                    # short pause between words
                    phones.append("sp")

    n_samples = utt_end_time * sampling_rate
    n_frames = n_samples / hop_length
    # fix occasional length mismatches at the end of utterances when
    # duration in samples is an integer multiple of hop_length
    if n_frames.is_integer():
        durations[-1] += 1
    return phones, durations, start_frames, end_frames, utt_start_time, utt_end_time

def save_to_disk(tensor, word, utt_id, count, output_directory):
    output_directory = os.path.join(output_directory, word)
    os.makedirs(output_directory, exist_ok=True)
    save_path = os.path.join(output_directory, f'{word}__{utt_id}__occ{count}__seqlen{tensor.size(0)}.pt')
    torch.save(tensor, save_path)

In [97]:
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--type', type=str, default='hubert',
                    help='type of input speech reps that we are using, i.e. hubert wav2vec2 etc.')
parser.add_argument('--padding_idx_offset', type=int, default=0,
                    help='add 1 to token id of discrete reps in order to allow for padding_idx==0')
parser.add_argument('-s', '--input_directory', type=str, required=True,
                    help='path to single non-nested folder containing speech representations (.pt files) or txt file (hubert)')
parser.add_argument('-a', '--alignments', type=str, required=True,
                    help='path to single non-nested folder containing MFA alignments (.TextGrid files)')
parser.add_argument('-o', '--output_directory', type=str, required=True,
                    help='where to write word-level data')
args = parser.parse_args()

if args.type == "hubert":
    with open(args.input_directory, 'r') as f:
        lines = f.readlines()
    num_of_utts = len(lines)
    utt_id2speechreps = {l.split('|')[0]:l.split('|')[1] for l in lines}
    utt_ids = sorted(utt_id2speechreps.keys()) # ensure we always process utts in same alphabetical order
elif args.type == "wav2vec2":
    num_of_utts = len(os.listdir(args.input_directory))
    utt_ids = sorted(file.split('.')[0] for file in os.listdir(args.input_directory))
elif args.type == "mel":
    utt_ids = list(sorted(file.split('.')[0] for file in os.listdir(args.input_directory)))
    utt_ids = utt_ids
    num_of_utts = len(utt_ids)
    utt_id2speechreps = {}
    print("loading mels from disk")
    for utt_id in tqdm(utt_ids):
        # load mel data
        p = os.path.join(args.input_directory, f'{utt_id}.pt')
        mel = torch.load(p).transpose(0,1) #[seqlen, feats]
        utt_id2speechreps[utt_id] = mel
else:
    raise ValueError(f"invalid input type {args.type}")

# sanity check
assert num_of_utts == len(os.listdir(args.alignments)), f"{num_of_utts}, {len(os.listdir(args.alignments))}"

longest_word = ''
longest_word_utt_id = ''
longest_word_num_frames = 0

# split each speech reps file using the word-level alignments
print("split speech reps using word alignments")
for utt_id in tqdm(utt_ids):
    # load speech reps
    if args.type == "hubert":
        reps = utt_id2speechreps[utt_id]
        reps = [int(s)+args.padding_idx_offset for s in reps.split(' ')] # NOTE add 1 to each index so that 0 is available as a padding_idx
        reps = torch.tensor(reps)
        reps.requires_grad = False

        # check dimensions
        if reps.dim() == 1:
            pass
        else:
            raise ValueError("speech representations have an incorrect number of dimensions")
    elif args.type == "mel":
        reps = utt_id2speechreps[utt_id]
        reps.requires_grad = False

        # check dimensions
        if reps.dim() == 2 and reps.size(1) == 80:
            pass
        else:
            raise ValueError("speech representations have an incorrect number of dimensions")
    else:
        raise ValueError(f"invalid input type {args.type}")

    tg_path = f"{args.alignments}/{utt_id}.TextGrid"
    tg = tgt.io.read_textgrid(tg_path, include_empty_intervals=True)
    words, all_durs, start_frames, end_frames, utt_start, utt_end = parse_textgrid(tg.get_tier_by_name('words'), SAMPLING_RATE, HOP_LENGTH)
    
    word_occ_in_utt_counter = Counter()
    mel = utt_id2speechreps[utt_id]
    assert mel.size(0) == sum(all_durs) # verify that MFA frame durations match up with the extracted mels
    for word, dur, start_frame, end_frame in zip(words, all_durs, start_frames, end_frames):
        for c in word:
            if c not in string.ascii_lowercase:
                print(f'WARNING: char {c} in word {word}')
        
        word_dur = end_frame - start_frame 
        if word_dur > longest_word_num_frames:
            longest_word_num_frames = word_dur
            longest_word = word
            longest_word_utt_id = utt_id
            
        word_occ_in_utt_counter[word] += 1
        wordaligned_mel = mel[start_frame:end_frame]
        extracted_timesteps = wordaligned_mel.size(0)
        assert dur == extracted_timesteps == word_dur, f"{dur=}, {extracted_timesteps=}, {word_dur=}"
        save_to_disk(wordaligned_mel, word, utt_id, word_occ_in_utt_counter[word], args.output_directory)

print("wordtype with longest num of timesteps is", longest_word, "from", longest_word_utt_id, "with len", longest_word_num_frames)
print("you can set transformer max_source_positions to this")

loading mels from disk


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 186/186 [00:00<00:00, 1714.08it/s]


AssertionError: 186, 125

In [None]:
tg_path = os.path.join(args.alignments, 'LJ001-0007.TextGrid')

In [None]:
tg = tgt.io.read_textgrid(tg_path, include_empty_intervals=True)

In [None]:
tg.get_tier_by_name('words')

In [None]:
words, durs, start_frames, end_frames, utt_start, utt_end = parse_textgrid(tg.get_tier_by_name('words'), SAMPLING_RATE, HOP_LENGTH)
print(words, durs, start_frames, end_frames, utt_start, utt_end)