In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict, Counter
from dataclasses import replace
import itertools
import json
import logging
from pathlib import Path
import pickle

import datasets
from fastdist import fastdist
import matplotlib.pyplot as plt
from matplotlib import transforms
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist, pdist, squareform
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import torch
from tqdm.auto import tqdm, trange

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/{train_dataset}.npy"

seed = 1234

max_instances_per_word = 100

metric = "cosine"

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
state_space_spec = state_space_spec.subsample_instances(max_instances_per_word)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)
trajectory = aggregate_state_trajectory(trajectory, state_space_spec, ("mean_within_cut", "phoneme"), keepdims=True)
traj_flat, traj_flat_src = flatten_trajectory(trajectory)

In [None]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()
cuts_df

In [None]:
max_num_frames = traj_flat_src[:, 2].max() + 1
all_phonemes = sorted(cuts_df.description.unique())
phoneme2idx = {p: i for i, p in enumerate(all_phonemes)}

In [None]:
traj_src_phoneme_lookup = cuts_df.description.map(phoneme2idx).to_dict()

In [None]:
max_num_frames

In [116]:
rsa_results = {}
for frame_idx in trange(max_num_frames):
    traj_idxs_i = (traj_flat_src[:, 2] == frame_idx).nonzero()[0]
    traj_srcs_i = traj_flat_src[traj_idxs_i]
    traj_phoneme_idxs = np.array([traj_src_phoneme_lookup[tuple(src)] for src in traj_srcs_i])
    
    rsa_i = np.zeros((len(all_phonemes), len(all_phonemes)))
    for i, j in tqdm(itertools.product(range(len(all_phonemes)), repeat=2), total=len(all_phonemes)**2):
        if j > i: continue
        traj_idxs_p1 = traj_idxs_i[traj_phoneme_idxs == i]
        traj_idxs_p2 = traj_idxs_i[traj_phoneme_idxs == j]
        if len(traj_idxs_p1) > 0 and len(traj_idxs_p2) > 0:
            rsa_i[i, j] = cdist(traj_flat[traj_idxs_p1], traj_flat[traj_idxs_p2], metric=metric).mean()

    rsa_results[frame_idx] = pd.DataFrame(rsa_i, index=all_phonemes, columns=all_phonemes)

In [None]:
rsa_results = pd.concat(rsa_results, names=["frame_idx"])
rsa_results

In [None]:
vmax = rsa_results.max().max()
vmin = rsa_results.min().min()

for frame_idx, rsa_results_i in rsa_results.groupby("frame_idx"):
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(rsa_results_i.droplevel("frame_idx"),
                ax=ax, square=True, vmin=vmin, vmax=vmax)
    ax.set_title(f"RSA distances at frame {frame_idx}")
    ax.set_xlabel("Phoneme 2")
    ax.set_ylabel("Phoneme 1")
    
    fig.savefig(f"{output_dir}/rsa_{frame_idx}.png")
    plt.close(fig)