In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import coherence
from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec
from src.datasets.speech_equivalence import SpeechEquivalenceDataset
from src.utils.timit import get_word_metadata

In [None]:
model_dir = "outputs/models/librispeech-train-clean-100/w2v2_8/rnn_32-hinge-mAP4/word_broad_10frames_fixedlen25"
output_dir = "."
dataset_path = "outputs/preprocessed_data/librispeech-train-clean-100"
equivalence_path = "outputs/equivalence_datasets/librispeech-train-clean-100/w2v2_8/word_broad_10frames_fixedlen25/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/librispeech-train-clean-100/w2v2_8/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/librispeech-train-clean-100/w2v2_8/state_space_specs.h5"
embeddings_path = "outputs/model_embeddings/librispeech-train-clean-100/w2v2_8/rnn_32-hinge-mAP4/word_broad_10frames_fixedlen25/librispeech-train-clean-100.npy"

metric = "cosine"

# Retain words with N or more instances
retain_n = 10

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
label_counts = state_space_spec.label_counts
drop_labels = label_counts[label_counts < retain_n].index
state_space_spec = state_space_spec.drop_labels(drop_names=drop_labels)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
len(trajectory), np.concatenate(lengths).mean()

## Estimate within-word distance

In [None]:
within_distance, within_distance_offset = \
    coherence.estimate_within_distance(trajectory, lengths, state_space_spec, metric=metric)

In [None]:
within_distance_df = pd.DataFrame(within_distance, index=pd.Index(state_space_spec.labels, name="word")) \
    .reset_index() \
    .melt(id_vars=["word"], var_name="frame", value_name="distance")

In [None]:
within_distance_offset_df = pd.DataFrame(within_distance_offset, index=pd.Index(state_space_spec.labels, name="word")) \
    .reset_index() \
    .melt(id_vars=["word"], var_name="frame", value_name="distance")

## Estimate between-word distance

In [None]:
between_distance, between_distance_offset = \
    coherence.estimate_between_distance(trajectory, lengths, state_space_spec,
                                        metric=metric)

In [None]:
between_distances_df = pd.DataFrame(np.nanmean(between_distance, axis=-1),
                                    index=pd.Index(state_space_spec.labels, name="word")) \
    .reset_index() \
    .melt(id_vars=["word"], var_name="frame", value_name="distance")

In [None]:
between_distances_offset_df = pd.DataFrame(np.nanmean(between_distance_offset, axis=-1),
                                    index=pd.Index(state_space_spec.labels, name="word")) \
    .reset_index() \
    .melt(id_vars=["word"], var_name="frame", value_name="distance")

## Together

In [None]:
merged_df = pd.concat([within_distance_df.assign(type="within"), between_distances_df.assign(type="between")])
merged_df.to_csv(Path(output_dir) / "distances.csv", index=False)
merged_df

In [None]:
ax = sns.lineplot(data=merged_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance within- and between-word")
ax.set_xlabel("Frames since word onset")
ax.set_ylabel(f"{metric.capitalize()} distance")

In [None]:
merged_offset_df = pd.concat([within_distance_offset_df.assign(type="within"),
                              between_distances_offset_df.assign(type="between")])
merged_offset_df.to_csv(Path(output_dir) / "distances_aligned_offset.csv", index=False)
merged_offset_df

In [None]:
ax = sns.lineplot(data=merged_offset_df.dropna(),
                  x="frame", y="distance", hue="type")
ax.set_title("Representational distance within- and between-word")
ax.set_xlabel("Frames before word offset")
ax.set_ylabel(f"{metric.capitalize()} distance")
ax.set_xlim((0, np.percentile(np.concatenate(lengths), 95)))

## Estimate distance by grouping features

### Onset

In [None]:
onsets = state_space_spec.cuts.xs("phoneme", level="level").groupby(["label", "instance_idx"]).first().groupby("label").description.value_counts().groupby("label").idxmax().str[1]
onsets = [onsets.loc[label] for label in state_space_spec.labels]

In [None]:
onset_distance_df, onset_distance_offset_df = coherence.estimate_category_within_between_distance(
    trajectory, lengths, onsets, metric=metric, labels=state_space_spec.labels
)

In [None]:
onset_distance_df.to_csv(Path(output_dir) / "distances-grouped_onset.csv", index=False)

In [None]:
ax = sns.lineplot(data=onset_distance_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance by onset match/mismatch")
ax.set_xlabel("Frames since word onset")
ax.set_ylabel(f"{metric.capitalize()} distance")

In [None]:
onset_distance_offset_df["time"] = onset_distance_offset_df.frame / 20
onset_distance_offset_df.to_csv(Path(output_dir) / "distances-grouped_onset_aligned_offset.csv", index=False)

### Onset as C/V

In [None]:
categorization = {
    "consonant": "B CH D DH F G HH JH K L M N NG P R S SH T TH V W Y Z ZH".split(" "),
    "vowel": "AA AE AH AO AW AY EH ER EY IH IY OW OY UH UW".split(" "),
}
categorization_lookup = {v: k for k, vs in categorization.items() for v in vs}

onset_categories = [categorization_lookup[onset] for onset in onsets]

onset_category_distance_df, onset_category_distance_offset_df = coherence.estimate_category_within_between_distance(
    trajectory, lengths, onset_categories, metric=metric, labels=state_space_spec.labels
)

In [None]:
onset_category_distance_df.to_csv(Path(output_dir) / "distances-grouped_onset_category.csv", index=False)

### Offset

In [None]:
offsets = state_space_spec.cuts.xs("phoneme", level="level").groupby(["label", "instance_idx"]).last().groupby("label").description.value_counts().groupby("label").idxmax().str[1]
offsets = [offsets.loc[label] for label in state_space_spec.labels]

In [None]:
offset_distance_df, offset_distance_offset_df = coherence.estimate_category_within_between_distance(
    trajectory, lengths, offsets, metric=metric, labels=state_space_spec.labels
)

In [None]:
offset_distance_df["time"] = (offset_distance_df.frame - 1) / 20
offset_distance_df.to_csv(Path(output_dir) / "distances-grouped_offset.csv", index=False)

In [None]:
offset_distance_offset_df["time"] = offset_distance_offset_df.frame / 20
offset_distance_offset_df.to_csv(Path(output_dir) / "distances-grouped_offset_aligned_offset.csv", index=False)

In [None]:
ax = sns.lineplot(data=offset_distance_offset_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance by onset match/mismatch")
ax.set_xlabel("Frames before word offset")
ax.set_ylabel(f"{metric.capitalize()} distance")

## Add word metadata and explore

In [None]:
word_metadata = get_word_metadata(state_space_spec)

In [None]:
def truncate_contrasts(distance_df, extra_grouping_variables=None, min_contrast_instances=30) -> pd.DataFrame:
    # We want to make comparisons between contrasts only when we have enough data
    # between individual units in the contrast to have a meaningful mean.
    # e.g. if a "within" contrast at frame 5 only has 2 instances, it's not so informative to compare this mean to the "between" contrast at frame 5.

    if extra_grouping_variables is None:
        extra_grouping_variables = []
    extra_grouping_variables = list(set(extra_grouping_variables) - {"type", "frame"})

    # Find the maximum frame at which, for all contrasts, we have at least `min_contrast_instances` instances
    # underlying each contrast value.
    max_contrast_frame = (distance_df.dropna().groupby(["type", "frame"] + extra_grouping_variables).distance.count() < min_contrast_instances) \
        .groupby("frame").max().idxmax()
    
    return distance_df[distance_df.frame < max_contrast_frame]

In [None]:
def plot_coherence_panel(distance_df, distance_offset_df, hue=None, style=None, dropna_columns=None):
    # merge metadata
    to_merge = word_metadata.groupby("label").first().reset_index()
    distance_df = distance_df.merge(to_merge, on="label", how="left")
    distance_offset_df = distance_offset_df.merge(to_merge, on="label", how="left")

    extra_grouping_variables = []
    if hue is not None:
        extra_grouping_variables.append(hue)
    if style is not None:
        extra_grouping_variables.append(style)
    distance_df = truncate_contrasts(distance_df, extra_grouping_variables=extra_grouping_variables)
    distance_offset_df = distance_offset_df[distance_offset_df.frame <= distance_df.frame.max()]

    if dropna_columns is not None:
        distance_df = distance_df.dropna(subset=dropna_columns)
        distance_offset_df = distance_offset_df.dropna(subset=dropna_columns)

    f, ax = plt.subplots(1, 2, figsize=(2 * 8, 6), sharey=True)

    sns.lineplot(data=distance_df, x="frame", y="distance",
                 hue="type" if hue is None else hue, style=style,
                 errorbar="se", ax=ax[0])
    ax[0].set_xlabel("Frames since word onset")

    sns.lineplot(data=distance_offset_df, x="frame", y="distance",
                 hue="type" if hue is None else hue, style=style,
                 errorbar="se", ax=ax[1])
    ax[1].set_xlabel("Frames before word offset")
    ax[1].invert_xaxis()

    ax[0].axhline(1, color="gray", linestyle="--")
    ax[1].axhline(1, color="gray", linestyle="--")
    ax[0].set_ylim((0, 1.5))

    return f, ax

### Identity-matched

In [None]:
f, axs = plot_coherence_panel(merged_df.rename(columns={"word": "label"}), merged_offset_df.rename(columns={"word": "label"}),
                              hue="type", style="stress_primary_initial", dropna_columns=["stress_primary_initial"])
f.suptitle("Representational distance within- and between-word, by stress")

In [None]:
f, axs = plot_coherence_panel(merged_df.rename(columns={"word": "label"}), merged_offset_df.rename(columns={"word": "label"}), hue="word_frequency_quantile", style="type")
f.suptitle("Representational distance within- and between-word, by word frequency quantile")

### Onset-matched

In [None]:
f, axs = plot_coherence_panel(onset_distance_df, onset_distance_offset_df, style="stress_primary_initial",
                              dropna_columns=["stress_primary_initial"])
f.suptitle("Representational distance by onset match/mismatch and primary initial stress")

In [None]:
f, axs = plot_coherence_panel(onset_distance_df, onset_distance_offset_df, style="stress_primary_final",
                              dropna_columns=["stress_primary_final"])
f.suptitle("Representational distance by onset match/mismatch and primary final stress")

In [None]:
f, axs = plot_coherence_panel(onset_distance_df, onset_distance_offset_df, hue="word_frequency_quantile", style="type")
f.suptitle("Representational distance by onset match/mismatch and word frequency quantile")

### Offset-matched

In [None]:
f, axs = plot_coherence_panel(offset_distance_df, offset_distance_offset_df, style="stress_primary_initial",
                                dropna_columns=["stress_primary_initial"])
f.suptitle("Representational distance by offset match/mismatch and primary initial stress")

In [None]:
f, axs = plot_coherence_panel(offset_distance_df, offset_distance_offset_df, hue="word_frequency_quantile",
                                style="type")
f.suptitle("Representational distance by offset match/mismatch and word frequency quantile")

## Model-free exploration

In [None]:
knn_window_size = 10

In [None]:
knn_references = np.concatenate([np.nanmean(traj_i[:, :knn_window_size, :], axis=1)  # .reshape((traj_i.shape[0], -1))
                                 for traj_i in trajectory])
knn_reference_ids = np.stack([(state_space_spec.labels[idx], j)
                                    for idx in range(len(trajectory)) for j in range(trajectory[idx].shape[0])])

In [None]:
knn_instances = np.random.choice(len(knn_references), 10, replace=False)

In [None]:
assert knn_references.shape[0] == knn_reference_ids.shape[0]

In [None]:
from scipy.spatial.distance import cdist, pdist, squareform

In [None]:
for knn_instance in knn_instances:
    ref_embedding = knn_references[knn_instance]
    knn_instance_results = cdist(knn_references, ref_embedding[None, :], metric=metric).ravel()

    print(knn_reference_ids[knn_instance])
    print("Nearest neighbors:")
    print(knn_reference_ids[knn_instance_results.argsort()[1:10]])
    print("Furthest neighbors:")
    print(knn_reference_ids[(-knn_instance_results).argsort()[:10]])
    print()

### RSA, collapsed over instances

In [None]:
rsa_ids = np.unique(knn_reference_ids[:, 0])
rsa_references = np.stack([np.nanmean(knn_references[knn_reference_ids[:, 0] == rsa_id], axis=0)
                            for rsa_id in rsa_ids])
rsa_distances = squareform(pdist(rsa_references, metric=metric))
rsa_distances = pd.DataFrame(rsa_distances, index=rsa_ids, columns=rsa_ids)

In [None]:
sns.heatmap(rsa_distances)

In [None]:
sns.clustermap(rsa_distances)

In [None]:
# Print closest pairs
rsa_distances.values[np.diag_indices(len(rsa_distances))] = rsa_distances.values.flatten().mean()
closest_pair_idxs = rsa_distances.values.ravel().argsort()
closest_pair_idxs = np.stack(np.unravel_index(closest_pair_idxs, rsa_distances.shape)).T
closest_pair_idxs = closest_pair_idxs[closest_pair_idxs[:, 0] != closest_pair_idxs[:, 1]]

print_n = 50
for x, y in closest_pair_idxs[:print_n]:
    if x < y:
        print(rsa_distances.index[x], rsa_distances.columns[y], rsa_distances.values[x, y])
print("---")
for x, y in closest_pair_idxs[-print_n:]:
    if x > y:
        print(rsa_distances.index[x], rsa_distances.columns[y], rsa_distances.values[x, y])