In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
from collections import defaultdict, Counter
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec, flatten_trajectory
from src.utils.timit import get_word_metadata

In [None]:
model_dir = "outputs/models/librispeech-train-clean-100/w2v2_8/rnn_32-hinge-mAP4/word_broad_10frames_fixedlen25"
output_dir = "."
dataset_path = "outputs/preprocessed_data/librispeech-train-clean-100"
equivalence_path = "outputs/equivalence_datasets/librispeech-train-clean-100/w2v2_8/word_broad_10frames_fixedlen25/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/librispeech-train-clean-100/w2v2_8/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/librispeech-train-clean-100/w2v2_8/state_space_specs.h5"
embeddings_path = "outputs/model_embeddings/librispeech-train-clean-100/w2v2_8/rnn_32-hinge-mAP4/word_broad_10frames_fixedlen25/librispeech-train-clean-100.npy"

metric = "cosine"

# Retain words with N or more instances
retain_n = 10

subsample_instances = 50

model_sfreq = 50

expand_window = (15, 0)

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
label_counts = state_space_spec.label_counts
drop_labels = label_counts[label_counts < retain_n].index
state_space_spec = state_space_spec.drop_labels(drop_names=drop_labels)

In [None]:
state_space_spec = state_space_spec.subsample_instances(subsample_instances)

In [None]:
metadata = get_word_metadata(state_space_spec)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan,
                                      expand_window=expand_window)
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
traj_full, traj_full_flat_src = flatten_trajectory(trajectory)

In [None]:
len(trajectory), np.concatenate(lengths).mean()

In [None]:
pca_full = PCA(n_components=4)
traj_full_flat_pca = pca_full.fit_transform(traj_full)

#### Prepare truncated trajectory data

In [None]:
traj_trunc = [traj_i[:, :(expand_window[0] + 15)] for traj_i in trajectory]
traj_trunc_flat, traj_trunc_flat_src = flatten_trajectory(traj_trunc)

In [None]:
trunc_times = np.concatenate([-np.arange(expand_window[0], -1, -1), np.arange(1, 15)])

In [None]:
pca = PCA(n_components=4)
traj_trunc_flat_pca = pca.fit_transform(traj_trunc_flat)

## State space

In [None]:
traj_full_flat_src_dict = {src: i for i, src in enumerate(map(tuple, traj_full_flat_src))}

In [None]:
def plot_state_space_binned(n_bins, groupby=None, return_data=False):
    all_lengths = np.concatenate(lengths)
    max_traj_length = all_lengths.max()

    # bin word tokens by length
    length_bins = pd.qcut(all_lengths, q=n_bins, labels=np.arange(n_bins), retbins=True)[1]
    bin_time_edges = np.maximum(0, length_bins) / model_sfreq
    bin_assignments = [np.digitize(traj_lengths, length_bins, right=True) - 1
                       for traj_lengths in lengths]    
    all_bin_assignments = np.concatenate(bin_assignments)
    all_bin_assignments = all_bin_assignments[all_bin_assignments >= 0]

    # key := bin + grouping variables
    # build reverse map from key and frame index to list of (label_idx, instance_idx) tuples
    bin_assignments_rev = defaultdict(list)
    if groupby is not None:
        group_lookup = metadata[groupby].to_dict()
    for label_idx, assignments_i in enumerate(bin_assignments):
        for j, bin_idx in enumerate(assignments_i):
            if bin_idx < 0:
                continue

            if groupby is not None:
                group_value = group_lookup[state_space_spec.labels[label_idx], j]
                if group_value is None or (isinstance(group_value, float) and np.isnan(group_value)):
                    continue
                key = (bin_idx, group_value)
            else:
                key = (bin_idx,)
            bin_assignments_rev[key].append((label_idx, j))

    # Prepare per-frame and per-key means
    bin_frame_means = defaultdict(list)
    # optionally store all generating data
    bin_frame_data = defaultdict(list)
    bin_frame_src = defaultdict(list)
    for key, traj_indices in tqdm(bin_assignments_rev.items(), desc="retrieving per-bin data"):
        for frame_idx in range(max_traj_length):
            flat_idxs = [traj_full_flat_src_dict[(label_idx, instance_idx, frame_idx)]
                        for label_idx, instance_idx in traj_indices
                        if (label_idx, instance_idx, frame_idx) in traj_full_flat_src_dict]
            
            if len(flat_idxs) == 0:
                bin_frame_means[key].append(np.full(traj_full_flat_pca.shape[1], np.nan))
                bin_frame_data[key].append([])
                bin_frame_src[key].append([])
            else:
                bin_frame_means[key].append(traj_full_flat_pca[flat_idxs].mean(axis=0))
                bin_frame_data[key].append(traj_full_flat_pca[flat_idxs])
                bin_frame_src[key].append(flat_idxs)

    bin_frame_means = {key: np.array(frame_means_i) for key, frame_means_i in bin_frame_means.items()}

    ## Plot
    pcs = [[0, 1], [2, 3]]
    f, axs = plt.subplots(1, len(pcs), figsize=(10 * len(pcs), 10), squeeze=False)

    # bin_colors = sns.color_palette("tab10", n_bins)
    # get normalized continuous hue for bin edges
    bin_colors = sns.color_palette("spring", n_bins)
    grouping_values = sorted(set(key[1:] for key in bin_frame_means.keys()))
    grouping_styles = ["-", "--", "-.", ":", (0, (3, 5, 1, 5)), (0, (3, 5, 1, 5, 1, 5))]
    assert len(grouping_values) <= len(grouping_styles)

    for i, (pcs_i, ax) in enumerate(zip(pcs, axs.flat)):
        ax.axvline(0, color="black", linestyle="--")
        ax.axhline(0, color="black", linestyle="--")

        i, j = pcs_i
        ax.set_xlabel(f"PC {i+1}")
        ax.set_ylabel(f"PC {j+1}")

        k = 0
        for key in sorted(bin_frame_means.keys()):
            bin_idx = key[0]
            grouping_values_ij = key[1:]
            means = bin_frame_means[key]
            bin_edge = bin_time_edges[bin_idx]
            ax.plot(means[:, i], means[:, j], label=f"{key[1:]} {bin_edge:.2f} s",
                    color=bin_colors[bin_idx],
                    linestyle=grouping_styles[grouping_values.index(grouping_values_ij)],
                    alpha=0.7)

            ax.quiver(means[:-1, i], means[:-1, j], means[1:, i] - means[:-1, i], means[1:, j] - means[:-1, j],
                        angles='xy', scale_units='xy', scale=1, color="gray", alpha=0.5)
            
            # O at start of word
            word_start_frame = 0 + expand_window[0]
            ax.scatter(means[word_start_frame, i], means[word_start_frame, j], color="blue", marker="o")

            # X at middle of word
            if np.isnan(means).any():
                max_length = np.isnan(means).argmax(0).min() - 1
            else:
                max_length = means.shape[0]
            word_midpoint_frame = (max_length - word_start_frame) // 2 + word_start_frame
            ax.scatter(means[word_midpoint_frame, i], means[word_midpoint_frame, j], color="red", marker="x")

        if i == len(pcs) - 1:
            handles, labels = [], []
            from matplotlib import patches as mpatches
            from matplotlib.lines import Line2D
            handles.append(mpatches.Patch(color="white", label=""))
            labels.append("length bin")
            for bin in range(n_bins):
                handles.append(mpatches.Patch(color=bin_colors[bin], label=f"bin {bin}"))
                labels.append(f"{bin_time_edges[bin]:.2f} s")
            handles.append(mpatches.Patch(color="white", label=""))
            labels.append(groupby)
            for group, style in zip(grouping_values, grouping_styles):
                handles.append(Line2D([0], [0], color="black", linestyle=style))
                labels.append(group)
            ax.legend(handles, labels, loc="upper right", bbox_to_anchor=(1.5, 1))

    if return_data:
        return f, bin_frame_data, bin_frame_src, bin_time_edges
    else:
        return f

In [None]:
f, ss_data_all, ss_src_all, ss_edges_all = plot_state_space_binned(5, return_data=True)

In [None]:
f, ss_data_stress, ss_src_stress, ss_edges_stress = plot_state_space_binned(5, groupby="stress_primary_initial", return_data=True)

In [None]:
f, ss_data_freq, ss_src_freq, ss_edges_freq = plot_state_space_binned(5, groupby="word_frequency_quantile", return_data=True)

## Basic plots

In [None]:
pca_full.explained_variance_ratio_

In [None]:
traj_idxs_by_frame = []
for frame_idx in range(traj_full_flat_src[:, 2].max() + 1):
    traj_idxs_by_frame.append(np.where(traj_full_flat_src[:, 2] == frame_idx)[0])

In [None]:
frame_means, frame_sds, frame_counts = [], [], []
for frame_idx, traj_idxs_i in enumerate(traj_idxs_by_frame):
    frame_means.append(traj_full_flat_pca[traj_idxs_i].mean(axis=0))
    frame_sds.append(traj_full_flat_pca[traj_idxs_i].std(axis=0))
    frame_counts.append(len(traj_idxs_i))

frame_means = np.array(frame_means)
frame_sds = np.array(frame_sds)
frame_counts = np.array(frame_counts)

In [None]:
f, ax = plt.subplots()

times = np.arange(-expand_window[0], frame_means.shape[0] - expand_window[0]) / model_sfreq
ax.axvline(0, color="gray", linestyle="--")
ax.axhline(0, color="gray", linestyle="--")

for component in range(frame_means.shape[1]):
    ax.plot(times, frame_means[:, component], label=f"PC {component+1}")
    ax.fill_between(times,
                    frame_means[:, component] - frame_sds[:, component] / np.sqrt(frame_counts),
                    frame_means[:, component] + frame_sds[:, component] / np.sqrt(frame_counts),
                    alpha=0.3)

ax.legend()
ax.set_xlabel("Distance from word boundary")

### By stress

In [None]:
metadata["label_idx"] = metadata.index.get_level_values("label").map({label: idx for idx, label in enumerate(state_space_spec.labels)})

In [None]:
traj_full_flat_src_dict = {src: i for i, src in enumerate(map(tuple, traj_full_flat_src))}

In [None]:
def plot_boundary_grouped(grouping_variable, num_components=4):
    grouped_frame_means, grouped_frame_sds, grouped_frame_counts = {}, {}, {}
    for group_values, rows in tqdm(metadata.reset_index().set_index(["label_idx", "instance_idx"]).groupby(grouping_variable)):
        matched_word_instances = rows.index
        
        frame_means_i, frame_sds_i, frame_counts_i = [], [], []
        for frame_idx in range(traj_full_flat_src[:, 2].max() + 1):
            traj_idxs_i = [traj_full_flat_src_dict[(label_idx, instance_idx, frame_idx)] for label_idx, instance_idx in matched_word_instances
                        if (label_idx, instance_idx, frame_idx) in traj_full_flat_src_dict]
            if len(traj_idxs_i) == 0:
                frame_means_i.append(np.full(traj_full_flat_pca.shape[1], np.nan))
                frame_sds_i.append(np.full(traj_full_flat_pca.shape[1], np.nan))
                frame_counts_i.append(0)
            else:
                frame_means_i.append(traj_full_flat_pca[traj_idxs_i].mean(axis=0))
                frame_sds_i.append(traj_full_flat_pca[traj_idxs_i].std(axis=0))
                frame_counts_i.append(len(traj_idxs_i))

        grouped_frame_means[group_values] = np.array(frame_means_i)
        grouped_frame_sds[group_values] = np.array(frame_sds_i)
        grouped_frame_counts[group_values] = np.array(frame_counts_i)

    f, ax = plt.subplots()

    times = np.arange(-expand_window[0], frame_means.shape[0] - expand_window[0]) / model_sfreq
    ax.axvline(0, color="gray", linestyle="--")
    ax.axhline(0, color="gray", linestyle="--")

    component_palette = sns.color_palette("tab10", frame_means.shape[1])
    group_palette = ["-", "--", ":"]
    for i, group in enumerate(grouped_frame_means):
        for component in range(min(num_components, frame_means.shape[1])):
            ax.plot(times, grouped_frame_means[group][:, component],
                    linestyle=group_palette[i], color=component_palette[component],
                    label=f"{group}-{component}")
            ax.fill_between(times,
                            grouped_frame_means[group][:, component] - grouped_frame_sds[group][:, component] / np.sqrt(grouped_frame_counts[group]),
                            grouped_frame_means[group][:, component] + grouped_frame_sds[group][:, component] / np.sqrt(grouped_frame_counts[group]),
                            alpha=0.3)

    handles, labels = [], []
    from matplotlib import patches as mpatches
    from matplotlib.lines import Line2D
    handles.append(mpatches.Patch(color="white", label=""))
    labels.append("component")
    for component in range(min(num_components, frame_means.shape[1])):
        handles.append(mpatches.Patch(color=component_palette[component], label=str(component + 1)))
        labels.append(str(component + 1))
    handles.append(mpatches.Patch(color="white", label=""))
    labels.append(grouping_variable)
    for i, group in enumerate(grouped_frame_means):
        handles.append(Line2D([0], [0], color="black", linestyle=group_palette[i]))
        labels.append(group)
    ax.legend(handles, labels, loc="upper right", bbox_to_anchor=(1.5, 1))

    return f, ax

In [None]:
f, ax = plot_boundary_grouped("stress_primary_initial")
ax.axvline(4 / model_sfreq)

### Frequency

In [None]:
f, ax = plot_boundary_grouped("word_frequency_quantile")
ax.axvline(4 / model_sfreq)

### Onset category

In [None]:
categorization = {
    "consonant": "B CH D DH F G HH JH K L M N NG P R S SH T TH V W Y Z ZH".split(" "),
    "vowel": "AA AE AH AO AW AY EH ER EY IH IY OW OY UH UW".split(" "),
}
category_lookup = {label: category for category, labels in categorization.items() for label in labels}
metadata["onset_phoneme_category"] = metadata.onset_phoneme.map(category_lookup)

In [None]:
f, ax = plot_boundary_grouped("onset_phoneme_category")
ax.axvline(4 / model_sfreq)

## Cluster trajs

In [None]:
# traj_cluster_trunc = [traj_i[:, cluster_sample_idxs] for traj_i in traj_trunc]
# traj_cluster_flat, traj_cluster_src = flatten_trajectory(traj_cluster_trunc)
traj_cluster_trunc = traj_trunc
traj_cluster_flat, traj_cluster_src = traj_trunc_flat, traj_trunc_flat_src

In [None]:
# pca_cluster = PCA(n_components=4)
# traj_cluster_flat_pca = pca_cluster.fit_transform(traj_cluster_flat)
pca_cluster = pca
traj_cluster_flat_pca = traj_trunc_flat_pca

In [None]:
# chunk points: wherever 1st or second column of traj_cluster_src changes
chunk_points = np.where(np.any(np.diff(traj_cluster_src[:, :2], axis=0), axis=1))[0] + 1
traj_cluster_word_level = np.split(traj_cluster_flat_pca, chunk_points)

In [None]:
traj_cluster_word_level_src = np.split(traj_cluster_src, chunk_points)
traj_cluster_word_level_src = np.array([src_i[0][:2] for src_i in traj_cluster_word_level_src])

In [None]:
maxlen = max(map(len, traj_cluster_word_level))
traj_cluster_mat = np.zeros((len(traj_cluster_word_level), maxlen, traj_cluster_flat_pca.shape[1]))
for i, traj_i in enumerate(tqdm(traj_cluster_word_level)):
    traj_cluster_mat[i, :traj_i.shape[0], :] = traj_i

In [None]:
# baseline around center
center_frame = traj_cluster_mat.shape[1] // 2
baseline_data = traj_cluster_mat[:, center_frame - 4: center_frame + 4, :].mean(axis=1, keepdims=True)
traj_cluster_mat -= baseline_data

In [None]:
from sklearn.cluster import KMeans
km = KMeans(n_clusters=3, n_init="auto")
km.fit(traj_cluster_mat.reshape(traj_cluster_mat.shape[0], -1))

In [None]:
centroids = km.cluster_centers_.reshape((km.cluster_centers_.shape[0], *traj_cluster_mat.shape[1:]))

# add baseline back in
centroids += baseline_data.mean(0)[-1][None, None, :]

f, axs = plt.subplots(1, centroids.shape[2], figsize=(6 * centroids.shape[2], 4))
for i, (ax, centroid) in enumerate(zip(axs, centroids.transpose(2, 0, 1))):
    ax.axvline(0, color="black", linestyle="--")
    ax.axhline(0, color="black", linestyle="--")
    ax.set_title(f"PCA {i + 1}")
    for j, c in enumerate(centroid):
        ax.plot(trunc_times, c, label=f"cluster {j}")
    ax.legend()

f.tight_layout()

In [None]:
np.bincount(km.labels_)

In [None]:
from collections import Counter

match_df = {}
for km_label in range(len(km.cluster_centers_)):
    matches = np.where(km.labels_ == km_label)[0]
    matches_src = traj_cluster_word_level_src[matches]
    matches_labels = [state_space_spec.labels[idx] for idx in matches_src[:, 0]]
    match_df[km_label] = pd.DataFrame(Counter(matches_labels).items(), columns=["label", "count"])
match_df = pd.concat(match_df, names=["cluster"]).droplevel(-1)
cluster_df = pd.merge(match_df.reset_index(), metadata.groupby("label").first().reset_index(), on="label")
cluster_df["word_log_frequency"] = np.log10(cluster_df.word_frequency)

In [None]:
sns.barplot(data=cluster_df.groupby("cluster")["word_frequency_quantile"].value_counts(normalize=True).reset_index(),
            hue="word_frequency_quantile", x="cluster", y="proportion")

In [None]:
sns.barplot(data=cluster_df, x="cluster", y="word_log_frequency")

In [None]:
sns.barplot(data=cluster_df.groupby("cluster")["num_syllables"].value_counts(normalize=True).reset_index(),
            hue="num_syllables", x="cluster", y="proportion")

In [None]:
cluster_df.groupby("cluster").onset_phoneme.value_counts(normalize=True).to_frame().unstack().T.plot(kind="bar")

In [None]:
cluster_df.groupby("cluster").stress_primary_initial.value_counts(normalize=True).to_frame().unstack().T.plot(kind="bar")

In [None]:
cluster_df.groupby("cluster").sample(10)

In [None]:
# cluster entropy within word type
word_cluster_entropy = cluster_df.groupby("label").cluster.value_counts(normalize=True).groupby("label").apply(lambda x: -np.sum(x * np.log2(x)))
word_cluster_entropy.sort_values().head(20)

In [None]:
word_cluster_entropy.sort_values().tail(20)