In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dataclasses import replace
import json
from pathlib import Path
import pickle

import datasets
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import torch

from src.analysis.state_space import StateSpaceAnalysisSpec, prepare_state_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
base_model = "w2v2_6"
model_class = "rnn_8-aniso1"
model_name = "word_broad"
model_dir = f"outputs/models/timit/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"outputs/notebooks/timit/{base_model}/{model_class}/{model_name}_10frames/state_space"
dataset_path = "outputs/preprocessed_data/timit"
equivalence_path = f"outputs/equivalence_datasets/timit/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/timit/{base_model}/hidden_states.pkl"
state_space_specs_path = f"outputs/state_space_specs/timit/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/timit/{base_model}/{model_class}/{model_name}_10frames/embeddings.npy"

# Add 4 frames prior to onset to each trajectory
expand_frame_window = (4, 0)

metric = "cosine"

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
# vocab = sorted([(len(trajs), label) for trajs, label in zip(state_space_spec.target_frame_spans, state_space_spec.labels)], reverse=True)
# import itertools
# {k: sorted([(count, word) for count, word in vs if len(word) > 3], reverse=True)
#  for k, vs in itertools.groupby(sorted(vocab, key=lambda x: x[1]), key=lambda x: x[1][0])}["p"]

In [None]:
# use top N words to compute broader PCA
pca_num_words = 200
pca_words = [label for _, label in sorted([(len(trajs), label) for trajs, label in zip(state_space_spec.target_frame_spans, state_space_spec.labels)], reverse=True)[:200]]
# ..but just plot these words in the resulting space
plot_words = ["allow", "about", "around",
              "before", "black", "barely",
              "small", "said", "such",
              "please", "people", "problem"]
if any(word not in state_space_spec.labels for word in plot_words):
    raise ValueError(f"Plot words not found in state space: {plot_words}")
drop_idxs = [idx for idx, word in enumerate(state_space_spec.labels)
             if word not in pca_words + plot_words]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

In [None]:
trajectory = prepare_state_trajectory(
    model_representations,
    state_space_spec,
    expand_window=expand_frame_window,
    pad=np.nan
)

In [None]:
all_trajectories_full = np.concatenate(trajectory)
all_trajectories_src = np.array(list(np.ndindex(all_trajectories_full.shape[:2])))

# flatten & retain non-padding
all_trajectories = all_trajectories_full.reshape(-1, all_trajectories_full.shape[-1])
retain_idxs = ~np.isnan(all_trajectories).any(axis=1)
all_trajectories = all_trajectories[retain_idxs]
all_trajectories_src = all_trajectories_src[retain_idxs]

In [None]:
pca = PCA(n_components=2)
pca.fit(all_trajectories)

all_trajectories_pca = pca.transform(all_trajectories)

In [None]:
all_trajectories_pca_padded = np.full(all_trajectories_full.shape[:2] + (2,), np.nan)
all_trajectories_pca_padded[all_trajectories_src[:, 0], all_trajectories_src[:, 1]] = all_trajectories_pca

# get index of first nan in each item; back-fill with last value
for idx, nan_onset in enumerate(np.isnan(all_trajectories_pca_padded)[:, :, 0].argmax(axis=1)):
    if nan_onset == 0:
        continue
    all_trajectories_pca_padded[idx, nan_onset:] = all_trajectories_pca_padded[idx, nan_onset - 1]

In [None]:
trajectory_dividers = np.cumsum([traj.shape[0] for traj in trajectory])
trajectory_dividers = np.concatenate([[0], trajectory_dividers])
# get just the dividers for plot_words
plot_word_dividers = []
for word in plot_words:
    class_idx = state_space_spec.labels.index(word)
    left_edge = trajectory_dividers[class_idx]
    right_edge = trajectory_dividers[class_idx + 1] if class_idx + 1 < len(trajectory_dividers) else len(all_trajectories_pca_padded)
    plot_word_dividers.append((left_edge, right_edge))
plot_word_dividers

In [None]:
min, max = all_trajectories_pca.min(), all_trajectories_pca.max()

In [None]:
# Animate
import matplotlib.pyplot as plt
import matplotlib.animation as animation

fig, ax = plt.subplots()
ax.set_xlim(np.floor(min), np.ceil(max))
ax.set_ylim(np.floor(min), np.ceil(max))
annot_frame = ax.text(-0.75, 0.75, "-1")

color_classes = sorted(set(word[0] for word in plot_words))
color_values = {class_: i for i, class_ in enumerate(color_classes)}

cmap = plt.get_cmap("Set1")
scats = [ax.scatter(np.zeros(end - start + 1), np.zeros(end - start + 1),
                    alpha=0.5,
                    marker="o",
                    color=cmap(color_values[word[0]]),
                   ) for i, (word, (start, end)) in enumerate(zip(plot_words, plot_word_dividers))]
ax.legend(scats, plot_words, loc=1)

def init():
    for scat in scats:
        scat.set_offsets(np.zeros((0, 2)))
    return tuple(scats)

def update(frame):
    for scat, (idx_start, idx_end) in zip(scats, plot_word_dividers):
        traj_i = all_trajectories_pca_padded[idx_start:idx_end, frame]
        scat.set_offsets(traj_i)
        # scat.set_array(np.arange(traj_i.shape[0]))
    annot_frame.set_text(str(frame))
    return tuple(scats) + (annot_frame,)

# Animate by model frame
num_frames = all_trajectories_pca_padded.shape[1]
ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=500,
                              init_func=init)
ani.save(Path(output_dir) / "state_space.gif", writer="ffmpeg")

In [None]:
from IPython.display import HTML
HTML(ani.to_jshtml())