In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict, Counter
from dataclasses import replace
import json
import logging
from pathlib import Path
import pickle

import datasets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
L = logging.getLogger(__name__)

In [None]:
base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/{train_dataset}.npy"

seed = 1234

# Add 4 frames prior to onset to each trajectory
expand_frame_window = (4, 0)

# Only use plot words for PCA or use whole vocabulary?
pca_plot_words_only = False
# Use words with this many or more instances to estimate embedding PCA space
pca_freq_min = 15
# Ignore words with this many or more instances when estimating embedding PCA space
pca_freq_max = 10000

# Use at most this many samples of each word in computing PCA (for computational efficiency)
pca_max_samples_per_word = 100

metric = "cosine"

In [None]:
np.random.seed(seed)

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
# use all words with frequency greater than cutoff to compute PCA
word_freqs = {label: len(trajs) for trajs, label in
            zip(state_space_spec.target_frame_spans, state_space_spec.labels)}

min_suffix_type_count = 10
min_prefix_type_count = 10
plot_suffixes = 3
plot_prefixes = 3

# find longest possible suffixes for which we have at least min_suffix_type_count word types matching frequency constraint
best_suffixes = []
for suffix_length in range(9):
    suffix_counts = Counter()
    for word in state_space_spec.labels:
        if len(word) < suffix_length or word_freqs[word] < pca_freq_min or word_freqs[word] >= pca_freq_max:
            continue
        suffix_counts[word[-suffix_length:]] += 1

    for suffix, count in suffix_counts.most_common():
        if count < min_suffix_type_count:
            break
        best_suffixes.append((suffix_length, count, suffix))
best_suffixes = sorted(best_suffixes, reverse=True)[:plot_suffixes]

# same for prefixes
best_prefixes = []
for prefix_length in range(9):
    prefix_counts = Counter()
    for word in state_space_spec.labels:
        if len(word) < prefix_length or word_freqs[word] < pca_freq_min or word_freqs[word] >= pca_freq_max:
            continue
        prefix_counts[word[:prefix_length]] += 1

    for prefix, count in prefix_counts.most_common():
        if count < min_prefix_type_count:
            break
        best_prefixes.append((prefix_length, count, prefix))
best_prefixes = sorted(best_prefixes, reverse=True)[:plot_prefixes]

print("Best prefixes", best_prefixes)
print("Best suffixes", best_suffixes)

In [None]:
plot_sets = {
    "exploratory": ["allow", "about", "around",
                    "before", "black", "barely",
                    "small", "said", "such",
                    "please", "people", "problem"],    
}
plot_sets.update({f"prefix_{prefix}": [word for word in state_space_spec.labels if word.startswith(prefix)
                                       and word_freqs[word] >= pca_freq_min and word_freqs[word] < pca_freq_max]
                  for _, _, prefix in best_prefixes})
plot_sets.update({f"suffix_{suffix}": [word for word in state_space_spec.labels if word.endswith(suffix)
                                       and word_freqs[word] >= pca_freq_min and word_freqs[word] < pca_freq_max]
                  for _, _, suffix in best_suffixes})

plot_words = set(word for words in plot_sets.values() for word in words)
if any(word not in state_space_spec.labels for word in plot_words):
    raise ValueError(f"Plot words not found in state space: {plot_words}")

if pca_plot_words_only:
    pca_words = plot_words
else:
    # use all words with frequency between cutoffs to compute PCA
    pca_words = sorted([(freq, label) for label, freq in word_freqs.items()
                        if freq >= pca_freq_min and freq < pca_freq_max], reverse=True)
    pca_words = [label for _, label in pca_words]

    for plot_word in plot_words:
        if plot_word not in pca_words:
            L.warn(f"Plot word {plot_word} not found in PCA words")
            pca_words.append(plot_word)

drop_idxs = [idx for idx, word in enumerate(state_space_spec.labels)
             if word not in pca_words]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

## Prepare PCA on full set

In [None]:
trajectory = prepare_state_trajectory(
    model_representations,
    state_space_spec,
    expand_window=expand_frame_window,
    pad=np.nan
)
# Subsample trajectories to reduce computation time
for i in range(len(trajectory)):
    if len(trajectory[i]) > pca_max_samples_per_word:
        subsample_idxs = np.random.choice(len(trajectory[i]), pca_max_samples_per_word, replace=False)
        trajectory[i] = trajectory[i][subsample_idxs]

all_trajectories_full = np.concatenate(trajectory)
all_trajectories_src = np.array(list(np.ndindex(all_trajectories_full.shape[:2])))

# flatten & retain non-padding
all_trajectories = all_trajectories_full.reshape(-1, all_trajectories_full.shape[-1])
retain_idxs = ~np.isnan(all_trajectories).any(axis=1)
all_trajectories = all_trajectories[retain_idxs]
all_trajectories_src = all_trajectories_src[retain_idxs]

pca = PCA(n_components=2)
pca.fit(all_trajectories)
scaler = StandardScaler().fit(all_trajectories)

all_trajectories_pca = pca.transform(scaler.transform(all_trajectories))

## State space analysis over plot sets

In [None]:
def plot_state_space(plot_key):
    plot_words = plot_sets[plot_key]
    state_space_spec_sub = state_space_spec.drop_labels([idx for idx, word in enumerate(state_space_spec.labels) if word not in plot_words])

    trajectory = prepare_state_trajectory(
        model_representations,
        state_space_spec_sub,
        expand_window=expand_frame_window,
        pad=np.nan
    )

    # Subsample trajectories to reduce computation time
    for i in range(len(trajectory)):
        if len(trajectory[i]) > pca_max_samples_per_word:
            subsample_idxs = np.random.choice(len(trajectory[i]), pca_max_samples_per_word, replace=False)
            trajectory[i] = trajectory[i][subsample_idxs]

    all_trajectories_full = np.concatenate(trajectory)
    all_trajectories_src = np.array(list(np.ndindex(all_trajectories_full.shape[:2])))

    # flatten & retain non-padding
    all_trajectories = all_trajectories_full.reshape(-1, all_trajectories_full.shape[-1])
    retain_idxs = ~np.isnan(all_trajectories).any(axis=1)
    all_trajectories = all_trajectories[retain_idxs]
    all_trajectories_src = all_trajectories_src[retain_idxs]

    # use previously fit scaler+PCA to transform these representations
    all_trajectories_pca = pca.transform(scaler.transform(all_trajectories))

    all_trajectories_pca_padded = np.full(all_trajectories_full.shape[:2] + (2,), np.nan)
    all_trajectories_pca_padded[all_trajectories_src[:, 0], all_trajectories_src[:, 1]] = all_trajectories_pca

    # get index of first nan in each item; back-fill with last value
    for idx, nan_onset in enumerate(np.isnan(all_trajectories_pca_padded)[:, :, 0].argmax(axis=1)):
        if nan_onset == 0:
            continue
        all_trajectories_pca_padded[idx, nan_onset:] = all_trajectories_pca_padded[idx, nan_onset - 1]

    # plotting helpers
    trajectory_dividers = np.cumsum([traj.shape[0] for traj in trajectory])
    trajectory_dividers = np.concatenate([[0], trajectory_dividers])
    # get just the dividers for plot_words
    plot_word_dividers = []
    for word in plot_words:
        class_idx = state_space_spec_sub.labels.index(word)
        left_edge = trajectory_dividers[class_idx]
        right_edge = trajectory_dividers[class_idx + 1] if class_idx + 1 < len(trajectory_dividers) else len(all_trajectories_pca_padded)
        plot_word_dividers.append((left_edge, right_edge))
    
    #####

    min, max = all_trajectories_pca.min(), all_trajectories_pca.max()

    # Animate
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation

    fig, ax = plt.subplots()
    ax.set_xlim(np.floor(min), np.ceil(max))
    ax.set_ylim(np.floor(min), np.ceil(max))
    annot_frame = ax.text(-0.75, 0.75, "-1")

    color_classes = sorted(set(word for word in plot_words))
    cmap = sns.color_palette("Set1", len(color_classes))
    color_values = {class_: cmap[i] for i, class_ in enumerate(color_classes)}
    marker_values = {class_: "o" if i % 2 == 0 else "x" for i, class_ in enumerate(color_classes)}

    scats = [ax.scatter(np.zeros(end - start + 1), np.zeros(end - start + 1),
                        alpha=0.5,
                        marker=marker_values[word],
                        color=color_values[word],
                    ) for i, (word, (start, end)) in enumerate(zip(plot_words, plot_word_dividers))]
    ax.legend(scats, plot_words, loc=1)

    def init():
        for scat in scats:
            scat.set_offsets(np.zeros((0, 2)))
        return tuple(scats)

    def update(frame):
        for scat, (idx_start, idx_end) in zip(scats, plot_word_dividers):
            traj_i = all_trajectories_pca_padded[idx_start:idx_end, frame]
            scat.set_offsets(traj_i)
            # scat.set_array(np.arange(traj_i.shape[0]))
        annot_frame.set_text(str(frame))
        return tuple(scats) + (annot_frame,)

    # Animate by model frame
    num_frames = all_trajectories_pca_padded.shape[1]
    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=500,
                                init_func=init)
    ani.save(Path(output_dir) / f"state_space-{plot_key}.gif", writer="ffmpeg")

In [None]:
for key in tqdm(plot_sets):
    plot_state_space(key)

## Analysis on trajectory aggregates

The brain encoding linking analysis assumes that the mean over representational trajectory is meaningful. Is it?

In [None]:
agg_fns = [
    "mean", "max", "last_frame",
    ("mean_last_k", 2), ("mean_last_k", 5),
]

In [None]:
pca_aggs = {}
pca_agg_results = {}
for name in tqdm(agg_fns):
    pca_aggs[name] = PCA(n_components=2)
    pca_data = np.concatenate(aggregate_state_trajectory(trajectory, state_space_spec, name))
    pca_agg_results[name] = pca_aggs[name].fit_transform(pca_data)

In [None]:
all_trajectories_pca_padded = np.full(all_trajectories_full.shape[:2] + (2,), np.nan)
all_trajectories_pca_padded[all_trajectories_src[:, 0], all_trajectories_src[:, 1]] = all_trajectories_pca

# plotting helpers
trajectory_dividers = np.cumsum([traj.shape[0] for traj in trajectory])
trajectory_dividers = np.concatenate([[0], trajectory_dividers])
# get just the dividers for plot_words
plot_word_dividers = []
for word in plot_words:
    class_idx = state_space_spec.labels.index(word)
    left_edge = trajectory_dividers[class_idx]
    right_edge = trajectory_dividers[class_idx + 1] if class_idx + 1 < len(trajectory_dividers) else len(all_trajectories_pca_padded)
    plot_word_dividers.append((left_edge, right_edge))

In [None]:
color_classes = sorted(set(word for word in plot_words))
cmap = sns.color_palette("Set1", len(color_classes))
color_values = {class_: cmap[i] for i, class_ in enumerate(color_classes)}
marker_values = {class_: "o" if i % 2 == 0 else "x" for i, class_ in enumerate(color_classes)}

f, axs = plt.subplots(len(agg_fns), 1, figsize=(7, 5 * len(agg_fns)))
for ax, (agg_fn, trajectory_agg_pca) in zip(axs.ravel(), pca_agg_results.items()):
    for word in plot_words:
        class_idx = state_space_spec.labels.index(word)
        traj_i = trajectory_agg_pca[trajectory_dividers[class_idx]:trajectory_dividers[class_idx + 1]]
        ax.scatter(*traj_i.T, color=color_values[word], marker=marker_values[word], alpha=0.5,
                   label=word)
    ax.set_title(agg_fn)
    ax.legend(loc="upper right", bbox_to_anchor=(1.2, 1))