In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
from collections import defaultdict, Counter
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec, flatten_trajectory
from src.utils.timit import get_word_metadata

In [None]:
model_dir = "outputs/models/librispeech-train-clean-100/w2v2_8/discrim-rnn_32-mAP1/word_broad_10frames_fixedlen25"
output_dir = "."
dataset_path = "outputs/preprocessed_data/librispeech-train-clean-100"
equivalence_path = "outputs/equivalence_datasets/librispeech-train-clean-100/w2v2_8/word_broad_10frames_fixedlen25/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/librispeech-train-clean-100/w2v2_8/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/librispeech-train-clean-100/w2v2_8/state_space_specs.h5"
embeddings_path = "outputs/model_embeddings/librispeech-train-clean-100/w2v2_8/discrim-rnn_32-mAP1/word_broad_10frames_fixedlen25/librispeech-train-clean-100.npy"

metric = "cosine"

# Retain words with N or more instances
retain_n = 10

subsample_instances = 50

model_sfreq = 50

expand_window = (15, 0)

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec_syll = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "syllable")
state_space_spec_word = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
assert state_space_spec_syll.is_compatible_with(model_representations)

In [None]:
syll_from_word = state_space_spec_word.cuts.xs("syllable", level="level")
syll_from_word["instance_diff"] = syll_from_word.index.get_level_values("instance_idx").diff()
# NB this breaks on single-syllable words because of adjacent rows with the same instance_idx,
# which is fine; we want to exclude those
syll_from_word["word_initial"] = syll_from_word["instance_diff"].isna() | (syll_from_word["instance_diff"] != 0)
syll_from_word.loc["reaction"]

In [None]:
word_initial = syll_from_word.reset_index().set_index("onset_frame_idx").word_initial.to_dict()
word_initial

In [None]:
state_space_spec = state_space_spec_syll

In [None]:
label_counts = state_space_spec.label_counts
drop_labels = label_counts[label_counts < retain_n].index
state_space_spec = state_space_spec.drop_labels(drop_names=drop_labels)

In [None]:
state_space_spec = state_space_spec.subsample_instances(subsample_instances)

In [None]:
# post-hoc hack: retain only non-word-initial syllables
to_drop = set((label_idx, j)
           for label_idx, spans_i in enumerate(state_space_spec.target_frame_spans)
           for j, (start_frame, end_frame) in enumerate(spans_i)
           if word_initial[start_frame])
print(f"Dropping {len(to_drop)} word-initial syllables")
new_target_frame_spans = [[span_j for j, span_j in enumerate(spans_i) if (i, j) not in to_drop]
                           for i, spans_i in enumerate(state_space_spec.target_frame_spans)]
from dataclasses import replace
ss_new = replace(state_space_spec, cuts=None, target_frame_spans=new_target_frame_spans)

In [None]:
drop_label_idxs = [idx for idx, spans_i in enumerate(ss_new.target_frame_spans) if len(spans_i) == 0]
ss_new = ss_new.drop_labels(drop_idxs=drop_label_idxs)

In [None]:
trajectory = prepare_state_trajectory(model_representations, ss_new, pad=np.nan,
                                      expand_window=expand_window)
lengths = [np.isnan(traj_i[:, :, 0]).any(axis=1) * np.isnan(traj_i[:, :, 0]).argmax(axis=1) + \
           ~np.isnan(traj_i[:, :, 0]).any(axis=1) * traj_i.shape[1]
           for traj_i in trajectory]

In [None]:
traj_full, traj_full_flat_src = flatten_trajectory(trajectory)

In [None]:
len(trajectory), np.concatenate(lengths).mean()

In [None]:
traj_full.shape

In [None]:
pca_full = PCA(n_components=4)
traj_full_flat_pca = pca_full.fit_transform(traj_full)

#### Prepare truncated trajectory data

In [None]:
traj_trunc = [traj_i[:, :(expand_window[0] + 15)] for traj_i in trajectory]
traj_trunc_flat, traj_trunc_flat_src = flatten_trajectory(traj_trunc)

In [None]:
trunc_times = np.concatenate([-np.arange(expand_window[0], -1, -1), np.arange(1, 15)])

In [None]:
pca = PCA(n_components=4)
traj_trunc_flat_pca = pca.fit_transform(traj_trunc_flat)

## State space

In [None]:
traj_full_flat_src_dict = {src: i for i, src in enumerate(map(tuple, traj_full_flat_src))}

In [None]:
traj_full_flat_frame_dict = defaultdict(list)
for i, src in enumerate(traj_full_flat_src):
    traj_full_flat_frame_dict[tuple(src[:2])].append(i)

In [None]:
len(traj_full_flat_frame_dict[6, 0])

In [None]:
def plot_state_space_binned(n_bins, groupby=None, return_data=False, hide_largest_bin=True):
    all_lengths = np.concatenate(lengths)
    max_traj_length = all_lengths.max()

    # bin word tokens by length
    length_bins = pd.qcut(all_lengths, q=n_bins, labels=np.arange(n_bins), retbins=True)[1]
    bin_time_edges = np.maximum(0, length_bins) / model_sfreq
    bin_assignments = [np.digitize(traj_lengths, length_bins, right=True) - 1
                       for traj_lengths in lengths]    
    all_bin_assignments = np.concatenate(bin_assignments)
    all_bin_assignments = all_bin_assignments[all_bin_assignments >= 0]

    # key := bin + grouping variables
    # build reverse map from key and frame index to list of (label_idx, instance_idx) tuples
    bin_assignments_rev = defaultdict(list)
    if groupby is not None:
        group_lookup = metadata[groupby].to_dict()
    for label_idx, assignments_i in enumerate(bin_assignments):
        for j, bin_idx in enumerate(assignments_i):
            if bin_idx < 0:
                continue

            if groupby is not None:
                group_value = group_lookup[state_space_spec.labels[label_idx], j]
                if group_value is None or (isinstance(group_value, float) and np.isnan(group_value)):
                    continue
                key = (bin_idx, group_value)
            else:
                key = (bin_idx,)
            bin_assignments_rev[key].append((label_idx, j))

    # Prepare per-key vector collections
    if return_data:
        bin_frame_data, bin_frame_src = {}, {}
        for key, traj_indices in bin_assignments_rev.items():
            data_i, src_i = [], []
            for label_idx, instance_idx in traj_indices:
                flat_idxs = traj_full_flat_frame_dict[label_idx, instance_idx]
                data_i.append(traj_full_flat_pca[flat_idxs])
                src_i.append(traj_full_flat_src[flat_idxs])

            bin_frame_data[key] = data_i
            bin_frame_src[key] = src_i

    # Prepare per-frame and per-key means
    bin_frame_means = defaultdict(list)
    for key, traj_indices in tqdm(bin_assignments_rev.items(), desc="retrieving per-bin data"):
        for frame_idx in range(max_traj_length):
            flat_idxs = [traj_full_flat_src_dict[(label_idx, instance_idx, frame_idx)]
                        for label_idx, instance_idx in traj_indices
                        if (label_idx, instance_idx, frame_idx) in traj_full_flat_src_dict]
            
            if len(flat_idxs) == 0:
                bin_frame_means[key].append(np.full(traj_full_flat_pca.shape[1], np.nan))
            else:
                bin_frame_means[key].append(traj_full_flat_pca[flat_idxs].mean(axis=0))

    bin_frame_means = {key: np.array(frame_means_i) for key, frame_means_i in bin_frame_means.items()}

    ## Plot
    pcs = [[0, 1], [2, 3]]
    f, axs = plt.subplots(1, len(pcs), figsize=(10 * len(pcs), 10), squeeze=False)

    # bin_colors = sns.color_palette("tab10", n_bins)
    # get normalized continuous hue for bin edges
    bin_colors = sns.color_palette("spring", n_bins)
    grouping_values = sorted(set(key[1:] for key in bin_frame_means.keys()))
    grouping_styles = ["-", "--", "-.", ":", (0, (3, 5, 1, 5)), (0, (3, 5, 1, 5, 1, 5))]
    assert len(grouping_values) <= len(grouping_styles)

    for i, (pcs_i, ax) in enumerate(zip(pcs, axs.flat)):
        ax.axvline(0, color="black", linestyle="--")
        ax.axhline(0, color="black", linestyle="--")

        i, j = pcs_i
        ax.set_xlabel(f"PC {i+1}")
        ax.set_ylabel(f"PC {j+1}")

        k = 0
        for key in sorted(bin_frame_means.keys()):
            bin_idx = key[0]
            if hide_largest_bin and bin_idx == n_bins - 1:
                continue

            grouping_values_ij = key[1:]
            means = bin_frame_means[key]
            bin_edge = bin_time_edges[bin_idx]
            ax.plot(means[:, i], means[:, j], label=f"{key[1:]} {bin_edge:.2f} s",
                    color=bin_colors[bin_idx],
                    linestyle=grouping_styles[grouping_values.index(grouping_values_ij)],
                    alpha=0.7)

            ax.quiver(means[:-1, i], means[:-1, j], means[1:, i] - means[:-1, i], means[1:, j] - means[:-1, j],
                        angles='xy', scale_units='xy', scale=1, color="gray", alpha=0.5)
            
            # O at start of word
            word_start_frame = 0 + expand_window[0]
            ax.scatter(means[word_start_frame, i], means[word_start_frame, j], color="blue", marker="o")

            # X at middle of word
            if np.isnan(means).any():
                max_length = np.isnan(means).argmax(0).min() - 1
            else:
                max_length = means.shape[0]
            word_midpoint_frame = (max_length - word_start_frame) // 2 + word_start_frame
            ax.scatter(means[word_midpoint_frame, i], means[word_midpoint_frame, j], color="red", marker="x")

        # legend on last axis
        handles, labels = [], []
        from matplotlib import patches as mpatches
        from matplotlib.lines import Line2D
        handles.append(mpatches.Patch(color="white", label=""))
        labels.append("length bin")
        for bin in range(n_bins):
            handles.append(mpatches.Patch(color=bin_colors[bin], label=f"bin {bin}"))
            labels.append(f"{bin_time_edges[bin]:.2f} s")

        if groupby is not None:
            handles.append(mpatches.Patch(color="white", label=""))
            labels.append(groupby)
            for group, style in zip(grouping_values, grouping_styles):
                handles.append(Line2D([0], [0], color="black", linestyle=style))
                labels.append(group)

        ax.legend(handles, labels, loc="upper right", bbox_to_anchor=(1.5, 1))

    if return_data:
        return f, bin_frame_data, bin_frame_src, bin_frame_means, bin_time_edges
    else:
        return f

In [None]:
f, ss_data_all, ss_src_all, ss_means_all, ss_edges_all = plot_state_space_binned(5, return_data=True)
f.savefig(Path(output_dir) / "state_space.png", bbox_inches="tight")

In [None]:
norm_track_key = (0,)
maxlen = max(data_i.shape[0] for data_i in ss_data_all[norm_track_key])
norm_track_data = np.full((len(ss_data_all[norm_track_key]), maxlen), np.nan)
diff_norm_track_data = np.full((len(ss_data_all[norm_track_key]), maxlen), np.nan)
for i, data_i in enumerate(ss_data_all[norm_track_key]):
    norm_track_data[i, :data_i.shape[0]] = np.linalg.norm(data_i, axis=1)
    diff_norm_track_data[i, :data_i.shape[0]] = np.linalg.norm(np.roll(data_i, -1, axis=0) - data_i, axis=1)

# order by peak time
norm_track_data = norm_track_data[norm_track_data.argmax(axis=1).argsort()]

In [None]:
plt.plot(np.nanmean(norm_track_data, 0))

In [None]:
plt.plot(np.nanmean(diff_norm_track_data, 0))

## Basic plots

In [None]:
pca.explained_variance_ratio_

In [None]:
traj_idxs_by_frame = []
for frame_idx in range(traj_trunc_flat_src[:, 2].max() + 1):
    traj_idxs_by_frame.append(np.where(traj_trunc_flat_src[:, 2] == frame_idx)[0])

In [None]:
frame_means, frame_sds, frame_counts = [], [], []
for frame_idx, traj_idxs_i in enumerate(traj_idxs_by_frame):
    frame_means.append(traj_trunc_flat_pca[traj_idxs_i].mean(axis=0))
    frame_sds.append(traj_trunc_flat_pca[traj_idxs_i].std(axis=0))
    frame_counts.append(len(traj_idxs_i))

frame_means = np.array(frame_means)
frame_sds = np.array(frame_sds)
frame_counts = np.array(frame_counts)

In [None]:
f, ax = plt.subplots()

times = np.arange(-expand_window[0], frame_means.shape[0] - expand_window[0]) / model_sfreq
ax.axvline(0, color="gray", linestyle="--")
ax.axhline(0, color="gray", linestyle="--")

for component in range(frame_means.shape[1]):
    ax.plot(times, frame_means[:, component], label=f"PC {component+1}")
    ax.fill_between(times,
                    frame_means[:, component] - frame_sds[:, component] / np.sqrt(frame_counts),
                    frame_means[:, component] + frame_sds[:, component] / np.sqrt(frame_counts),
                    alpha=0.3)

ax.legend()
ax.set_xlabel("Distance from syllable boundary")

f.savefig(Path(output_dir) / "syllable_boundary.png")