In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import itertools
from pathlib import Path
import pickle
from typing import Any

import datasets
import matplotlib.pyplot as plt
from mne.decoding import ReceptiveField
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import train_test_split, cross_val_score
import torch
import transformers
from tqdm.auto import tqdm, trange

from src.analysis.trf import estimate_trf_cv
from src.models import get_best_checkpoint
from src.models.integrator import ContrastiveEmbeddingModel, load_or_compute_embeddings

In [None]:
sns.set_theme(context='talk', style='whitegrid')

In [None]:
model_dir = "outputs/models/w2v2_6_8/phoneme"
# model_checkpoint = "out/ce_model_phoneme_6_8/checkpoint-800"
# use a word-level equivalence dataset regardless of model, so that we can look up cohort facts
equiv_dataset_path = "data/timit_equivalence_facebook-wav2vec2-base_6-phoneme-1.pkl"
timit_corpus_path = "data/timit_syllables"

output_dir = "."

In [None]:
model = ContrastiveEmbeddingModel.from_pretrained(get_best_checkpoint(model_dir))
model.eval()

In [None]:
with open(equiv_dataset_path, "rb") as f:
    equiv_dataset = pickle.load(f)

In [None]:
frames_by_item = equiv_dataset.hidden_state_dataset.frames_by_item

In [None]:
model_representations = load_or_compute_embeddings(model, equiv_dataset, model_dir, equiv_dataset_path)

In [None]:
timit_corpus = datasets.load_from_disk(timit_corpus_path)["train"]

In [None]:
def plot_item(item_idx, ax, plot_dims=None):
    item = timit_corpus[item_idx]
    start_frame, end_frame = frames_by_item[item_idx]
    compression_ratio = (end_frame - start_frame) / len(item["input_values"])

    times = np.linspace(0, len(item["input_values"]) / 16000, int(len(item["input_values"]) / 16000 * 1000))
    ax.plot(times, np.interp(times, np.arange(len(item["input_values"])) / 16000,
                            item["input_values"]),
            alpha=0.2)

    # plot word and phoneme boundaries
    for i, word in enumerate(item["word_phonemic_detail"]):
        word_str = item["word_detail"]["utterance"][i]

        word_start, word_stop = word[0]["start"] / 16000, word[-1]["stop"] / 16000
        ax.axvline(word_start, color="black", linestyle="--")
        ax.text(word_start, -5, word_str, rotation=90, verticalalignment="top", alpha=0.7)

        for j, phoneme in enumerate(word):
            phoneme_str = phoneme["phone"]
            phoneme_start, phoneme_stop = phoneme["start"] / 16000, phoneme["stop"] / 16000

            if j > 0:
                color = "black" if phoneme["idx_in_syllable"] == 0 else "gray"
                ax.axvline(phoneme_start, color=color, linestyle=":", alpha=0.5)
            # ax.text(phoneme_start + 0.01, -5, phoneme_str, rotation=90, verticalalignment="bottom", fontdict={"size": 8})

    model_ax = ax.twinx()
    if plot_dims is None:
        plot_dims = list(range(model_representations.shape[1]))
    for dim in plot_dims:
        model_ax.plot(times, np.interp(times, np.arange(0, end_frame - start_frame) / compression_ratio / 16000,
                                model_representations[start_frame:end_frame, dim]),
                label=f"Model dimension {dim + 1}")

    # align at origin
    ax.set_ylim((-8, 8))
    model_ax.set_ylim((-2, 2))
    model_ax.legend()

    ax.set_title(f"{item['speaker_id']}_{item['id']}: {item['text']}")
    ax.set_yticks([])
    model_ax.set_yticks([])
    ax.grid(False)
    model_ax.grid(False)
    ax.axis("off")

In [None]:
plot_items = [18]
f, axs = plt.subplots(len(plot_items), 1, figsize=(18, 8 * len(plot_items)))
for item_idx, ax in zip(plot_items, axs if isinstance(axs, list) else [axs]):
    plot_item(item_idx, ax, plot_dims=[0])

## Plot single word

In [None]:
word_lookup = "act"
matches = []

def find_word(item, idx):
    if word_lookup in item["word_detail"]["utterance"]:
        matches.append((idx, item["word_detail"]["utterance"].index(word_lookup)))

timit_corpus.map(find_word, with_indices=True)

In [None]:
len(matches)

In [None]:
matches

In [None]:
def plot_word_in_item(item_idx, word_idx, ax, annot=True, text=True):
    item = timit_corpus[item_idx]
    start_frame, end_frame = frames_by_item[item_idx]
    compression_ratio = (end_frame - start_frame) / len(item["input_values"])

    word_start_sample, word_end_sample = item["word_detail"]["start"][word_idx], item["word_detail"]["stop"][word_idx]
    word_start, word_end = item["word_detail"]["start"][word_idx] / 16000, item["word_detail"]["stop"][word_idx] / 16000

    times = np.linspace(word_start, word_end, int((word_end - word_start) * 16000))
    audio_samples = np.arange(word_start_sample, word_end_sample)

    # Normalize audio samples to [-1, 1]
    values = np.array(item["input_values"][word_start_sample:word_end_sample])
    values = (values - values.min()) / (values.max() - values.min()) * 2 - 1

    ax.plot(times, np.interp(times, audio_samples / 16000, values),
            alpha=0.3)

    ax.set_xlim((word_start, word_end))
    ax.axis("off")

    if annot:
        for j, phoneme in enumerate(item["word_phonemic_detail"][word_idx]):
            phoneme_str = phoneme["phone"]
            phoneme_start, phoneme_stop = phoneme["start"] / 16000, phoneme["stop"] / 16000

            color = "black" if phoneme["idx_in_syllable"] == 0 else "gray"
            ax.axvline(phoneme_start, color=color, linestyle=":")

            if text:
                ax.text(phoneme_start + 0.01, 0, phoneme_str, verticalalignment="bottom", fontdict={"size": 15, "weight": "bold"})

In [None]:
f, ax = plt.subplots(figsize=(18, 4))
plot_word_in_item(*matches[0], ax, annot=True)

In [None]:
manual_plots = [
    (143, 0), # positive
    (206, 5), # popularity
    (253, 9), # impossible
    (4442, 8), # employee
]

In [None]:
f, axs = plt.subplots(len(manual_plots), 1, figsize=(18, 4 * len(manual_plots)))
if len(manual_plots) == 1:
    axs = [axs]
for (item_idx, word_idx), ax in zip(manual_plots, axs):
    plot_word_in_item(item_idx, word_idx, ax, annot=True, text=True)

for ax in axs:
    ax.set_ylim((-1.1, 1.1))

In [None]:
if len(matches) > 18:
    matches = [matches[idx] for idx in np.random.choice(len(matches), 18, replace=False)]

f, axs = plt.subplots(len(matches), 1, figsize=(18, 4 * len(matches)))
if len(matches) == 1:
    axs = [axs]
for (item_idx, word_idx), ax in zip(matches, axs):
    plot_word_in_item(item_idx, word_idx, ax, annot=False)

for ax in axs:
    ax.set_ylim((-1.1, 1.1))

## Plot single syllable

In [None]:
syllable_lookup = ("IH", "M")
matched_syllables = []
matched_words = []

def find_syllable(item, idx):
    for word_idx, sylls in enumerate(item["word_syllable_detail"]):
        for syll_idx, syll in enumerate(sylls):
            if tuple(syll["phones"]) == syllable_lookup:
                if item["word_detail"]["utterance"][word_idx] in matched_words:
                    continue
                matched_syllables.append((idx, word_idx, syll_idx))
                matched_words.append(item["word_detail"]["utterance"][word_idx])

timit_corpus.map(find_syllable, with_indices=True)

In [None]:
matched_syllables

In [None]:
matched_words

In [None]:
def plot_syllable_in_item(item_idx, word_idx, syll_idx, ax, annot=True):
    item = timit_corpus[item_idx]
    start_frame, end_frame = frames_by_item[item_idx]
    compression_ratio = (end_frame - start_frame) / len(item["input_values"])

    syllable = item["word_syllable_detail"][word_idx][syll_idx]

    syll_start_sample, syll_end_sample = syllable["start"], syllable["stop"]
    syll_start, syll_end = syllable["start"] / 16000, syllable["stop"] / 16000

    times = np.linspace(syll_start, syll_end, int((syll_end - syll_start) * 16000))
    audio_samples = np.arange(syll_start_sample, syll_end_sample)

    # Normalize audio samples to [-1, 1]
    values = np.array(item["input_values"][syll_start_sample:syll_end_sample])
    values = (values - values.min()) / (values.max() - values.min()) * 2 - 1

    ax.plot(times, np.interp(times, audio_samples / 16000, values),
            alpha=0.3)

    ax.set_xlim((syll_start, syll_end))
    ax.axis("off")

    if annot:
        for j, phoneme in enumerate(item["word_phonemic_detail"][word_idx]):
            if j >= syllable["phoneme_start_idx"] and j <= syllable["phoneme_stop_idx"]:
                phoneme_str = phoneme["phone"]
                phoneme_start, phoneme_stop = phoneme["start"] / 16000, phoneme["stop"] / 16000

                color = "black" if phoneme["idx_in_syllable"] == 0 else "gray"
                ax.axvline(phoneme_start, color=color, linestyle=":")
                ax.text(phoneme_start + 0.01, 0, phoneme_str, verticalalignment="bottom", fontdict={"size": 15, "weight": "bold"})

In [None]:
if len(matched_syllables) > 18:
    matched_syllables = [matched_syllables[idx] for idx in np.random.choice(len(matched_syllables), 18, replace=False)]

f, axs = plt.subplots(len(matched_syllables), 1, figsize=(8, 4 * len(matched_syllables)))
if len(matched_syllables) == 1:
    axs = [axs]
for (item_idx, word_idx, syll_idx), ax in zip(matched_syllables, axs):
    plot_syllable_in_item(item_idx, word_idx, syll_idx, ax, annot=False)

for ax in axs:
    ax.set_ylim((-1.1, 1.1))