In [None]:
import sys
sys.path.append("../")

In [None]:
import datasets
from collections import Counter, defaultdict

from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
from sklearn.model_selection import cross_val_score, StratifiedKFold
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec

In [None]:
state_space_specs_path = "outputs/state_space_specs/librispeech-train-clean-100/w2v2_8/state_space_specs.pkl"
embeddings_path = "outputs/model_embeddings/librispeech-train-clean-100/w2v2_8/rnn_8-weightdecay0.01/phoneme_10frames/librispeech-train-clean-100.npy"

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
ds = datasets.load_from_disk("outputs/preprocessed_data/librispeech-train-clean-100") \
    .remove_columns(["audio", "file"])

In [None]:
observed_pronunciations = defaultdict(Counter)

In [None]:
def update(item):
    for word, phones in zip(item["word_detail"]["utterance"], item["word_phonemic_detail"]):
        phones = tuple(phone["phone"] for phone in phones)
        observed_pronunciations[word][phones] += 1
ds.map(update)

In [None]:
sorted(observed_pronunciations, key=lambda x: len(observed_pronunciations[x].values()), reverse=True)

In [None]:
pronunciation_stats = []
for word, pronunciations in observed_pronunciations.items():
    if not word: continue

    total = sum(pronunciations.values())
    proportions = np.array([count / total for count in pronunciations.values()])
    entropy = -np.sum(proportions * np.log(proportions))
    pronunciation_stats.append((word, total, entropy))

pronunciation_stats = pd.DataFrame(pronunciation_stats, columns=["word", "total", "entropy"])
pronunciation_stats

In [None]:
def plot_pronunciations_pca(study_word, ax=None):
    study_label_idx = state_space_spec.labels.index(study_word)
    study_instances = defaultdict(list)
    study_classes = {}
    study_X = []
    study_Y = []

    for instance_idx, rows in state_space_spec.cuts.loc[study_word].xs("phoneme", level="level").groupby("instance_idx"):
        phons = tuple(rows.description)
        if phons in study_classes:
            cls = study_classes[phons]
        else:
            cls = len(study_classes)
            study_classes[phons] = cls

        study_instances[cls].append(instance_idx)
        frame_start, frame_end = state_space_spec.target_frame_spans[study_label_idx][instance_idx]
        study_X.append(model_representations[frame_start:frame_end])
        study_Y.append(cls)

    study_X = np.array([np.mean(Xi, axis=0) for Xi in study_X])
    study_X = (study_X - study_X.mean(axis=0)) / study_X.std(axis=0)

    study_Y = np.array(study_Y)

    pca = PCA(2).fit(study_X)
    study_X_pca = pca.transform(study_X)

    if ax is None:
        fig, ax = plt.subplots()
    ax.set_title(study_word)
    for phons, idx in study_classes.items():
        ax.scatter(study_X_pca[study_Y == idx, 0], study_X_pca[study_Y == idx, 1], label=" ".join(phons), alpha=0.3)
    ax.legend()

In [None]:
plot_n = 20
n_cols = int(np.floor(np.sqrt(plot_n)))
n_rows = int(np.ceil(plot_n / n_cols))
f, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5))
for ax, (_, row) in zip(tqdm(axs.flat), pronunciation_stats.sort_values("entropy", ascending=False).head(plot_n).iterrows()):
    plot_pronunciations_pca(row["word"], ax=ax)

f.suptitle("Words with maximal entropy over pronunciations")

In [None]:
plot_n = 20
n_cols = int(np.floor(np.sqrt(plot_n)))
n_rows = int(np.ceil(plot_n / n_cols))
f, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5))
for ax, (_, row) in zip(tqdm(axs.flat), pronunciation_stats[pronunciation_stats.entropy != 0].sort_values("total", ascending=False).head(plot_n).iterrows()):
    plot_pronunciations_pca(row["word"], ax=ax)

## Num syllables

In [None]:
instance_syllable_counts = state_space_spec.cuts.xs("syllable", level="level").groupby(["label", "instance_idx"]).size()
has_syllable_variation = instance_syllable_counts.groupby("label").nunique() > 1
has_syllable_variation = has_syllable_variation[has_syllable_variation].index
has_syllable_variation

In [None]:
syllable_count_entropy = instance_syllable_counts.loc[has_syllable_variation].groupby("label").apply(lambda x: -np.sum(x / x.sum() * np.log(x / x.sum())))
syllable_count_entropy.sort_values(ascending=False)

In [None]:
def plot_syllables_pca(study_word, ax=None):
    study_label_idx = state_space_spec.labels.index(study_word)
    study_instances = defaultdict(list)
    study_classes = {}
    study_X = []
    study_Y = []

    for instance_idx, rows in state_space_spec.cuts.loc[study_word].xs("syllable", level="level").groupby("instance_idx"):
        phons = rows.description.str.join(" ").str.cat(sep="-")
        if phons in study_classes:
            cls = study_classes[phons]
        else:
            cls = len(study_classes)
            study_classes[phons] = cls

        study_instances[cls].append(instance_idx)
        frame_start, frame_end = state_space_spec.target_frame_spans[study_label_idx][instance_idx]
        study_X.append(model_representations[frame_start:frame_end])
        study_Y.append(cls)

    study_X = np.array([np.mean(Xi, axis=0) for Xi in study_X])
    study_X = (study_X - study_X.mean(axis=0)) / study_X.std(axis=0)

    study_Y = np.array(study_Y)

    pca = PCA(2).fit(study_X)
    study_X_pca = pca.transform(study_X)

    if ax is None:
        fig, ax = plt.subplots()
    ax.set_title(study_word)
    for phons, idx in study_classes.items():
        ax.scatter(study_X_pca[study_Y == idx, 0], study_X_pca[study_Y == idx, 1], label=phons, alpha=0.3)
    ax.legend()

In [None]:
plot_n = 20
n_cols = int(np.floor(np.sqrt(plot_n)))
n_rows = int(np.ceil(plot_n / n_cols))
f, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5))
for ax, word in zip(tqdm(axs.flat), syllable_count_entropy.sort_values(ascending=False).head(plot_n).index):
    plot_syllables_pca(word, ax=ax)

## Variation in onset

In [None]:
multi_phoneme_words = state_space_spec.cuts.xs("phoneme", level="level").groupby(["label", "instance_idx"]).size() > 1
multi_phoneme_words = multi_phoneme_words[multi_phoneme_words].index
multi_onset_counts = state_space_spec.cuts.xs("phoneme", level="level").loc[multi_phoneme_words].groupby(["label", "instance_idx"]).head(1).groupby("label").description.value_counts()
multi_onset_counts = multi_onset_counts.groupby("label").filter(lambda xs: len(xs) > 1)

In [None]:
multi_onset_counts.sort_values(ascending=False)

In [None]:
plot_n = 20
n_cols = int(np.floor(np.sqrt(plot_n)))
n_rows = int(np.ceil(plot_n / n_cols))
f, axs = plt.subplots(n_rows, n_cols, figsize=(n_cols * 5, n_rows * 5))
for ax, (word, _) in zip(tqdm(axs.flat), multi_onset_counts.groupby("label").filter(lambda xs: xs.sum() > 100).groupby("label").apply(lambda xs: -np.sum(xs / xs.sum() * np.log(xs / xs.sum()))).sort_values(ascending=False).head(20).items()):
    plot_pronunciations_pca(word, ax=ax)