In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict, Counter
from dataclasses import replace
import json
import logging
from pathlib import Path
import pickle

import datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import Normalizer
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
L = logging.getLogger(__name__)

In [None]:
base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/{train_dataset}.npy"

seed = 1234

# Add 4 frames prior to onset to each trajectory
expand_frame_window = (4, 0)

# Only use plot words for PCA or use whole vocabulary?
pca_plot_words_only = False
# Use words with this many or more instances to estimate embedding PCA space
pca_freq_min = 15
# Ignore words with this many or more instances when estimating embedding PCA space
pca_freq_max = 10000

# Use at most this many samples of each word in computing PCA (for computational efficiency)
pca_max_samples_per_word = 100

agg_method = "mean"

metric = "cosine"

In [None]:
word_freq_df = pd.read_csv("data/SUBTLEXus74286wordstextversion.txt", sep="\t", index_col=0)

In [None]:
np.random.seed(seed)

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
# use all words with frequency greater than cutoff to compute PCA
word_freqs = {label: len(trajs) for trajs, label in
            zip(state_space_spec.target_frame_spans, state_space_spec.labels)}

In [None]:
# use all words with frequency between cutoffs to compute PCA
pca_words = sorted([(freq, label) for label, freq in word_freqs.items()
                    if freq >= pca_freq_min and freq < pca_freq_max], reverse=True)
pca_words = [label for _, label in pca_words]

drop_idxs = [idx for idx, word in enumerate(state_space_spec.labels)
             if word not in pca_words]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

In [None]:
# subsample word instances
state_space_spec = state_space_spec.subsample_instances(pca_max_samples_per_word)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)
traj_agg = aggregate_state_trajectory(trajectory, state_space_spec, agg_method, keepdims=True)
agg_flat, agg_src = flatten_trajectory(traj_agg)

In [None]:
pipeline = make_pipeline(Normalizer(), PCA(n_components=16))
all_trajectories_pca = pipeline.fit_transform(agg_flat)

pca = pipeline.named_steps["pca"]

In [None]:
f, ax = plt.subplots(figsize=(12, 4))
ax.plot([0] + np.cumsum(pca.explained_variance_ratio_).tolist())
ax.set_title("PCA explained variance")
ax.set_xlabel("Number of components")
ax.set_ylim((0, 1))
ax.set_ylabel("Cumulative explained variance")

In [None]:
pca_df = pd.DataFrame(
    [(label_idx, instance_idx, *pca_coords)
     for (label_idx, instance_idx, _), pca_coords
     in zip(agg_src, all_trajectories_pca)],
     columns=["label_idx", "instance_idx"] + [f"pca_{i}" for i in range(pca.n_components_)])
pca_df["label"] = [state_space_spec.labels[label_idx] for label_idx in pca_df["label_idx"]]
pca_df = pca_df.set_index(["label", "instance_idx"])

In [None]:
first_phoneme = state_space_spec.cuts.xs("phoneme", level="level").groupby(["label", "instance_idx"]).head(1).description.rename("first_phoneme")
last_phoneme = state_space_spec.cuts.xs("phoneme", level="level").groupby(["label", "instance_idx"]).tail(1).description.rename("last_phoneme")
num_phonemes = state_space_spec.cuts.xs("phoneme", level="level").groupby(["label", "instance_idx"]).size().rename("num_phonemes")
num_syllables = state_space_spec.cuts.xs("syllable", level="level").groupby(["label", "instance_idx"]).size().rename("num_syllables")
meta = pd.concat([first_phoneme, last_phoneme, num_phonemes, num_syllables], axis=1)
pca_df = pd.merge(pca_df, meta, left_index=True, right_index=True)

pca_df = pca_df.reset_index()

In [None]:
pca_df = pd.merge(pca_df, word_freq_df["Lg10WF"], left_on="label", right_index=True)

In [None]:
pca_df["log_frequency_bin"] = pd.qcut(pca_df["Lg10WF"], 10, labels=False)

In [None]:
pca_df[[f"pca_{i}" for i in range(pca.n_components_)] + ["Lg10WF", "num_phonemes", "num_syllables"]].corr()

In [None]:
pca_df["hyp_0a"] = pca_df.first_phoneme.isin(("CH", "SH", "K", "P", "T"))
pca_df["hyp_0b"] = pca_df.label.str.startswith(("exp", "exe", "exa", "exc", "enc", "ext", "aca"))
pca_df["hyp_0"] = pca_df.hyp_0a | pca_df.hyp_0b

# pca_df["hyp_1a"] = pca_df.last_phoneme.isin(("T", "D"))
pca_df["hyp_1b"] = pca_df.first_phoneme.isin(("M", "N", "OY", "Y", "Z"))
pca_df["hyp_1c"] = pca_df.last_phoneme.isin(("Z", "JH", "ER"))
pca_df["hyp_1"] = pca_df.hyp_1b | pca_df.hyp_1c

pca_df["hyp_2"] = pca_df.num_syllables  # better correlation than log-freq, and # phonemes

In [None]:
meta_cols = [col for col in pca_df.columns if col not in ["label", "instance_idx"] and not col.startswith("pca_")]
pca_type_df = pca_df.groupby(["label"] + meta_cols)[[f"pca_{i}" for i in range(pca.n_components_)]].agg(["mean", "std"]).reset_index()
pca_type_df.columns = ["_".join(col).strip("_") for col in pca_type_df.columns.values]

In [None]:
pca_type_df

In [None]:
sns.scatterplot(data=pca_df, x="pca_0", y="pca_1", hue="hyp_0", s=5, alpha=0.5)

In [None]:
sns.scatterplot(data=pca_type_df, x="pca_2_mean", y="pca_3_mean", hue="hyp_2", alpha=0.4)

In [None]:
# relationship between variance along a PC and word frequency
ax = sns.lineplot(data=pca_df.groupby("log_frequency_bin").apply(lambda xs: xs.groupby("label")[[f"pca_{i}" for i in range(16)]].std().mean(axis=0)).reset_index().melt(id_vars="log_frequency_bin"),
             x="log_frequency_bin", y="value", hue="variable")
ax.set_title("Variance along PCA components by log frequency bin")

In [None]:
sns.displot(data=pca_type_df, x="pca_0_mean", hue="first_phoneme", kind="ecdf",
            hue_order=pca_df.groupby("first_phoneme").pca_0.mean().sort_values().index)

### study PC 1

In [None]:
sns.displot(data=pca_type_df, x="pca_1_mean", hue="first_phoneme", kind="ecdf",
            hue_order=pca_type_df.groupby("first_phoneme").pca_1_mean.mean().sort_values().index)

In [None]:

sns.displot(data=pca_type_df, x="pca_1_mean", hue="last_phoneme", kind="ecdf",
            hue_order=pca_type_df.groupby("last_phoneme").pca_1_mean.mean().sort_values().index)

In [None]:
sns.scatterplot(data=pca_type_df, x="pca_0_mean", y="pca_1_mean", hue="hyp_1", alpha=0.5)

In [None]:
sns.scatterplot(data=pca_type_df[(pca_type_df.pca_0_mean > 1) & (pca_type_df.pca_1_mean > 2)],
                x="pca_0_mean", y="pca_1_mean", alpha=0.5)
for label, row in pca_type_df[(pca_type_df.pca_0_mean > 1) & (pca_type_df.pca_1_mean > 2)].iterrows():
    plt.text(row.pca_0_mean, row.pca_1_mean, row.label, fontsize=6)