Study how lexical coherence relations are preserved by aggregation functions over time.
This is relevant because the brain encoding pipeline aggregates these over time.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import sys
sys.path.append("../")

In [None]:
import itertools
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import coherence
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
model_dir = "outputs/models/timit/w2v2_6/rnn_8-aniso2/word_broad_10frames"
output_dir = "."
dataset_path = "outputs/preprocessed_data/timit"
equivalence_path = "outputs/equivalence_datasets/timit/w2v2_6/word_broad_10frames/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/timit/w2v2_6/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/timit/w2v2_6/state_space_specs.pkl"
embeddings_path = "outputs/model_embeddings/timit/w2v2_6/rnn_8-weightdecay0.01/word_broad_10frames/embeddings.npy"

metric = "cosine"

# Retain words with N or more instances
retain_n = 10

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
drop_idxs = [idx for idx, target_frames in enumerate(state_space_spec.target_frame_spans)
               if len(target_frames) < retain_n]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

In [None]:
agg_fns = [
    "mean", "max", "last_frame",
    ("mean_last_k", 2), ("mean_last_k", 5),
]

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
trajectory_aggs = {agg_fn: aggregate_state_trajectory(trajectory, state_space_spec, agg_fn, keepdims=True)
                   for agg_fn in agg_fns}
dummy_lengths = [np.ones(len(traj_i), dtype=int) for traj_i in trajectory]

In [None]:
len(trajectory), np.concatenate(lengths).mean()

## Estimate within-word distance

In [None]:
within_distance_dfs = {}

for agg_fn, traj_agg in tqdm(trajectory_aggs.items(), unit="aggfn"):
    within_distance, within_distance_offset = \
        coherence.estimate_within_distance(traj_agg, dummy_lengths, state_space_spec, metric=metric)

    within_distance_dfs[agg_fn] = pd.DataFrame(
        within_distance, columns=["distance"], index=pd.Index(state_space_spec.labels, name="word"))

## Estimate between-word distance

In [None]:
between_distance_dfs = {}

for agg_fn, traj_agg in tqdm(trajectory_aggs.items(), unit="aggfn"):
    between_distance, between_distance_offset = \
        coherence.estimate_between_distance(traj_agg, dummy_lengths, state_space_spec, metric=metric)

    between_distance_dfs[agg_fn] = pd.DataFrame(
        between_distance.squeeze(1).mean(axis=-1),
        columns=["distance"], index=pd.Index(state_space_spec.labels, name="word"))

## Together

In [None]:
merged_df = pd.concat({
    "within": pd.concat(within_distance_dfs, names=["agg_fn"]),
    "between": pd.concat(between_distance_dfs, names=["agg_fn"]),
}, names=["type"])
merged_df

In [None]:
ax = sns.catplot(data=merged_df.reset_index(),
                 x="agg_fn", y="distance", hue="type", kind="bar")
# ax.set_title("Representational distance within- and between-word")
# ax.set_xlabel("Frames since word onset")
# ax.set_ylabel(f"{metric.capitalize()} distance")

## Estimate distance by grouping features

### Onset

In [None]:
onsets = [word[0] for word in state_space_spec.labels]

onset_distance_dfs = {}
for agg_fn, traj_agg in tqdm(trajectory_aggs.items(), unit="aggfn"):
    onset_distance_dfs[agg_fn], _ = coherence.estimate_category_within_between_distance(
        traj_agg, dummy_lengths, onsets, metric=metric, labels=state_space_spec.labels
    )

In [None]:
onset_distance_df = pd.concat(onset_distance_dfs, names=["agg_fn"])

In [None]:
onset_distance_df.to_csv(Path(output_dir) / "distances-grouped_onset.csv")

In [None]:
ax = sns.catplot(data=onset_distance_df, x="agg_fn", y="distance", hue="type", kind="bar")
# ax.set_title("Representational distance by onset match/mismatch")
# ax.set_xlabel("Frames since word onset")
# ax.set_ylabel(f"{metric.capitalize()} distance")

### Offset

In [None]:
offsets = [word[-1] for word in state_space_spec.labels]

offset_distance_dfs = {}
for agg_fn, traj_agg in tqdm(trajectory_aggs.items(), unit="aggfn"):
    offset_distance_dfs[agg_fn], _ = coherence.estimate_category_within_between_distance(
        traj_agg, dummy_lengths, offsets, metric=metric, labels=state_space_spec.labels
    )

In [None]:
offset_distance_df = pd.concat(offset_distance_dfs, names=["agg_fn"])

In [None]:
offset_distance_df.to_csv(Path(output_dir) / "distances-grouped_offset.csv")

In [None]:
sns.catplot(data=offset_distance_df.reset_index(), x="agg_fn", y="distance", hue="type", kind="bar")

In [None]:
offset_distance_df

## Model-free exploration

In [None]:
knn_references = np.concatenate(trajectory_aggs["mean"]).squeeze(1)
knn_reference_ids = np.concatenate([np.stack([np.ones(len(traj)) * i, np.arange(len(traj))]).T
                                   for i, traj in enumerate(trajectory_aggs["mean"])], axis=0).astype(int)

assert len(knn_references) == len(knn_reference_ids)

In [None]:
knn_instances = np.random.choice(len(knn_references), 10, replace=False)

In [None]:
from scipy.spatial.distance import cdist, pdist, squareform

In [None]:
for knn_instance in knn_instances:
    ref_embedding = knn_references[knn_instance]
    knn_instance_results = cdist(knn_references, ref_embedding[None, :], metric=metric).ravel()

    print(state_space_spec.labels[knn_reference_ids[knn_instance][0]], knn_reference_ids[knn_instance][1])
    print("Nearest neighbors:")
    for class_idx, instance_idx in knn_reference_ids[knn_instance_results.argsort()[1:10]]:
        print("\t", state_space_spec.labels[class_idx], instance_idx)
    print("Furthest neighbors:")
    for class_idx, instance_idx in knn_reference_ids[-knn_instance_results.argsort()[1:10]]:
        print("\t", state_space_spec.labels[class_idx], instance_idx)
    print()

### RSA, collapsed over instances

In [None]:
rsa_distances = {}

for agg_fn, traj_agg in tqdm(trajectory_aggs.items(), unit="aggfn"):
    rsa_references = np.stack([np.mean(traj_agg_i.squeeze(1), axis=0) for traj_agg_i in traj_agg])
    rsa_distances[agg_fn] = pd.DataFrame(
        squareform(pdist(rsa_references, metric=metric)),
        index=state_space_spec.labels,
        columns=state_space_spec.labels
    )

In [None]:
rsa_viz_sample = np.random.choice(state_space_spec.labels, size=20, replace=False)

f, axs = plt.subplots(len(agg_fns), 1, figsize=(10, 10 * len(agg_fns)))
for ax, (agg_fn, rsa_distances_i) in zip(axs.ravel(), rsa_distances.items()):
    rsa_viz = rsa_distances_i.loc[rsa_viz_sample, rsa_viz_sample]
    sns.heatmap(rsa_viz, ax=ax)
    ax.set_title(agg_fn)

In [None]:
rsa_sims = {}
for agg1, agg2 in itertools.product(agg_fns, repeat=2):
    rsa_triu1 = rsa_distances[agg1].values[np.triu_indices(len(rsa_distances[agg1]), k=1)]
    rsa_triu2 = rsa_distances[agg2].values[np.triu_indices(len(rsa_distances[agg2]), k=1)]
    rsa_sims[agg1, agg2] = scipy.stats.spearmanr(rsa_triu1, rsa_triu2)[0]

In [None]:
rsa_sims = pd.Series(rsa_sims)
rsa_sims.index.set_names(["agg1", "agg2"], inplace=True)
rsa_sims

In [None]:
ax = sns.heatmap(rsa_sims.unstack())
ax.set_title("Similarity in word-level RSA")

In [None]:
sns.clustermap(rsa_viz)