In [16]:
import textgrid
import torch
import matplotlib.pyplot as plt
import math
def normalize_continious_phonemes(src):
    res = []
    time = 0
    for t in src:
        tok = t[0]
        start = t[1]
        end = t[2]
        if start != time:
            res.append(('<SIL>', time, start))
        res.append(t)
        time = end
    return res

def quantisize_phoneme_positions(src, phoneme_duration):
    res = []
    for t in src:
        tok = t[0]
        start = int(t[1] / phoneme_duration)
        end = int(t[2] / phoneme_duration)
        # start = int(t[1] // phoneme_duration)
        # end = int(t[2] // phoneme_duration)
        res.append((tok, start, end))
    return res

def continious_phonemes_to_discreete(raw_phonemes, phoneme_duration):

    # Normalize: add silence between intervals,
    #            ensure that start of any token is equal to end of a previous,
    #            ensure that first token is zero
    raw_phonemes = normalize_continious_phonemes(raw_phonemes)

    # Quantisize offsets: convert from real one to a discreete one
    quantisized = quantisize_phoneme_positions(raw_phonemes, phoneme_duration)

    # Convert to intervals
    intervals = [(i[0], i[2] - i[1]) for i in quantisized]

    return intervals

def extract_textgrid_alignments(tg):
    output = []
    for t in tg[1]:
        ends = t.maxTime
        tok = t.mark
        if tok == '': # Ignore spaces
            continue
        if tok == 'spn':
            tok = '<UNK>'
        output.append((tok, t.minTime, t.maxTime))
    return output

def prepare_textgrid_alignments(tg, total_duration, phoneme_duration):

    # Extract alignments
    x = extract_textgrid_alignments(tg)

    # Convert to discreete
    x = continious_phonemes_to_discreete(x, phoneme_duration)

    # Fitler empty
    x = [i for i in x if i[1] > 0]

    # Pad with silence
    total_length = sum([i[1] for i in x])
    assert total_length <= total_duration # We don't have reverse in our datasets
    if total_length < total_duration:
        x += [('<SIL>', total_duration - total_length)]

    assert total_length >= 2 # We expect at least two tokens

    # Patch first token
    if x[0][1] == 1:
        x[0] = ('<BEGIN>', 1)
    else:
        x = [('<BEGIN>', 1), (x[0][0], x[0][1] - 1)] + x[1:]

    # Patch last token
    if x[-1][1] == 1:
        x[-1] = ('<END>', 1)
    else:
        x = x[:-1] + [(x[-1][0], x[-1][1] - 1), ('<END>', 1)]

    return x


tg = textgrid.TextGrid.fromFile("datasets/vctk-aligned/00000068/00055895.TextGrid")
audio = torch.load("datasets/vctk-prepared/00000068/00055895.pt")
target_duration = audio.shape[1]

res = prepare_textgrid_alignments(tg, target_duration, 256 / 24000)
print(res)
phonemes = []
for t in res:
    for i in range(t[1]):
        phonemes.append(t[0])
if len(phonemes) != target_duration:
    raise Exception("Phonemes and audio length mismatch: " + str(len(phonemes)) + " != " + str(target_duration))
    
# p_d = 256 / 24000
# source = [['a', 0.2, 0.3], ['b', 0.34, 0.46]]
# source_norm = normalize_continious_phonemes(source)
# target = continious_phonemes_to_discreete(source, 60, p_d)
# print(source)
# print(source_norm)
# print(target)
# print(list((i[0], i[2] - i[1]) for i in source_norm))
# print(list((i[0], i[1], i[1] * p_d) for i in target))

[('<BEGIN>', 1), ('<SIL>', 2), ('ð', 84), ('ə', 17), ('ç', 9), ('iː', 10), ('b', 10), ('ɹ', 2), ('ʉː', 18), ('z', 7), ('ɪ', 8), ('w', 8), ('ə', 4), ('z', 6), ('ə', 6), ('tʰ', 12), ('ow', 6), ('k', 14), ('ə', 6), ('n', 5), ('d̪', 3), ('æ', 5), ('ʔ', 4), ('d̪', 5), ('ɛ', 3), ('ɹ', 7), ('w', 5), ('ʊ', 5), ('d', 2), ('bʲ', 7), ('iː', 5), ('n', 9), ('ow', 10), ('m', 9), ('ɒː', 3), ('ɹ', 6), ('j', 11), ('ʉ', 1), ('ɲ', 4), ('ɪ', 5), ('v', 4), ('ɝ', 9), ('s', 9), ('ə', 4), ('ɫ', 4), ('f', 9), ('l', 5), ('ɐ', 10), ('d', 5), ('z', 24), ('<SIL>', 249), ('<END>', 1)]


In [None]:
def plot_alignments(spec, alignments, tg):

    # Draw spec
    _, ax = plt.subplots(1, 1, figsize=(20, 10))
    ax.imshow(spec, vmin=-10, vmax=0, origin="lower", aspect="auto", extent=(0, spec.shape[1], 0, spec.shape[0]))

    # Draw annotations
    time = 0
    for span in alignments:
        span_start = time
        span_end = time + span[1]
        time += span[1]
        ax.axvspan(span_start, span_end, facecolor="None", edgecolor="yellow")
        ax.annotate(span[0], (span_start, 90), annotation_clip=True, color="white")
    for span in tg[1]:
        span_start = math.floor(span.minTime / (256 / 24000))
        span_end = math.floor(span.maxTime / (256 / 24000))
        ax.axvspan(span_start, span_end, facecolor="None", edgecolor="red")
        ax.annotate(span.mark, (span_start, 70), annotation_clip=True, color="white")
plot_alignments(audio, res, tg)