In [75]:
docstring = """
(Use fastpitch conda env)

Helper script that takes a folder of speech reps (wav2vec2, mel-spec, etc.)
and aligns them at word-level using MFA alignments.

Speech reps corresponding to word tokens in the corpus are then saved individually to an output folder
with the following structure:
- data_path
    - word1
        - word1_LJ010-0292_001.pt
        - word1_LJ010-0292_002.pt
        - ...
    - word2
        - word2_LJ001-0012_001.pt
        - word2_LJ002-0024_001.pt
        - ...
    - ...

- word1, word2, ... subfolders refer to a particular wordtype in the corpus.
- .pt files contain speech representations that map to a particular example of a wordtype.
  It is named as:
    <wordtype>_<utt id>_<numbered occurrence in the utterance>.pt

Example usage:
    #hubert w/ padding offset
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t hubert \
        --padding_idx_offset 1 \
        -s /home/s1785140/fairseq/examples/lexicon_learner/lj_speech_quantized.txt \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_hubert_reps/hubert-base/layer-6/word_level_with_padding_idx_offset

    #hubert w/o padding offset
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t hubert \
        --padding_idx_offset 0 \
        -s /home/s1785140/fairseq/examples/lexicon_learner/lj_speech_quantized.txt \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_hubert_reps/hubert-base/layer-6/word_level_without_padding_idx_offset

    #wav2vec2
    cd ~/fairseq
    python examples/lexicon_learner/wordalign_speechreps.py \
        -t wav2vec2 \
        -s /home/s1785140/data/ljspeech_wav2vec2_reps/wav2vec2-large-960h/layer-15/utt_level \
        -a /home/s1785140/data/ljspeech_MFA_alignments \
        -o /home/s1785140/data/ljspeech_wav2vec2_reps/wav2vec2-large-960h/layer-15/word_level
"""

Command line args

In [76]:
# imitate CLAs
import sys
sys.argv = [
    # fastpitch features
    # 'train.py',
    # '--type', 'mel',
    # '--utt_id_list', '/home/s1785140/data/ljspeech_fastpitch/respeller_uttids.txt', 
    # '--input_directory', '/home/s1785140/data/ljspeech_fastpitch/mels',
    # '--alignments', '/home/s1785140/data/ljspeech_fastpitch/aligns', 
    # '--output_directory', '/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels',

    # speechbrain features
    'train.py',
    '--type', 'mel',
    '--utt_id_list', '/home/s1785140/data/ljspeech_fastpitch/respeller_uttids.txt', 
    '--input_directory', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats',
    '--alignments', '/home/s1785140/data/ljspeech_fastpitch/aligns', 
    '--output_directory', '/home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned',
    
    # FOR TESTING
    # '--input_directory', '/home/s1785140/data/ljspeech_fastpitch/mels_test',
    # '--alignments', '/home/s1785140/data/ljspeech_fastpitch/aligns_test', 
    # '--output_directory', '/home/s1785140/data/ljspeech_fastpitch/wordaligned_mels_test',
]

In [77]:
# install package from notebook
import sys
!{sys.executable} -m pip install nltk



# imports and globals

In [78]:
import os
import argparse
import torch
from tqdm import tqdm
from collections import Counter
import numpy as np
import tgt
import string
import librosa
import glob
import random
from IPython.display import Audio
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')

SKIP_NON_ASCII = False
WORDS_TO_SKIP = ["wdsu-tv"]

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/s1785140/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


# Parser

In [79]:
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--type', type=str, default='hubert',
                    help='type of input speech reps that we are using, i.e. hubert wav2vec2 etc.')
parser.add_argument('--padding_idx_offset', type=int, default=0,
                    help='add 1 to token id of discrete reps in order to allow for padding_idx==0')
parser.add_argument('--utt_id_list', type=str, required=False, default="",
                    help='path to text file that contains list of utterance ids that we extract from')
parser.add_argument('-s', '--input_directory', type=str, required=True,
                    help='path to single non-nested folder containing speech representations (.pt files) or txt file (hubert)')
parser.add_argument('-a', '--alignments', type=str, required=True,
                    help='path to single non-nested folder containing MFA alignments (.TextGrid files)')
parser.add_argument('-o', '--output_directory', type=str, required=True,
                    help='where to write word-level data')
args = parser.parse_args()

if "speechbrain" in args.input_directory:
    args.corpus_name = "speechbrain"
    args.transpose_mel = False
    SAMPLING_RATE = 16000 # hz
    HOP_LENGTH_IN_MS = 10 # in ms
    WIN_LENGTH_IN_MS = 25 # in ms
    HOP_LENGTH = int(HOP_LENGTH_IN_MS * SAMPLING_RATE / 1000) # convert HOP_LENGTH to samples
    WIN_LENGTH = int(WIN_LENGTH_IN_MS * SAMPLING_RATE / 1000) # convert WIN_LENGTH to samples
elif "ljspeech" in args.input_directory:
    args.corpus_name = "ljspeech"
    args.transpose_mel = True
    SAMPLING_RATE = 22050 # hz
    HOP_LENGTH = 256 # in samples
    
args.max_utts_to_generate = 25

# Functions

In [80]:
def save_to_disk(tensor, word, utt_id, count, output_directory):
    output_directory = os.path.join(output_directory, word)
    os.makedirs(output_directory, exist_ok=True)
    save_path = os.path.join(output_directory, f'{word}__{utt_id}__occ{count}.pt')
    torch.save(tensor, save_path)
    
def allowed_word(word):
    if len(word) <= 1:
        return False
    if word == '--':
        return False
    return True

# load speech reps

In [81]:
debug = True
if debug:
    # load fewer utts
    max_utts = args.max_utts_to_generate
else:
    max_utts = None

    

cuda = torch.cuda.is_available()
map_location = 'cuda' if cuda else 'cpu'

if args.type == "hubert":
    with open(args.input_directory, 'r') as f:
        lines = f.readlines()
    num_of_utts = len(lines)
    utt_id2speechreps = {l.split('|')[0]:l.split('|')[1] for l in lines}
    utt_ids = sorted(utt_id2speechreps.keys()) # ensure we always process utts in same alphabetical order
elif args.type == "wav2vec2":
    num_of_utts = len(os.listdir(args.input_directory))
    utt_ids = sorted(file.split('.')[0] for file in os.listdir(args.input_directory))
elif args.type == "mel":
    if args.utt_id_list:
        # we specified a subset of utt ids
        with open(args.utt_id_list, 'r') as f:
            utt_ids = f.read().splitlines()
    else:
        # all files in directory
        utt_ids = list(sorted(file.split('.')[0] for file in os.listdir(args.input_directory)))
    num_of_utts = len(utt_ids)
    utt_ids = utt_ids[:max_utts]
    utt_id2speechreps = {}
    print(f"loading mels from disk in dir {args.input_directory} for {len(utt_ids)} utts")
    for utt_id in tqdm(utt_ids):
        # load mel data
        mel_path = os.path.join(args.input_directory, f'{utt_id}.pt')
        mel = torch.load(mel_path, map_location=map_location)
        if args.transpose_mel:
            mel = mel.transpose(0,1)

        # mels should be of shape (T, D) now
        utt_id2speechreps[utt_id] = mel
else:
    raise ValueError(f"invalid input type {args.type}")

# sanity check - assert that each utt has a corresponding alignment
alignment_files = set(os.listdir(args.alignments))
for utt_id in utt_ids:
    assert f"{utt_id}.TextGrid" in alignment_files

loading mels from disk in dir /home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats for 25 utts


100%|██████████| 25/25 [00:03<00:00,  8.18it/s]


# perform splitting of mel specs using MFA alignments and save to disk

In [82]:
def parse_textgrid(tier, sampling_rate, hop_length, ignore_all_pauses=True):
    # latest MFA replaces silence phones with "" in output TextGrids
    sil_phones = ["sil", "sp", "spn", ""]
    utt_start_time = tier[0].start_time
    utt_end_time = tier[-1].end_time
    phones = []
    durations = [] # NOTE includes durations of silences
    start_frames = []
    end_frames = []
    for i, t in enumerate(tier._objects):
        s, e, p = t.start_time, t.end_time, t.text
        if p not in sil_phones:
            phones.append(p)
            start_frames.append(int(np.ceil(s * sampling_rate / hop_length)))
            end_frames.append(int(np.ceil(e * sampling_rate / hop_length)))
            durations.append(int(np.ceil(e * sampling_rate / hop_length)
                                 - np.ceil(s * sampling_rate / hop_length)))
        else:
            if not ignore_all_pauses:
                if (i == 0) or (i == len(tier) - 1):
                    # leading or trailing silence
                    phones.append("sil")
                else:
                    # short pause between words
                    phones.append("sp")

    n_samples = utt_end_time * sampling_rate
    n_frames = n_samples / hop_length
    # fix occasional length mismatches at the end of utterances when
    # duration in samples is an integer multiple of hop_length
    if n_frames.is_integer():
        durations[-1] += 1
    return phones, durations, start_frames, end_frames, utt_start_time, utt_end_time

In [83]:
import math
def extract_reprs_with_timestamps(total_num_frames, start_time, end_time, utt_duration):
    """
    extract subsequence of 'repr' that corresponds to a particular word
    function expects input to be of dimension 2: (timesteps, hidden_size)
    """
    start_fraction = start_time / utt_duration
    end_fraction = end_time / utt_duration
    # start_idx = math.floor(start_fraction * total_num_frames)
    # end_idx = math.ceil(end_fraction * total_num_frames)
    start_idx = int(start_fraction * total_num_frames)
    end_idx = int(end_fraction * total_num_frames)
    # start_idx = math.floor(start_fraction * total_num_frames)
    # end_idx = math.ceil(end_fraction * total_num_frames)
    num_frames = end_idx - start_idx
    return start_idx, end_idx, num_frames

def parse_textgrid2(tier, mel_spectrogram, ignore_all_pauses=True):
    # latest MFA replaces silence phones with "" in output TextGrids
    total_num_frames = mel_spectrogram.size(0)
    sil_phones = ["sil", "sp", "spn", ""]
    utt_start_time = tier[0].start_time
    utt_end_time = tier[-1].end_time
    words = []
    word_durations = [] 
    sil_durations = []
    start_frames = []
    end_frames = []
    for i, t in enumerate(tier._objects):
        start, end, token = t.start_time, t.end_time, t.text
        start_idx, end_idx, num_frames = extract_reprs_with_timestamps(total_num_frames, start, end, utt_end_time)
    
        if token not in sil_phones:
            words.append(token)
            start_frames.append(start_idx)
            end_frames.append(end_idx)
            word_durations.append(num_frames)
        else:
            sil_durations.append(num_frames)

    return words, word_durations, sil_durations, start_frames, end_frames, utt_start_time, utt_end_time

In [84]:
longest_word = ''
longest_word_utt_id = ''
longest_word_num_frames = 0
buffer_frames = 0

# glob recursively number of torch tensor files in output directory
orig_num_files = len(glob.glob(os.path.join(args.output_directory, '**', '*.pt'), recursive=True))

# split each speech reps file using the word-level alignments
print("split speech reps using word alignments")
for utt_id in tqdm(utt_ids):
    # load speech reps
    if args.type == "hubert":
        reps = utt_id2speechreps[utt_id]
        reps = [int(s)+args.padding_idx_offset for s in reps.split(' ')] # NOTE add 1 to each index so that 0 is available as a padding_idx
        reps = torch.tensor(reps)
        reps.requires_grad = False

        # check dimensions
        if reps.dim() == 1:
            pass
        else:
            raise ValueError("speech representations have an incorrect number of dimensions")
    elif args.type == "mel":
        reps = utt_id2speechreps[utt_id]
        reps.requires_grad = False

        # check dimensions
        if reps.dim() == 2:
            if args.corpus_name == 'speechbrain' and reps.size(1) == 40:
                pass
            elif args.corpus_name == 'ljspeech' and reps.size(1) == 80:
                pass
            else:
                raise ValueError(f"feat dimension is wrong size for corpus {reps.size(1)=}")    
        else:
            raise ValueError(f"speech representations have an incorrect number of dimensions {reps.dim()=}")
    else:
        raise ValueError(f"invalid input type {args.type}")

    tg_path = f"{args.alignments}/{utt_id}.TextGrid"
    tg = tgt.io.read_textgrid(tg_path, include_empty_intervals=True)
    # words, word_durs, start_frames, end_frames, utt_start, utt_end = parse_textgrid(tg.get_tier_by_name('words'), SAMPLING_RATE, HOP_LENGTH)
    words, word_durs, sil_durations, start_frames, end_frames, utt_start, utt_end = parse_textgrid2(tg.get_tier_by_name('words'), reps, ignore_all_pauses=True)
    
    word_occ_in_utt_counter = Counter()
    mel = utt_id2speechreps[utt_id]
    assert mel.size(0) == (sum(word_durs) + sum(sil_durations)), f"{mel.size(0)=} != {sum(word_durs)=} + {sum(sil_durations)=}" # verify that MFA frame durations match up with the extracted mels
    # print(f"{mel.size(0)=} == {sum(word_durs)=} + {sum(sil_durations)=}")

    for j, (word, dur, start_frame, end_frame) in enumerate(zip(words, word_durs, start_frames, end_frames)):
        if allowed_word(word):
            skip_word = False
            normalise_non_ascii = False
            for c in word:
                if c not in string.ascii_lowercase:
                    s = f'WARNING: char {c} in word {word}'
                    if SKIP_NON_ASCII or word in WORDS_TO_SKIP:
                        s += '. skipping!...'
                        skip_word = True
                    else:
                        normalise_non_ascii = True

                    print(s)
                    
            if not skip_word:
                if normalise_non_ascii: # normalise word
                    prenorm_word = word
                    # remove trailing '-'
                    word = word.rstrip('-')
                    # convert diacritics to ascii
                    word = unidecode.unidecode(word)
                    print(f"\tnormalised '{prenorm_word}' to '{word}'")
                
                # check if word is the longest word we have seen so far
                word_dur = end_frame - start_frame 
                if word_dur > longest_word_num_frames:
                    longest_word_num_frames = word_dur
                    longest_word = word
                    longest_word_utt_id = utt_id

                # extract mel
                a = max(0, start_frame - buffer_frames)
                b = min(mel.size(0), end_frame + buffer_frames)
                # print(mel.size(0), start_frame, end_frame, a, b)
                wordaligned_mel = mel[a:b]

                # save extracted mel to disk
                word_occ_in_utt_counter[word] += 1
                extracted_timesteps = wordaligned_mel.size(0)
                # print(j, f"{dur=}, {extracted_timesteps=}, {word_dur=}")
                assert dur == extracted_timesteps - (start_frame - a) - (b - end_frame) == word_dur, f"{dur=}, {extracted_timesteps=}, {word_dur=}"
                save_to_disk(wordaligned_mel, word, utt_id, word_occ_in_utt_counter[word], args.output_directory)

print(f"wordtype with longest num of timesteps is '{longest_word}' from", longest_word_utt_id, "with len", longest_word_num_frames)
print("you can set transformer max_source_positions to this")

new_num_files = len(glob.glob(os.path.join(args.output_directory, '**', '*.pt'), recursive=True))

print(f"Added {new_num_files - orig_num_files} files to {args.output_directory}, now contains {new_num_files} files, used to contain {orig_num_files} files")

split speech reps using word alignments


100%|██████████| 25/25 [00:40<00:00,  1.63s/it]


wordtype with longest num of timesteps is 'surpassed' from LJ001-0008 with len 407
you can set transformer max_source_positions to this
Added 289 files to /home/s1785140/speechbrain/templates/speech_recognition_CharTokens_NoLM/data/ljspeech_dumped_feats_word_aligned, now contains 454 files, used to contain 165 files


## sanity check alignments by generating wordaligned spectrograms using griffin-lim

In [85]:
# glob pytorch tensors from nested folders in output directory
mel_paths = glob.glob(f'{args.output_directory}/**/*.pt', recursive=True)

# load mels into list
mels = []
words = []
for mel_path in tqdm(mel_paths):
    mel = torch.load(mel_path)
    mels.append(mel)

    # also get word from path
    word = mel_path.split('/')[-2]
    words.append(word)

100%|██████████| 454/454 [00:09<00:00, 46.06it/s]


In [86]:
def griffin_lim_synthesise(mel, n_iter=100):
    """Synthesises audio from mel spectrogram using Griffin-Lim algorithm.
    Args:
        mel (torch.Tensor): Mel spectrogram (B, C, T).
        n_iter (int): Number of iterations for Griffin-Lim algorithm.
    Returns:
        torch.Tensor: Audio waveform (B, T).
    """
    mel = mel.detach().cpu().numpy()
    mel = librosa.feature.inverse.mel_to_audio(
        mel, sr=SAMPLING_RATE, n_fft=400, hop_length=HOP_LENGTH, win_length=WIN_LENGTH,
        window='hamming', center=True, pad_mode='constant', power=1.0, n_iter=n_iter,
        )
    mel = torch.from_numpy(mel).float()
    return mel

def reshape_mel_for_librosa(mel):
    mel = mel.unsqueeze(0) # make batch dimension
    mel = mel.transpose(1, 2)
    return mel

def synthesise_and_play_Audio(mel, n_iter=100):
    mel = reshape_mel_for_librosa(mel)
    audio = griffin_lim_synthesise(mel, n_iter)
    return Audio(audio, rate=SAMPLING_RATE)

## listen to generated audio for each word

In [87]:
# get list of wordnet nltk stopwords
stop_words = set(stopwords.words('english'))
print(f"{len(stop_words)=}")

# add to stop_words all letters in alphabet, since tokenisation might create these edge cases 
# for example, "wasn't" might get tokenized as "wasn" and "t"
stop_words.update(string.ascii_lowercase)
print(f"{len(stop_words)=}")
print(sorted(stop_words))

# words_to_synth = "stopwords"
words_to_synth = "functionwords"

tuples = list(zip(mels, words))
# filter out words that are not in the wordnet nltk stopwords list
if words_to_synth == "stopwords":
    tuples = [t for t in tuples if t[1] in stop_words]
# filter out words that are not in the function word list
elif words_to_synth == "functionwords":
    tuples = [t for t in tuples if t[1] not in stop_words]

len(stop_words)=179
len(stop_words)=197
['a', 'about', 'above', 'after', 'again', 'against', 'ain', 'all', 'am', 'an', 'and', 'any', 'are', 'aren', "aren't", 'as', 'at', 'b', 'be', 'because', 'been', 'before', 'being', 'below', 'between', 'both', 'but', 'by', 'c', 'can', 'couldn', "couldn't", 'd', 'did', 'didn', "didn't", 'do', 'does', 'doesn', "doesn't", 'doing', 'don', "don't", 'down', 'during', 'e', 'each', 'f', 'few', 'for', 'from', 'further', 'g', 'h', 'had', 'hadn', "hadn't", 'has', 'hasn', "hasn't", 'have', 'haven', "haven't", 'having', 'he', 'her', 'here', 'hers', 'herself', 'him', 'himself', 'his', 'how', 'i', 'if', 'in', 'into', 'is', 'isn', "isn't", 'it', "it's", 'its', 'itself', 'j', 'just', 'k', 'l', 'll', 'm', 'ma', 'me', 'mightn', "mightn't", 'more', 'most', 'mustn', "mustn't", 'my', 'myself', 'n', 'needn', "needn't", 'no', 'nor', 'not', 'now', 'o', 'of', 'off', 'on', 'once', 'only', 'or', 'other', 'our', 'ours', 'ourselves', 'out', 'over', 'own', 'p', 'q', 'r', 're', 's

In [91]:
# NOTE these might sound very bad for two reasons
# 1. stft -> mel spec is not lossless
# 2. (speechbrain) ASR features are fewer bins (i.e. 40 rather than 80 mel spec bins)
# but still good enough for sanity checking alignments!
NUM_TO_LISTEN = None
# NUM_TO_LISTEN = 20

# generate wavs from mels and then concatenate them into a single wav separated by silences and play in notebook
# NOTE this is not lossless, but still good enough for sanity checking alignments!
def generate_wav_from_mel(mel, n_iter=100):
    mel = reshape_mel_for_librosa(mel)
    audio = griffin_lim_synthesise(mel, n_iter)
    audio = audio.squeeze(0)
    audio = audio.numpy()
    return audio

def generate_wav_from_mels_with_silences(mels, n_iter=100, silence_duration=0.5):
    wavs = []
    for mel in mels:
        wavs.append(generate_wav_from_mel(mel, n_iter))
        wavs.append(np.zeros(int(silence_duration * SAMPLING_RATE)))
    wav = np.concatenate(wavs)
    return wav

def generate_wav_from_mels_with_silences_and_play(mels, n_iter=100, silence_duration=0.5):
    wav = generate_wav_from_mels_with_silences(mels, n_iter, silence_duration)
    return Audio(wav, rate=SAMPLING_RATE)

random.shuffle(tuples)
mels_to_gen = [t[0] for t in tuples[:NUM_TO_LISTEN]]
words_to_gen = [t[1] for t in tuples[:NUM_TO_LISTEN]]

# sort mels and words by alphabetical order of words
mels_to_gen, words_to_gen = zip(*sorted(zip(mels_to_gen, words_to_gen), key=lambda x: x[1]))

print("Generating audio for following words:", words_to_gen)
generate_wav_from_mels_with_silences_and_play(mels_to_gen, silence_duration=0.4)

Generating audio for following words: ('actually', 'aforesaid', 'ages', 'ages', 'also', 'although', 'always', 'arrangement', 'art', 'arts', 'basle', 'beautiful', 'beautiful', 'beautiful', 'beautiful', 'began', 'bible', 'bible', 'bible', 'birth', 'block', 'blocks', 'book', 'book', 'book', 'book', 'books', 'books', 'books', 'brought', 'called', 'calligraphy', 'care', 'case', 'casting', 'centuries', 'century', 'century', 'century', 'certainly', 'character', 'chinese', 'cities', 'come', 'comparatively', 'composed', 'concerned', 'considered', 'considered', 'consist', 'cost', 'course', 'crafts', 'craftsmen', 'dated', 'differs', 'earliest', 'earliest', 'easier', 'ecclesiastical', 'eleventh', 'engraved', 'especially', 'especially', 'etc', 'even', 'exact', 'example', 'exceedingly', 'exhibition', 'fact', 'far', 'fifteen', 'fifteenth', 'fifteenth', 'fifty', 'fine', 'first', 'five', 'five', 'form', 'form', 'formal', 'forms', 'forty', 'fourteen', 'fourteen', 'fourteen', 'france', 'freer', 'germany'

In [89]:
raise ValueError("stop before creating datasplits")

ValueError: stop before creating datasplits

# create train,dev,test datasplits for training respeller

We hold out WORDTYPES from training for the dev and test splits

## Random

In [None]:
import random
import json

random.seed(1337)

train_ratio, dev_ratio, test_ratio = [0.9, 0.05, 0.05]

In [None]:
# get oov wordtypes list (words that are not seen in tts training)
oov_wordlist_path = '/home/s1785140/data/ljspeech_fastpitch/oov_list.json'
with open(oov_wordlist_path, 'r') as f:
    oovs_and_freqs = json.load(f)
    
all_wordtypes = set(oovs_and_freqs.keys())
print(f'original before cleaning/sampling {len(all_wordtypes)=}')

In [None]:
# clean/remove words that do not have speech reps
words_with_aligned_mels = set(os.listdir(args.output_directory))
words_no_mels = all_wordtypes - words_with_aligned_mels
print(f'{len(words_no_mels)}=')

In [None]:
print("list of words to be excluded from respeller training as they do not have mels (likely due to how normalisation is different between mfa and our own data processing):")
words_no_mels

In [None]:
# sort

In [None]:
# remove these problematic words from respeller training dev test
for w in words_no_mels:
    del oovs_and_freqs[w]
    
all_wordtypes = set(oovs_and_freqs.keys())
print(f'original after cleaning {len(all_wordtypes)=}')

dev_N = int(dev_ratio * len(all_wordtypes))
test_N = int(test_ratio * len(all_wordtypes))

In [None]:
def sample_and_remove(s: set, N: int):
    """sample N words from set s
    then remove these words from the set"""
    sampled = random.sample(s, N)
    for item in sampled:
        s.remove(item)
    return set(sampled)

In [None]:
#get dev and test splits
oov_singletons = set(wordtype for wordtype, freq in oovs_and_freqs.items() if freq == 1)
assert len(oov_singletons) > dev_N + test_N, "not enough OOV singletons to create dev and test sets" 
print(f'before sampling dev and test {len(oov_singletons)=}')

dev = sample_and_remove(oov_singletons, dev_N)
print(f'after sampling dev {len(oov_singletons)=}, {len(dev)=}')

test = sample_and_remove(oov_singletons, test_N)
print(f'after sampling test {len(oov_singletons)=}, {len(test)=}')

In [None]:
list(dev)[:10]

In [None]:
list(test)[:10]

In [None]:
#get train split
print(f'before removing dev and test wordtypes {len(all_wordtypes)=}')
for word in dev | test:
    all_wordtypes.remove(word)
print(f'after removing dev and test wordtypes {len(all_wordtypes)=}')

train = set(all_wordtypes)

In [None]:
# sanity checks
assert len(dev.intersection(test)) == 0
assert len(train.intersection(dev)) == 0
assert len(train.intersection(test)) == 0
print("Good! No overlapping words between train, dev, and test!!!")

In [None]:
# write to disk
def save_wordlist(path, words):
    with open(path, 'w') as f:
        json.dump(sorted(list(words)), f, indent=4)
        
train_path = '/home/s1785140/data/ljspeech_fastpitch/respeller_train_words.json'
dev_path = '/home/s1785140/data/ljspeech_fastpitch/respeller_dev_words.json'
test_path = '/home/s1785140/data/ljspeech_fastpitch/respeller_test_words.json'

save_wordlist(train_path, train)
save_wordlist(dev_path, dev)
save_wordlist(test_path, test)

## G2P selection