In [8]:
docstring = """
(Use fastpitch conda env)

Helper script that takes a folder of speech reps (wav2vec2, mel-spec, etc.)
and aligns them at word-level using MFA alignments.

Speech reps corresponding to word tokens in the corpus are then saved individually to an output folder
with the following structure:
- data_path
    - word1
        - word1_LJ010-0292_001.pt
        - word1_LJ010-0292_002.pt
        - ...
    - word2
        - word2_LJ001-0012_001.pt
        - word2_LJ002-0024_001.pt
        - ...
    - ...

- word1, word2, ... subfolders refer to a particular wordtype in the corpus.
- .pt files contain speech representations that map to a particular example of a wordtype.
  It is named as:
    <wordtype>_<utt id>_<numbered occurrence in the utterance>.pt

Example usage:
    #hubert w/ padding offset
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t hubert \
        --padding_idx_offset 1 \
        -s /home/s1785140/fairseq/examples/lexicon_learner/lj_speech_quantized.txt \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_hubert_reps/hubert-base/layer-6/word_level_with_padding_idx_offset

    #hubert w/o padding offset
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t hubert \
        --padding_idx_offset 0 \
        -s /home/s1785140/fairseq/examples/lexicon_learner/lj_speech_quantized.txt \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_hubert_reps/hubert-base/layer-6/word_level_without_padding_idx_offset

    #wav2vec2
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t wav2vec2 \
        -s /home/s1785140/data/ljspeech_wav2vec2_reps/wav2vec2-large-960h/layer-15/utt_level \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_wav2vec2_reps/wav2vec2-large-960h/layer-15/word_level
"""

In [9]:
# imitate CLAs
import sys
sys.argv = [
    # fastpitch features
    # 'train.py',
    # '--type', 'mel',
    # '--utt_id_list', '/home/s1785140/data/ljspeech_fastpitch/respeller_uttids.txt', 
    # '--input_directory', '/home/s1785140/data/ljspeech_fastpitch/mels',
    # '--alignments', '/home/s1785140/data/ljspeech_fastpitch/aligns', 
    # '--output_directory', '/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels',

    # speechbrain features
    'train.py',
    '--type', 'mel',
    '--utt_id_list', '/home/s1785140/data/ljspeech_fastpitch/respeller_uttids.txt', 
    '--input_directory', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats',
    '--alignments', '/home/s1785140/data/ljspeech_fastpitch/aligns', 
    '--output_directory', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned',
    
    # FOR TESTING
    # '--input_directory', '/home/s1785140/data/ljspeech_fastpitch/mels_test',
    # '--alignments', '/home/s1785140/data/ljspeech_fastpitch/aligns_test', 
    # '--output_directory', '/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels_test',
]

In [10]:
import os
import argparse
import torch
from tqdm import tqdm
from collections import Counter
import numpy as np
import tgt
import string

SAMPLING_RATE = 22050
HOP_LENGTH = 256
SKIP_NON_ASCII = False
WORDS_TO_SKIP = ["wdsu-tv"]

In [11]:
def parse_textgrid(tier, sampling_rate, hop_length, ignore_all_pauses=True):
    # latest MFA replaces silence phones with "" in output TextGrids
    sil_phones = ["sil", "sp", "spn", ""]
    utt_start_time = tier[0].start_time
    utt_end_time = tier[-1].end_time
    phones = []
    durations = [] # NOTE includes durations of silences
    start_frames = []
    end_frames = []
    for i, t in enumerate(tier._objects):
        s, e, p = t.start_time, t.end_time, t.text
        if p not in sil_phones:
            phones.append(p)
            start_frames.append(int(np.ceil(s * sampling_rate / hop_length)))
            end_frames.append(int(np.ceil(e * sampling_rate / hop_length)))
            durations.append(int(np.ceil(e * sampling_rate / hop_length)
                                 - np.ceil(s * sampling_rate / hop_length)))
        else:
            if not ignore_all_pauses:
                if (i == 0) or (i == len(tier) - 1):
                    # leading or trailing silence
                    phones.append("sil")
                else:
                    # short pause between words
                    phones.append("sp")

    n_samples = utt_end_time * sampling_rate
    n_frames = n_samples / hop_length
    # fix occasional length mismatches at the end of utterances when
    # duration in samples is an integer multiple of hop_length
    if n_frames.is_integer():
        durations[-1] += 1
    return phones, durations, start_frames, end_frames, utt_start_time, utt_end_time

def save_to_disk(tensor, word, utt_id, count, output_directory):
    output_directory = os.path.join(output_directory, word)
    os.makedirs(output_directory, exist_ok=True)
    save_path = os.path.join(output_directory, f'{word}__{utt_id}__occ{count}__seqlen{tensor.size(0)}.pt')
    torch.save(tensor, save_path)
    
def allowed_word(word):
    if len(word) <= 1:
        return False
    if word == '--':
        return False
    return True

In [12]:
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--type', type=str, default='hubert',
                    help='type of input speech reps that we are using, i.e. hubert wav2vec2 etc.')
parser.add_argument('--padding_idx_offset', type=int, default=0,
                    help='add 1 to token id of discrete reps in order to allow for padding_idx==0')
parser.add_argument('--utt_id_list', type=str, required=False, default="",
                    help='path to text file that contains list of utterance ids that we extract from')
parser.add_argument('-s', '--input_directory', type=str, required=True,
                    help='path to single non-nested folder containing speech representations (.pt files) or txt file (hubert)')
parser.add_argument('-a', '--alignments', type=str, required=True,
                    help='path to single non-nested folder containing MFA alignments (.TextGrid files)')
parser.add_argument('-o', '--output_directory', type=str, required=True,
                    help='where to write word-level data')
args = parser.parse_args()

if "speechbrain" in args.input_directory:
    args.transpose_mel = False
elif "ljspeech" in args.input_directory:
    args.transpose_mel = True

# load speech reps

In [16]:
cuda = torch.cuda.is_available()
map_location = 'cuda' if cuda else 'cpu'

if args.type == "hubert":
    with open(args.input_directory, 'r') as f:
        lines = f.readlines()
    num_of_utts = len(lines)
    utt_id2speechreps = {l.split('|')[0]:l.split('|')[1] for l in lines}
    utt_ids = sorted(utt_id2speechreps.keys()) # ensure we always process utts in same alphabetical order
elif args.type == "wav2vec2":
    num_of_utts = len(os.listdir(args.input_directory))
    utt_ids = sorted(file.split('.')[0] for file in os.listdir(args.input_directory))
elif args.type == "mel":
    if args.utt_id_list:
        # we specified a subset of utt ids
        with open(args.utt_id_list, 'r') as f:
            utt_ids = f.read().splitlines()
    else:
        # all files in directory
        utt_ids = list(sorted(file.split('.')[0] for file in os.listdir(args.input_directory)))
    num_of_utts = len(utt_ids)
    utt_id2speechreps = {}
    print(f"loading mels from disk in dir {args.input_directory} for {len(utt_ids)} utts")
    for utt_id in tqdm(utt_ids):
        # load mel data
        mel_path = os.path.join(args.input_directory, f'{utt_id}.pt')
        mel = torch.load(mel_path, map_location=map_location)
        if args.transpose_mel:
            mel = mel.transpose(0,1)

        # mels should be of shape (T, D) now
        utt_id2speechreps[utt_id] = mel
else:
    raise ValueError(f"invalid input type {args.type}")

# sanity check - assert that each utt has a corresponding alignment
alignment_files = set(os.listdir(args.alignments))
for utt_id in utt_ids:
    assert f"{utt_id}.TextGrid" in alignment_files

loading mels from disk in dir /home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats for 6551 utts


  4%|▍         | 269/6551 [00:30<11:46,  8.89it/s]


KeyboardInterrupt: 

# perform splitting of mel specs using MFA alignments and save to disk

In [None]:
longest_word = ''
longest_word_utt_id = ''
longest_word_num_frames = 0

# split each speech reps file using the word-level alignments
print("split speech reps using word alignments")
for utt_id in tqdm(utt_ids):
    # load speech reps
    if args.type == "hubert":
        reps = utt_id2speechreps[utt_id]
        reps = [int(s)+args.padding_idx_offset for s in reps.split(' ')] # NOTE add 1 to each index so that 0 is available as a padding_idx
        reps = torch.tensor(reps)
        reps.requires_grad = False

        # check dimensions
        if reps.dim() == 1:
            pass
        else:
            raise ValueError("speech representations have an incorrect number of dimensions")
    elif args.type == "mel":
        reps = utt_id2speechreps[utt_id]
        reps.requires_grad = False

        # check dimensions
        if reps.dim() == 2 and reps.size(1) == 80:
            pass
        else:
            raise ValueError("speech representations have an incorrect number of dimensions")
    else:
        raise ValueError(f"invalid input type {args.type}")

    tg_path = f"{args.alignments}/{utt_id}.TextGrid"
    tg = tgt.io.read_textgrid(tg_path, include_empty_intervals=True)
    words, all_durs, start_frames, end_frames, utt_start, utt_end = parse_textgrid(tg.get_tier_by_name('words'), SAMPLING_RATE, HOP_LENGTH)
    
    word_occ_in_utt_counter = Counter()
    mel = utt_id2speechreps[utt_id]
    # assert mel.size(0) == sum(all_durs), f"{mel.size(0)=} != {sum(all_durs)=}" # verify that MFA frame durations match up with the extracted mels
    for word, dur, start_frame, end_frame in zip(words, all_durs, start_frames, end_frames):
        if allowed_word(word):
            skip_word = False
            normalise_non_ascii = False
            for c in word:
                if c not in string.ascii_lowercase:
                    s = f'WARNING: char {c} in word {word}'
                    if SKIP_NON_ASCII or word in WORDS_TO_SKIP:
                        s += '. skipping!...'
                        skip_word = True
                    else:
                        normalise_non_ascii = True

                    print(s)
                    
            if not skip_word:
                if normalise_non_ascii: # normalise word
                    prenorm_word = word
                    # remove trailing '-'
                    word = word.rstrip('-')
                    # convert diacritics to ascii
                    word = unidecode.unidecode(word)
                    print(f"\tnormalised '{prenorm_word}' to '{word}'")
                
                # check if word is the longest word we have seen so far
                word_dur = end_frame - start_frame 
                if word_dur > longest_word_num_frames:
                    longest_word_num_frames = word_dur
                    longest_word = word
                    longest_word_utt_id = utt_id

                # extract mel
                wordaligned_mel = mel[start_frame:end_frame]

                # save extracted mel to disk
                word_occ_in_utt_counter[word] += 1
                extracted_timesteps = wordaligned_mel.size(0)
                assert dur == extracted_timesteps == word_dur, f"{dur=}, {extracted_timesteps=}, {word_dur=}"
                save_to_disk(wordaligned_mel, word, utt_id, word_occ_in_utt_counter[word], args.output_directory)

print("wordtype with longest num of timesteps is", longest_word, "from", longest_word_utt_id, "with len", longest_word_num_frames)
print("you can set transformer max_source_positions to this")

split speech reps using word alignments


 29%|████████████████████████████████████████▌                                                                                                     | 1870/6551 [06:35<07:15, 10.76it/s]

	normalised 'grâce' to 'grace'


 33%|██████████████████████████████████████████████▎                                                                                               | 2134/6551 [07:35<10:10,  7.23it/s]

	normalised 'habitué' to 'habitue'


 44%|██████████████████████████████████████████████████████████████▋                                                                               | 2894/6551 [10:01<08:14,  7.40it/s]

	normalised 'dêtre' to 'detre'


 44%|██████████████████████████████████████████████████████████████▉                                                                               | 2905/6551 [10:03<07:25,  8.19it/s]

	normalised 'müllers' to 'mullers'


 44%|███████████████████████████████████████████████████████████████▏                                                                              | 2913/6551 [10:04<07:07,  8.51it/s]

	normalised 'müller' to 'muller'
	normalised 'müller' to 'muller'


 45%|███████████████████████████████████████████████████████████████▎                                                                              | 2919/6551 [10:04<07:22,  8.22it/s]

	normalised 'müller' to 'muller'


 45%|███████████████████████████████████████████████████████████████▍                                                                              | 2925/6551 [10:05<07:18,  8.26it/s]

	normalised 'müller' to 'muller'


 49%|█████████████████████████████████████████████████████████████████████▎                                                                        | 3198/6551 [10:50<09:42,  5.75it/s]

	normalised 'müller' to 'muller'
	normalised 'müller' to 'muller'


 49%|█████████████████████████████████████████████████████████████████████▎                                                                        | 3200/6551 [10:50<13:00,  4.29it/s]

	normalised 'müller' to 'muller'


 52%|██████████████████████████████████████████████████████████████████████████▏                                                                   | 3421/6551 [11:37<11:40,  4.47it/s]

	normalised 'célèbre' to 'celebre'


 57%|█████████████████████████████████████████████████████████████████████████████████▎                                                            | 3752/6551 [12:33<04:30, 10.34it/s]

	normalised 'forward-' to 'forward'


 58%|██████████████████████████████████████████████████████████████████████████████████▌                                                           | 3811/6551 [12:44<08:05,  5.64it/s]

	normalised 'self-' to 'self'


 58%|██████████████████████████████████████████████████████████████████████████████████▊                                                           | 3823/6551 [12:46<05:41,  7.99it/s]

	normalised 'vice-' to 'vice'


 73%|███████████████████████████████████████████████████████████████████████████████████████████████████████▏                                      | 4762/6551 [15:30<03:03,  9.76it/s]

	normalised 'full-' to 'full'


 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████                                 | 5034/6551 [16:20<03:46,  6.69it/s]

	normalised 'gray-' to 'gray'


 91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊             | 5942/6551 [18:53<00:48, 12.62it/s]



100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6551/6551 [20:11<00:00,  5.41it/s]

wordtype with longest num of timesteps is anesthesiologists from LJ031-0023 with len 152
you can set transformer max_source_positions to this





# create train,dev,test datasplits for training respeller

We hold out WORDTYPES from training for the dev and test splits

## Random

In [None]:
import random
import json

random.seed(1337)

train_ratio, dev_ratio, test_ratio = [0.9, 0.05, 0.05]

In [None]:
# get oov wordtypes list (words that are not seen in tts training)
oov_wordlist_path = '/home/s1785140/data/ljspeech_fastpitch/oov_list.json'
with open(oov_wordlist_path, 'r') as f:
    oovs_and_freqs = json.load(f)
    
all_wordtypes = set(oovs_and_freqs.keys())
print(f'original before cleaning/sampling {len(all_wordtypes)=}')

original before cleaning/sampling len(all_wordtypes)=8343


In [None]:
# clean/remove words that do not have speech reps
words_with_aligned_mels = set(os.listdir(args.output_directory))
words_no_mels = all_wordtypes - words_with_aligned_mels
print(f'{len(words_no_mels)}=')

34=


In [None]:
print("list of words to be excluded from respeller training as they do not have mels (likely due to how normalisation is different between mfa and our own data processing):")
words_no_mels

list of words to be excluded from respeller training as they do not have mels (likely due to how normalisation is different between mfa and our own data processing):


{'aaa',
 'cos',
 'dc',
 'eg',
 'eightthirty',
 'elevenfifty',
 'eleventhirty',
 'fivefifty',
 'fourfifty',
 'fourforty',
 'fourthirty',
 'iq',
 'k',
 'lj',
 'lld',
 'mps',
 'ninethirty',
 'onefifteen',
 'onefifty',
 'oneforty',
 'ps',
 'sevenfifteen',
 'seventhirty',
 'sixthirty',
 'tenforty',
 'tenthirty',
 'threetwenty',
 'tv',
 'twelvefifteen',
 'twofifteen',
 'twoforty',
 'twothirty',
 'u',
 'uss'}

In [None]:
# remove these problematic words from respeller training dev test
for w in words_no_mels:
    del oovs_and_freqs[w]
    
all_wordtypes = set(oovs_and_freqs.keys())
print(f'original after cleaning {len(all_wordtypes)=}')

dev_N = int(dev_ratio * len(all_wordtypes))
test_N = int(test_ratio * len(all_wordtypes))

original after cleaning len(all_wordtypes)=8309


In [None]:
def sample_and_remove(s: set, N: int):
    """sample N words from set s
    then remove these words from the set"""
    sampled = random.sample(s, N)
    for item in sampled:
        s.remove(item)
    return set(sampled)

In [None]:
#get dev and test splits
oov_singletons = set(wordtype for wordtype, freq in oovs_and_freqs.items() if freq == 1)
assert len(oov_singletons) > dev_N + test_N, "not enough OOV singletons to create dev and test sets" 
print(f'before sampling dev and test {len(oov_singletons)=}')

dev = sample_and_remove(oov_singletons, dev_N)
print(f'after sampling dev {len(oov_singletons)=}, {len(dev)=}')

test = sample_and_remove(oov_singletons, test_N)
print(f'after sampling test {len(oov_singletons)=}, {len(test)=}')

before sampling dev and test len(oov_singletons)=5440
after sampling dev len(oov_singletons)=5025, len(dev)=415
after sampling test len(oov_singletons)=4610, len(test)=415


In [None]:
list(dev)[:10]

['sumpter',
 'rumor',
 'esther',
 'depressed',
 'violins',
 'apprise',
 'summarize',
 'adelphi',
 'sighing',
 'entreating']

In [None]:
list(test)[:10]

['yell',
 'invaded',
 'oblivion',
 'punches',
 'divide',
 'permits',
 'facilitating',
 'resurrection',
 'cashier',
 'delicacy']

In [None]:
#get train split
print(f'before removing dev and test wordtypes {len(all_wordtypes)=}')
for word in dev | test:
    all_wordtypes.remove(word)
print(f'after removing dev and test wordtypes {len(all_wordtypes)=}')

train = set(all_wordtypes)

before removing dev and test wordtypes len(all_wordtypes)=8309
after removing dev and test wordtypes len(all_wordtypes)=7479


In [None]:
# sanity checks
assert len(dev.intersection(test)) == 0
assert len(train.intersection(dev)) == 0
assert len(train.intersection(test)) == 0
print("Good! No overlapping words between train, dev, and test!!!")

Good! No overlapping words between train, dev, and test!!!


In [None]:
# write to disk
def save_wordlist(path, words):
    with open(path, 'w') as f:
        json.dump(sorted(list(words)), f, indent=4)
        
train_path = '/home/s1785140/data/ljspeech_fastpitch/respeller_train_words.json'
dev_path = '/home/s1785140/data/ljspeech_fastpitch/respeller_dev_words.json'
test_path = '/home/s1785140/data/ljspeech_fastpitch/respeller_test_words.json'

save_wordlist(train_path, train)
save_wordlist(dev_path, dev)
save_wordlist(test_path, test)

## G2P selection