In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict, Counter
from dataclasses import replace
import itertools
import json
import logging
from pathlib import Path
import pickle

import datasets
from fastdist import fastdist
import matplotlib.pyplot as plt
from matplotlib import transforms
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist, pdist, squareform
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import torch
from tqdm.auto import tqdm, trange

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/{train_dataset}.npy"

seed = 1234

max_instances_per_word = 100

metric = "cosine"

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
state_space_spec = state_space_spec.subsample_instances(max_instances_per_word)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)
trajectory = aggregate_state_trajectory(trajectory, state_space_spec, ("mean_within_cut", "phoneme"), keepdims=True)
traj_flat, traj_flat_src = flatten_trajectory(trajectory)

In [None]:
word_freq_df = pd.read_csv("data/SUBTLEXus74286wordstextversion.txt", sep="\t", index_col=0)

In [None]:
cuts_df = state_space_spec.cuts.xs("phoneme", level="level").drop(columns=["onset_frame_idx", "offset_frame_idx"])
cuts_df["label_idx"] = cuts_df.index.get_level_values("label").map({l: i for i, l in enumerate(state_space_spec.labels)})
cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"]).sort_index()

# merge flattened traj idxs into this cuts_df
traj_flat_idxs = pd.Series({tuple(traj_flat_src_i): i for i, traj_flat_src_i in enumerate(traj_flat_src)})
traj_flat_idxs.index.names = ["label_idx", "instance_idx", "frame_idx"]
cuts_df = pd.merge(cuts_df, traj_flat_idxs.rename("traj_flat_idx"), left_index=True, right_index=True)

cuts_df

In [None]:
# Prepare metadata for other groupers
cuts_df["word_length"] = cuts_df.groupby(["label_idx", "instance_idx"]).size()
# merge in frequency data
cuts_df = pd.merge(cuts_df, word_freq_df[["Lg10WF"]], left_on="label", right_index=True, how="left")
cuts_df["word_relative_position"] = cuts_df.groupby(["label_idx", "instance_idx"]).cumcount() / cuts_df.word_length

relative_position_bins = 5
cuts_df["word_relative_position_bin"] = pd.cut(cuts_df.word_relative_position, bins=relative_position_bins, labels=False)

In [None]:
frequency_bins = 5
cuts_df["word_frequency_bin"] = pd.qcut(cuts_df.Lg10WF, q=frequency_bins, labels=False)

word_length_bins = 5
cuts_df["word_length_bin"] = pd.qcut(cuts_df.word_length, q=word_length_bins, labels=False)

In [None]:
assert cuts_df.traj_flat_idx.isna().sum() == 0

In [None]:
all_phonemes = sorted(cuts_df.description.unique())
phoneme2idx = {p: i for i, p in enumerate(all_phonemes)}

In [None]:
groupers = {
    "position": ["frame_idx"],
    "position_within_length": ["word_length_bin", "frame_idx"],
    "position_within_frequency": ["word_frequency_bin", "frame_idx"],
    "relative_position": ["word_relative_position_bin"],
}

In [None]:
def evaluate_rsa(grouper_name, grouper, cuts_df, traj_flat):
    rsa_results = {}
    for group, rows in tqdm(cuts_df.groupby(grouper), leave=False):
        rsa_i = {}
        # split rows by phoneme label
        rows = {phoneme_idx: rows_i for phoneme_idx, rows_i in rows.groupby(rows.description.map(phoneme2idx))}
        for i, j in tqdm(itertools.product(range(len(all_phonemes)), repeat=2), total=len(all_phonemes)**2, leave=False):
            if j > i: continue
            if i not in rows or j not in rows: continue
            traj_idxs_p1 = rows[i].traj_flat_idx
            traj_idxs_p2 = rows[j].traj_flat_idx
            if len(traj_idxs_p1) > 0 and len(traj_idxs_p2) > 0:
                rsa_i[all_phonemes[i], all_phonemes[j]] = cdist(traj_flat[traj_idxs_p1], traj_flat[traj_idxs_p2], metric=metric).mean()

        rsa_results[group] = rsa_i

    # save results csv
    rsa_results = pd.concat(\
        {group: pd.Series(rsa_results_i) for group, rsa_results_i in rsa_results.items()},
        names=grouper + ["phoneme1", "phoneme2"]).unstack().sort_index()
    rsa_results.to_csv(f"{output_dir}/rsa_results-{grouper_name}.csv")

    # plot
    vmax = rsa_results.max().max()
    vmin = rsa_results.min().min()

    for group, rsa_results_i in tqdm(rsa_results.groupby(grouper), desc="Plotting", leave=False):
        fig, ax = plt.subplots(figsize=(10, 10))
        sns.heatmap(rsa_results_i.droplevel(grouper),
                    ax=ax, square=True, vmin=vmin, vmax=vmax)
        ax.set_title(f"RSA distances in {grouper_name} at group {group}")
        ax.set_xlabel("Phoneme 2")
        ax.set_ylabel("Phoneme 1")
        
        group_id = "_".join(str(g) for g in group)
        fig.savefig(f"{output_dir}/rsa-{grouper_name}-{group_id}.png")
        plt.close(fig)

In [None]:
for grouper_name, grouper in tqdm(groupers.items(), unit="grouper"):
    evaluate_rsa(grouper_name, grouper, cuts_df, traj_flat)