In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from copy import copy
import itertools
import logging
import re
from tempfile import TemporaryDirectory

import datasets
import matplotlib.pyplot as plt
import numpy as np
import transformers

from src.datasets.barakeet import BarakeetDataset
from src.utils import syllabifier

In [None]:
L = logging.getLogger("barakeet")

In [None]:
data_dir = "data/barakeet"
out_path = "outputs/preprocessed_data/barakeet"
left_pad = 0.5

In [None]:
datasets.disable_caching()

In [None]:
with TemporaryDirectory() as tempdir:
    ds = datasets.load_dataset("src/datasets/barakeet.py", data_dir=data_dir, cache_dir=tempdir)
ds = ds["train"]

## Make audio statistically comparable to librispeech

In [None]:
librispeech_ds = datasets.load_from_disk("outputs/preprocessed_data/librispeech-train-clean-100")

In [None]:
from scipy.signal import welch
import pandas as pd
import seaborn as sns

def compute_item_stats(item):
    arr = item["audio"]["array"]

    freqs, psd = welch(arr, fs=item["audio"]["sampling_rate"], nperseg=1024)

    signal_power = np.mean(arr ** 2)
    noise_power = np.var(arr)
    snr = 10 * np.log10(signal_power / (noise_power + 1e-10))  # Avoid div by zero

    return {
        "mean": np.mean(arr),
        "median": np.median(arr),
        "var": np.var(arr),
        "min": np.min(arr),
        "max": np.max(arr),
        "snr": snr,
    }

def compute_dataset_stats(ds):
    # random sample 10% of dataset
    ds = ds.select(np.random.choice(len(ds), size=min(300, len(ds)), replace=False))
    # compute stats for each item
    stats = ds.map(compute_item_stats, keep_in_memory=True)
    # aggregate stats into a dataframe
    stats_df = pd.DataFrame([{k: v for k, v in item.items() if k not in ds.features.keys()} for item in stats])
    return stats_df

In [None]:
librispeech_stats = compute_dataset_stats(librispeech_ds)

In [None]:
librispeech_mean_stats = librispeech_stats.mean()
librispeech_mean_stats

In [None]:
rescale_target = max(abs(librispeech_mean_stats["min"]), abs(librispeech_mean_stats["max"]))

In [None]:
i = np.random.choice(min(len(ds), len(librispeech_ds)))
fixed = ds[i]["audio"]["array"].copy()
# remove DC distortion
fixed -= fixed.mean()
# match value of silence in librispeech
# fixed += librispeech_ds[0]["audio"]["array"][:10].mean()
# pad
fixed = np.pad(fixed, (int(16000 * left_pad), 0), mode="constant", constant_values=0)
fixed += librispeech_mean_stats["mean"]
fixed /= np.max(np.abs(fixed))
fixed *= rescale_target
plt.plot(librispeech_ds[i]["audio"]["array"], alpha=0.5)
plt.plot(fixed, alpha=0.5)

In [None]:
from IPython.display import Audio

In [None]:
Audio(fixed, rate=ds[i]["audio"]["sampling_rate"])

In [None]:
def rescale_item(item):
    arr = item["audio"]["array"].copy()
    # remove DC distortion
    arr -= arr.mean()

    # pad
    arr = np.pad(arr, (int(16000 * left_pad), 0), mode="constant", constant_values=0)

    arr += librispeech_mean_stats["mean"]
    arr /= np.max(np.abs(arr))
    arr *= rescale_target
    item["audio"]["array"] = arr
    return item
ds = ds.map(rescale_item)

In [None]:
# add offsets to all annotations
def fix_offsets(item):
    for annot in ["word_detail", "word_raw_detail", "phonetic_detail"]:
        item[annot]["start"] = [start + int(left_pad * item["audio"]["sampling_rate"]) for start in item[annot]["start"]]
        item[annot]["stop"] = [end + int(left_pad * item["audio"]["sampling_rate"]) for end in item[annot]["stop"]]

    return item
ds = ds.map(fix_offsets)

In [None]:
tokenizer = transformers.Wav2Vec2Tokenizer.from_pretrained("charsiu/tokenizer_en_cmu")
feature_extractor = transformers.Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = transformers.Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def add_phonemic_detail(item):
    starts = copy(item["phonetic_detail"]["start"])
    stops = copy(item["phonetic_detail"]["stop"])
    utterances = copy(item["phonetic_detail"]["utterance"])

    # remove stress annotations
    utterances = [re.sub(r"\d", "", u) for u in utterances]

    item["phonemic_detail"] = {
        "start": starts,
        "stop": stops,
        "utterance": utterances
    }

    return item

In [None]:
def group_phonetic_detail(item, idx, drop_phones=None, key="phonetic_detail"):
    """
    Group phonetic_detail entries according to the containing word.
    """
    phonetic_detail = item[key]
    word_detail = item["word_detail"]

    # Assure that each phone gets mapped to exactly one word. We'll arbitrarily map to the
    # first word that contains the phone; this seems to most frequently match TIMIT annotation standards
    phone_mask = np.zeros(len(phonetic_detail["start"]), dtype=bool)
    # Note that we also assign phonemes which span words to the leftmost word, consistent
    # with this strategy

    word_phonetic_detail = []
    for start, stop, word in zip(word_detail["start"], word_detail["stop"], word_detail["utterance"]):
        word_phonetic_detail.append([])
        for j, (phon_start, phon_stop, phon) in enumerate(zip(phonetic_detail["start"], phonetic_detail["stop"], phonetic_detail["utterance"])):
            if phone_mask[j]:
                continue
            elif drop_phones is not None and phon in drop_phones:
                phone_mask[j] = True
                continue
            
            # if the phoneme has start in this word, assign it to this word
            if phon_start >= start and phon_start < stop:
                phone_mask[j] = True
                word_phonetic_detail[-1].append({"phone": phon, "start": phon_start, "stop": phon_stop})

        if len(word_phonetic_detail[-1]) == 0:
            if word == "":
                # expected for these empty-word cases in librispeech annotations
                continue
            preceding_word_phones = " ".join(phone["phone"] for phone in word_phonetic_detail[-2]) if len(word_phonetic_detail) > 1 else ""
            L.warning(f"No phones found for word {word} in item {idx} ({item['text']}) (preceding word: {preceding_word_phones})")

    for unused_phone in np.flatnonzero(~phone_mask):
        preceding_phones = " ".join(phonetic_detail["utterance"][max(0, unused_phone - 3):unused_phone])
        following_phones = " ".join(phonetic_detail["utterance"][unused_phone + 1:min(len(phonetic_detail["utterance"]), unused_phone + 4)])
        unused_phone_str = phonetic_detail["utterance"][unused_phone]
        L.warning(f"Unused phone {unused_phone_str} in item {idx} ({item['text']}) (preceding: {preceding_phones}, following: {following_phones})")

    # from pprint import pprint
    # pprint(list(zip(word_detail["start"], word_detail["stop"], word_detail["utterance"])))
    # pprint(list(zip(phonetic_detail["start"], phonetic_detail["stop"], phonetic_detail["utterance"])))
    # pprint(word_phonetic_detail)

    item[f"word_{key}"] = word_phonetic_detail
    return item

In [None]:
def add_syllabic_detail(item):
    word_syllables = []

    # syllabifier doesn't use stress information so we can just use
    # phonemic detail here
    for word in item["word_phonemic_detail"]:
        phones = [ph["phone"] for ph in word if ph["phone"] not in ["[SIL]", ""]]
        if len(phones) > 0:
            syllables = syllabifier.syllabify(syllabifier.EnglishIPA, phones)

            assert phones == list(itertools.chain.from_iterable(
                [tuple(onset) + tuple(nucleus) + tuple(coda) for stress, onset, nucleus, coda in syllables]))
            # print(syllables)
            # word["syllables"] = syllables

            phoneme_idx, syllable_idx = 0, 0
            syllable_dicts = []
            for stress, onset, nucleus, coda in syllables:
                syllable_phones = tuple(onset + nucleus + coda)
                syllable_dict = {
                    "phones": syllable_phones,
                    "idx": syllable_idx,
                    "phoneme_start_idx": phoneme_idx,
                    "phoneme_end_idx": phoneme_idx + len(syllable_phones), # exclusive
                    "stress": stress,

                    "start": word[phoneme_idx]["start"],
                    "stop": word[phoneme_idx + len(syllable_phones) - 1]["stop"],
                }

                # Add cross-reference data in word_phonemic_detail
                for j, ph in enumerate(syllable_phones):
                    word[phoneme_idx + j]["syllable_idx"] = syllable_idx
                    word[phoneme_idx + j]["idx_in_syllable"] = j
                    word[phoneme_idx + j]["syllable_phones"] = tuple(syllable_phones)
                    word[phoneme_idx + j]["stress"] = stress
                    word[phoneme_idx + j]["syllable_start"] = syllable_dict["start"]
                    word[phoneme_idx + j]["syllable_stop"] = syllable_dict["stop"]

                syllable_dicts.append(syllable_dict)
                phoneme_idx += len(syllable_phones)
                syllable_idx += 1
        else:
            syllable_dicts = []

        word_syllables.append(syllable_dicts)
    
    item["word_syllable_detail"] = word_syllables
    return item

In [None]:
def check_item(item, idx, drop_phones=None):
    try:
        grouped_phonemic_detail = item["word_phonemic_detail"]
        grouped_syllable_detail = item["word_syllable_detail"]
        assert len(grouped_phonemic_detail) == len(item["word_detail"]["utterance"])
        assert len(grouped_syllable_detail) == len(item["word_detail"]["utterance"])

        all_phonemes = [phon["phone"] for word in grouped_phonemic_detail for phon in word]
        all_phonemes_syll = [phone for word in item["word_syllable_detail"] for syllable in word for phone in syllable["phones"]]
        assert len(all_phonemes) == len(all_phonemes_syll)
        assert all_phonemes == all_phonemes_syll, "phonemic detail does not match phonemes within syllable detail"

        # NB we do expect a mismatch here since some phonemes in the flat representation
        # won't appear in the word grouped representation, if they are outside the span of a word
        # all_phonemes_flat = [ph for ph in item["phonemic_detail"]["utterance"] if ph not in (drop_phones or [])]
        # assert all_phonemes == all_phonemes_flat, \
        #     f"grouped phonemic detail does not match non-grouped phonemic detail in item {idx}:" \
        #     f"\n{item['text']}\n{all_phonemes}\n{all_phonemes_flat}"
    except Exception as e:
        L.error(f"Error in item {idx} ({item['text']})")
        raise e

In [None]:
def prepare_audio(batch):
    audio = batch["audio"]
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    return batch

In [None]:
def add_idx(item, idx):
    item["idx"] = idx
    return item

In [None]:
drop_phones = ["sil", "sp", "spn", ""]

dev_dataset = ds.map(add_phonemic_detail)
dev_dataset = dev_dataset.map(group_phonetic_detail, with_indices=True,
                              fn_kwargs=dict(drop_phones=drop_phones))
dev_dataset = dev_dataset.map(group_phonetic_detail, with_indices=True,
                              fn_kwargs=dict(key="phonemic_detail", drop_phones=drop_phones))

dev_dataset = dev_dataset.map(add_syllabic_detail)

dev_dataset.map(check_item, with_indices=True)

dev_dataset = dev_dataset.map(prepare_audio)
dev_dataset = dev_dataset.map(add_idx, with_indices=True)

In [None]:
def plot_item(item_idx, ax, plot_units="phoneme", viz_rate=1000):
    item = dev_dataset[item_idx]

    times = np.linspace(0, len(item["input_values"]) / 16000, int(len(item["input_values"]) / 16000 * viz_rate))
    # normalize to [-1, 1]
    values = np.array(item["input_values"])
    values = (values - values.min()) / (values.max() - values.min()) * 2 - 1
    # resample to viz frame rate
    values = np.interp(times, np.arange(len(values)) / 16000, values)
    ax.plot(times, values, alpha=0.2)

    # plot word and phoneme boundaries
    for i, word in enumerate(item["word_phonemic_detail"]):
        if not word:
            continue
        word_str = item["word_detail"]["utterance"][i]

        word_start, word_stop = word[0]["start"] / 16000, word[-1]["stop"] / 16000
        ax.axvline(word_start, color="black", linestyle="--")
        ax.text(word_start, 0.8, word_str, rotation=90, verticalalignment="bottom", alpha=0.7)

        if plot_units == "phoneme":
            for j, phoneme in enumerate(word):
                phoneme_str = phoneme["phone"]
                phoneme_start, phoneme_stop = phoneme["start"] / 16000, phoneme["stop"] / 16000

                if j > 0:
                    color = "black" if phoneme["idx_in_syllable"] == 0 else "gray"
                    ax.axvline(phoneme_start, color=color, linestyle=":", alpha=0.5)
                ax.text(phoneme_start + 0.01, -6, phoneme_str, rotation=90, verticalalignment="bottom",
                        fontdict={"size": 15})
        elif plot_units == "syllable":
            for j, syllable in enumerate(item["word_syllable_detail"][i]):
                syllable_str = " ".join(syllable["phones"])
                syllable_start, syllable_stop = syllable["start"] / 16000, syllable["stop"] / 16000

                if j > 0:
                    ax.axvline(syllable_start, color="black", linestyle=":", alpha=0.5)
                ax.text(syllable_start + 0.01, -6, syllable_str, rotation=90, verticalalignment="bottom",
                        fontdict={"size": 15})
        else:
            raise ValueError(f"Unknown plot_units: {plot_units}")

    # align at origin
    ax.set_ylim((-8, 8))

    ax.set_title(f"{item['id']}: {item['text']}")
    ax.set_yticks([])
    ax.grid(False)
    ax.axis("off")

In [None]:
f, axs = plt.subplots(2, 1, figsize=(25, 2 * 8))
idx = np.random.choice(len(dev_dataset))
print(idx)
plot_item(idx, axs[0], plot_units="phoneme")
plot_item(idx, axs[1], plot_units="syllable")

In [None]:
dev_dataset.save_to_disk(out_path)