In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
from copy import copy
import itertools
import logging
import re

import datasets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import soundfile as sf
import transformers

from src.utils import syllabifier

In [5]:
L = logging.getLogger(__name__)

In [6]:
sns.set_theme(style="whitegrid", font_scale=2)

In [7]:
split = "train-clean-100"
data_dir = f"data/librispeech/{split}"
alignment_dir = "data/librispeech_alignments"
out_path = "."

In [8]:
datasets.disable_caching()

In [10]:
dataset = datasets.load_dataset(
    "src/datasets/huggingface_librispeech.py", data_dir=data_dir,
    alignment_dir=alignment_dir)[split.replace("-", ".")]

Generating train.clean.100 split:   0%|          | 0/28539 [00:00<?, ? examples/s]

In [11]:
dev_dataset = dataset

In [12]:
tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("charsiu/tokenizer_en_cmu")
feature_extractor = transformers.Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2CTCTokenizer'. 
The class this function is called from is 'Wav2Vec2Tokenizer'.


In [13]:
def add_phonemic_detail(item):
    starts = copy(item["phonetic_detail"]["start"])
    stops = copy(item["phonetic_detail"]["stop"])
    utterances = copy(item["phonetic_detail"]["utterance"])

    # remove stress annotations
    utterances = [re.sub(r"\d", "", u) for u in utterances]

    item["phonemic_detail"] = {
        "start": starts,
        "stop": stops,
        "utterance": utterances
    }

    return item

In [14]:
def group_phonetic_detail(item, idx, drop_phones=None, key="phonetic_detail"):
    """
    Group phonetic_detail entries according to the containing word.
    """
    phonetic_detail = item[key]
    word_detail = item["word_detail"]

    # Assure that each phone gets mapped to exactly one word. We'll arbitrarily map to the
    # first word that contains the phone; this seems to most frequently match TIMIT annotation standards
    phone_mask = np.zeros(len(phonetic_detail["start"]), dtype=bool)
    # Note that we also assign phonemes which span words to the leftmost word, consistent
    # with this strategy

    word_phonetic_detail = []
    for start, stop, word in zip(word_detail["start"], word_detail["stop"], word_detail["utterance"]):
        word_phonetic_detail.append([])
        for j, (phon_start, phon_stop, phon) in enumerate(zip(phonetic_detail["start"], phonetic_detail["stop"], phonetic_detail["utterance"])):
            if phone_mask[j]:
                continue
            elif drop_phones is not None and phon in drop_phones:
                phone_mask[j] = True
                continue
            
            # if the phoneme has start in this word, assign it to this word
            if phon_start >= start and phon_start < stop:
                phone_mask[j] = True
                word_phonetic_detail[-1].append({"phone": phon, "start": phon_start, "stop": phon_stop})

        if len(word_phonetic_detail[-1]) == 0:
            if word == "":
                # expected for these empty-word cases in librispeech annotations
                continue
            preceding_word_phones = " ".join(phone["phone"] for phone in word_phonetic_detail[-2]) if len(word_phonetic_detail) > 1 else ""
            L.warning(f"No phones found for word {word} in item {idx} ({item['text']}) (preceding word: {preceding_word_phones})")

    for unused_phone in np.flatnonzero(~phone_mask):
        preceding_phones = " ".join(phonetic_detail["utterance"][max(0, unused_phone - 3):unused_phone])
        following_phones = " ".join(phonetic_detail["utterance"][unused_phone + 1:min(len(phonetic_detail["utterance"]), unused_phone + 4)])
        unused_phone_str = phonetic_detail["utterance"][unused_phone]
        L.warning(f"Unused phone {unused_phone_str} in item {idx} ({item['text']}) (preceding: {preceding_phones}, following: {following_phones})")

    # from pprint import pprint
    # pprint(list(zip(word_detail["start"], word_detail["stop"], word_detail["utterance"])))
    # pprint(list(zip(phonetic_detail["start"], phonetic_detail["stop"], phonetic_detail["utterance"])))
    # pprint(word_phonetic_detail)

    item[f"word_{key}"] = word_phonetic_detail
    return item

In [15]:
def add_syllabic_detail(item):
    word_syllables = []

    # syllabifier doesn't use stress information so we can just use
    # phonemic detail here
    for word in item["word_phonemic_detail"]:
        phones = [ph["phone"] for ph in word if ph["phone"] not in ["[SIL]", ""]]
        if len(phones) > 0:
            syllables = syllabifier.syllabify(syllabifier.English, phones)

            assert phones == list(itertools.chain.from_iterable(
                [tuple(onset) + tuple(nucleus) + tuple(coda) for stress, onset, nucleus, coda in syllables]))
            # print(syllables)
            # word["syllables"] = syllables

            phoneme_idx, syllable_idx = 0, 0
            syllable_dicts = []
            for stress, onset, nucleus, coda in syllables:
                syllable_phones = tuple(onset + nucleus + coda)
                syllable_dict = {
                    "phones": syllable_phones,
                    "idx": syllable_idx,
                    "phoneme_start_idx": phoneme_idx,
                    "phoneme_end_idx": phoneme_idx + len(syllable_phones), # exclusive
                    "stress": stress,

                    "start": word[phoneme_idx]["start"],
                    "stop": word[phoneme_idx + len(syllable_phones) - 1]["stop"],
                }

                # Add cross-reference data in word_phonemic_detail
                for j, ph in enumerate(syllable_phones):
                    word[phoneme_idx + j]["syllable_idx"] = syllable_idx
                    word[phoneme_idx + j]["idx_in_syllable"] = j
                    word[phoneme_idx + j]["syllable_phones"] = tuple(syllable_phones)
                    word[phoneme_idx + j]["stress"] = stress
                    word[phoneme_idx + j]["syllable_start"] = syllable_dict["start"]
                    word[phoneme_idx + j]["syllable_stop"] = syllable_dict["stop"]

                syllable_dicts.append(syllable_dict)
                phoneme_idx += len(syllable_phones)
                syllable_idx += 1
        else:
            syllable_dicts = []

        word_syllables.append(syllable_dicts)
    
    item["word_syllable_detail"] = word_syllables
    return item

In [16]:
def check_item(item, idx, drop_phones=None):
    try:
        grouped_phonemic_detail = item["word_phonemic_detail"]
        grouped_syllable_detail = item["word_syllable_detail"]
        assert len(grouped_phonemic_detail) == len(item["word_detail"]["utterance"])
        assert len(grouped_syllable_detail) == len(item["word_detail"]["utterance"])

        all_phonemes = [phon["phone"] for word in grouped_phonemic_detail for phon in word]
        all_phonemes_syll = [phone for word in item["word_syllable_detail"] for syllable in word for phone in syllable["phones"]]
        assert len(all_phonemes) == len(all_phonemes_syll)
        assert all_phonemes == all_phonemes_syll, "phonemic detail does not match phonemes within syllable detail"

        # NB we do expect a mismatch here since some phonemes in the flat representation
        # won't appear in the word grouped representation, if they are outside the span of a word
        # all_phonemes_flat = [ph for ph in item["phonemic_detail"]["utterance"] if ph not in (drop_phones or [])]
        # assert all_phonemes == all_phonemes_flat, \
        #     f"grouped phonemic detail does not match non-grouped phonemic detail in item {idx}:" \
        #     f"\n{item['text']}\n{all_phonemes}\n{all_phonemes_flat}"
    except Exception as e:
        L.error(f"Error in item {idx} ({item['text']})")
        raise e

In [17]:
def prepare_audio(batch):
    audio = batch["audio"]
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    return batch

In [19]:
def add_idx(item, idx):
    item["idx"] = idx
    item["split"] = split

    return item

In [20]:
drop_phones = ["sil", "sp", "spn", ""]

dev_dataset = dev_dataset.map(add_phonemic_detail)
dev_dataset = dev_dataset.map(group_phonetic_detail, with_indices=True,
                              fn_kwargs=dict(drop_phones=drop_phones))
dev_dataset = dev_dataset.map(group_phonetic_detail, with_indices=True,
                              fn_kwargs=dict(key="phonemic_detail", drop_phones=drop_phones))

dev_dataset = dev_dataset.map(add_syllabic_detail)

dev_dataset.map(check_item, with_indices=True)

dev_dataset = dev_dataset.map(prepare_audio)
dev_dataset = dev_dataset.map(add_idx, with_indices=True)

Map:   0%|          | 0/28538 [00:00<?, ? examples/s]

Map:   0%|          | 0/28538 [00:00<?, ? examples/s]

No phones found for word <unk> in item 4 (A ONE METER LUMPFISH BLACKISH ON TOP) (preceding word: M IY1 T ER0)
No phones found for word <unk> in item 6 (THEN THE CABLE WAS RESUBMERGED BUT A FEW DAYS LATER IT SNAPPED AGAIN AND COULDN'T BE RECOVERED FROM THE OCEAN DEPTHS) (preceding word: W AH0 Z)
No phones found for word <unk> in item 8 (SIGNED BETWEEN PRUSSIA AND AUSTRIA AFTER THE BATTLE OF SADOVA THROUGH THE MISTS ON THE TWENTY SEVENTH IT SIGHTED THE PORT OF HEART'S CONTENT) (preceding word: AH1 V)
No phones found for word <unk> in item 12 (MEASURING TWO DECIMETERS GRENADIERS WITH LONG TAILS AND GLEAMING WITH A SILVERY GLOW SPEEDY FISH VENTURING FAR FROM THEIR HIGH ARCTIC SEAS OUR NETS ALSO HAULED IN A BOLD) (preceding word: T UW1)
No phones found for word <unk> in item 20 (IT CARRIED SEVENTY FOUR CANNONS AND WAS LAUNCHED IN SEVENTEEN SIXTY TWO ON AUGUST THIRTEENTH SEVENTEEN SEVENTY EIGHT COMMANDED BY LA POYPE VERTRIEUX) (preceding word: L AA1)
No phones found for word <unk> in item 20

Map:   0%|          | 0/28538 [00:00<?, ? examples/s]

No phones found for word <unk> in item 4 (A ONE METER LUMPFISH BLACKISH ON TOP) (preceding word: M IY T ER)
No phones found for word <unk> in item 6 (THEN THE CABLE WAS RESUBMERGED BUT A FEW DAYS LATER IT SNAPPED AGAIN AND COULDN'T BE RECOVERED FROM THE OCEAN DEPTHS) (preceding word: W AH Z)
No phones found for word <unk> in item 8 (SIGNED BETWEEN PRUSSIA AND AUSTRIA AFTER THE BATTLE OF SADOVA THROUGH THE MISTS ON THE TWENTY SEVENTH IT SIGHTED THE PORT OF HEART'S CONTENT) (preceding word: AH V)
No phones found for word <unk> in item 12 (MEASURING TWO DECIMETERS GRENADIERS WITH LONG TAILS AND GLEAMING WITH A SILVERY GLOW SPEEDY FISH VENTURING FAR FROM THEIR HIGH ARCTIC SEAS OUR NETS ALSO HAULED IN A BOLD) (preceding word: T UW)
No phones found for word <unk> in item 20 (IT CARRIED SEVENTY FOUR CANNONS AND WAS LAUNCHED IN SEVENTEEN SIXTY TWO ON AUGUST THIRTEENTH SEVENTEEN SEVENTY EIGHT COMMANDED BY LA POYPE VERTRIEUX) (preceding word: L AA)
No phones found for word <unk> in item 20 (IT C

Map:   0%|          | 0/28538 [00:00<?, ? examples/s]

Map:   0%|          | 0/28538 [00:00<?, ? examples/s]

Map:   0%|          | 0/28538 [00:00<?, ? examples/s]

In [None]:
def plot_item(item_idx, ax, plot_units="phoneme", viz_rate=1000):
    item = dev_dataset[item_idx]

    times = np.linspace(0, len(item["input_values"]) / 16000, int(len(item["input_values"]) / 16000 * viz_rate))
    # normalize to [-1, 1]
    values = np.array(item["input_values"])
    values = (values - values.min()) / (values.max() - values.min()) * 2 - 1
    # resample to viz frame rate
    values = np.interp(times, np.arange(len(values)) / 16000, values)
    ax.plot(times, values, alpha=0.2)

    # plot word and phoneme boundaries
    for i, word in enumerate(item["word_phonemic_detail"]):
        if not word:
            continue
        word_str = item["word_detail"]["utterance"][i]

        word_start, word_stop = word[0]["start"] / 16000, word[-1]["stop"] / 16000
        ax.axvline(word_start, color="black", linestyle="--")
        ax.text(word_start, 0.8, word_str, rotation=90, verticalalignment="bottom", alpha=0.7)

        if plot_units == "phoneme":
            for j, phoneme in enumerate(word):
                phoneme_str = phoneme["phone"]
                phoneme_start, phoneme_stop = phoneme["start"] / 16000, phoneme["stop"] / 16000

                if j > 0:
                    color = "black" if phoneme["idx_in_syllable"] == 0 else "gray"
                    ax.axvline(phoneme_start, color=color, linestyle=":", alpha=0.5)
                ax.text(phoneme_start + 0.01, -6, phoneme_str, rotation=90, verticalalignment="bottom",
                        fontdict={"size": 15})
        elif plot_units == "syllable":
            for j, syllable in enumerate(item["word_syllable_detail"][i]):
                syllable_str = " ".join(syllable["phones"])
                syllable_start, syllable_stop = syllable["start"] / 16000, syllable["stop"] / 16000

                if j > 0:
                    ax.axvline(syllable_start, color="black", linestyle=":", alpha=0.5)
                ax.text(syllable_start + 0.01, -6, syllable_str, rotation=90, verticalalignment="bottom",
                        fontdict={"size": 15})
        else:
            raise ValueError(f"Unknown plot_units: {plot_units}")

    # align at origin
    ax.set_ylim((-8, 8))

    ax.set_title(f"{item['speaker_id']}_{item['id']}: {item['text']}")
    ax.set_yticks([])
    ax.grid(False)
    ax.axis("off")

In [None]:
f, axs = plt.subplots(2, 1, figsize=(25, 2 * 8))
idx = np.random.choice(len(dev_dataset))
print(idx)
plot_item(idx, axs[0], plot_units="phoneme")
plot_item(idx, axs[1], plot_units="syllable")

## Check word-level correspondence with CMUdict

In [None]:
from collections import defaultdict, Counter
from tempfile import NamedTemporaryFile
from urllib.request import urlretrieve
import re
from pprint import pprint

# Download and parse cmudict
cmudict_entries = defaultdict(list)
with NamedTemporaryFile() as f:
    urlretrieve("https://github.com/cmusphinx/cmudict/raw/master/cmudict.dict", f.name)

    with open(f.name, "r") as f:
        for line in f:
            # remove comments
            line = re.sub(r'(\s)*#.*', '', line)

            fields = line.strip().split(" ")
            word = fields[0]

            # remove word idx number, indicating secondary pronunciation
            word = re.sub(r"\(\d\)$", "", word)

            phones = tuple(fields[1:])
            # remove stress markers
            phones = tuple(re.sub(r"\d", "", p) for p in phones)

            cmudict_entries[word].append(phones)


In [None]:
# Track attested pronunciations of each word in TIMIT
corpus_cmudict_mapping = defaultdict(Counter)
def process_item(item):
    for word, word_phonemes in zip(item["word_detail"]["utterance"], item["word_phonemic_detail"]):
        corpus_cmudict_mapping[word.lower()][tuple(p["phone"] for p in word_phonemes)] += 1
dev_dataset.map(process_item)

In [None]:
# How many words have multiple pronunciations?
multiple_pronunciations = {k: v for k, v in corpus_cmudict_mapping.items() if len(v) > 1}
print(f"{len(multiple_pronunciations)} words ({len(multiple_pronunciations) / len(corpus_cmudict_mapping) * 100}%) have multiple pronunciations")

In [None]:
# How many words have CMUDICT pronunciations?
has_cmudict = {k: v for k, v in corpus_cmudict_mapping.items() if k in cmudict_entries}
print(f"{len(has_cmudict)} words ({len(has_cmudict) / len(corpus_cmudict_mapping) * 100}%) have CMUDICT pronunciations")

In [None]:
# For how many words does the majority pronunciation align with the CMUDICT pronunciation?
majority_aligned = {k: v for k, v in corpus_cmudict_mapping.items()
                    if len(cmudict_entries[k]) > 0 and v.most_common(1)[0][0] == cmudict_entries[k][0]}
majority_misaligned = {k: v for k, v in corpus_cmudict_mapping.items()
                       if len(cmudict_entries[k]) > 0 and v.most_common(1)[0][0] != cmudict_entries[k][0]}
print(f"{len(majority_aligned)} words ({len(majority_aligned) / len(corpus_cmudict_mapping) * 100}%) have majority-aligned CMUDICT pronunciations")

In [None]:
# For misaligned majorities, compare with CMUDICT
for word, counts in majority_misaligned.items():
    print(f"{word}: {' '.join(counts.most_common(1)[0][0])} (LibriSpeech) vs {' '.join(cmudict_entries[word][0])} (CMUDICT)")

## Syllable analysis

In [None]:
all_syllable_counts = Counter()
word_syllable_counts = defaultdict(Counter)

def process_item(item):
    for i, (word, syllables) in enumerate(zip(item["word_detail"]["utterance"], item["word_syllable_detail"])):
        syll_string = tuple(tuple(syllable["phones"]) for syllable in syllables)
        word_syllable_counts[word.lower()][syll_string] += 1
        for syllable in syll_string:
            all_syllable_counts[syllable] += 1
dev_dataset.map(process_item)

In [None]:
all_syllable_counts.most_common(20)

In [None]:
cmudict_vowels = {"AA", "AE", "AH", "AO", "AW", "AY", "EH", "ER", "EY", "IH", "IY", "OW", "OY", "UH", "UW"}

print("Syllabic consonant frequencies:")
syllabic_frequencies = Counter({k: v for k, v in all_syllable_counts.items() if len(k) == 1 and k[0] not in cmudict_vowels})
pprint(syllabic_frequencies)

print("Proportion of total syllable tokens: ", sum(syllabic_frequencies.values()) / sum(all_syllable_counts.values()) * 100, "%")

In [None]:
multiple_syllabification_words = Counter({k: v for k, v in word_syllable_counts.items() if len(v) > 1})
print(f"{len(multiple_syllabification_words)} words ({len(multiple_syllabification_words) / len(word_syllable_counts) * 100}%) have multiple syllabifications")

In [None]:
# Log top token frequency syllables
sorted(multiple_syllabification_words.items(), key=lambda x: sum(x[1].values()), reverse=True)[:10]

## Save to disk

In [None]:
dev_dataset.save_to_disk(out_path)