In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import sys
sys.path.append("../")

In [None]:

from dataclasses import replace
from pathlib import Path
import pickle

import datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, cdist
import seaborn as sns
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

from src.analysis import coherence
from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec
from src.datasets.speech_equivalence import SpeechEquivalenceDataset
from src.models import get_best_checkpoint
from src.models.integrator import ContrastiveEmbeddingModel, compute_embeddings, load_or_compute_embeddings

In [None]:
model_dir = "outputs/models/w2v2_6_8/phoneme_within_word_suffix"
# equiv_dataset_path = "../data/timit_equiv_phoneme_within_word_prefix_1.pkl"
# model_checkpoint = "../out/ce_model_phoneme_6_32/checkpoint-800/"
equiv_dataset_path = "data/timit_equivalence_facebook-wav2vec2-base_6-phoneme-1.pkl"

state_space_spec_path = "out/state_space_specs/all_words.pkl"
output_dir = "."

metric = "cosine"

In [None]:
model = ContrastiveEmbeddingModel.from_pretrained(get_best_checkpoint(model_dir))
model.eval()

In [None]:
with open(equiv_dataset_path, "rb") as f:
    equiv_dataset: SpeechEquivalenceDataset = pickle.load(f)

In [None]:
with open(state_space_spec_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = pickle.load(f)
assert state_space_spec.is_compatible_with(equiv_dataset)

In [None]:
model_representations = load_or_compute_embeddings(model, equiv_dataset, model_dir, equiv_dataset_path)

In [None]:
# Retain only words with N or more instances
retain_n = 10
drop_idxs = [idx for idx, target_frames in enumerate(state_space_spec.target_frame_spans)
               if len(target_frames) < retain_n]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

In [None]:
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
len(trajectory)

## Estimate within-word distance

In [None]:
within_distance, within_distance_offset = \
    coherence.estimate_within_distance(trajectory, lengths, state_space_spec, metric=metric)

In [None]:
sns.heatmap(within_distance, center=1, cmap="RdBu")

In [None]:
within_distance_df = pd.DataFrame(within_distance, index=pd.Index(state_space_spec.labels, name="word")) \
    .reset_index() \
    .melt(id_vars=["word"], var_name="frame", value_name="distance")

In [None]:
within_distance_offset_df = pd.DataFrame(within_distance_offset, index=pd.Index(state_space_spec.labels, name="word")) \
    .reset_index() \
    .melt(id_vars=["word"], var_name="frame", value_name="distance")

## Estimate between-word distance

In [None]:
between_distance, between_distance_offset = \
    coherence.estimate_between_distance(trajectory, lengths, state_space_spec,
                                        metric=metric)

In [None]:
between_distances_df = pd.DataFrame(np.nanmean(between_distance, axis=-1),
                                    index=pd.Index(state_space_spec.labels, name="word")) \
    .reset_index() \
    .melt(id_vars=["word"], var_name="frame", value_name="distance")

In [None]:
between_distances_offset_df = pd.DataFrame(np.nanmean(between_distance_offset, axis=-1),
                                    index=pd.Index(state_space_spec.labels, name="word")) \
    .reset_index() \
    .melt(id_vars=["word"], var_name="frame", value_name="distance")

## Together

In [None]:
merged_df = pd.concat([within_distance_df.assign(type="within"), between_distances_df.assign(type="between")])
merged_df.to_csv(Path(output_dir) / "distances.csv", index=False)
merged_df

In [None]:
ax = sns.lineplot(data=merged_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance within- and between-word")
ax.set_xlabel("Frames since word onset")
ax.set_ylabel(f"{metric.capitalize()} distance")

In [None]:
merged_offset_df = pd.concat([within_distance_offset_df.assign(type="within"),
                              between_distances_offset_df.assign(type="between")])
merged_offset_df.to_csv(Path(output_dir) / "distances_aligned_offset.csv", index=False)
merged_offset_df

In [None]:
ax = sns.lineplot(data=merged_offset_df.dropna(),
                  x="frame", y="distance", hue="type")
ax.set_title("Representational distance within- and between-word")
ax.set_xlabel("Frames before word offset")
ax.set_ylabel(f"{metric.capitalize()} distance")
ax.set_xlim((0, np.percentile(np.concatenate(lengths), 95)))

## Estimate distance by grouping features

### Onset

In [None]:
onsets = [word[0] for word in state_space_spec.labels]

onset_distance_df, onset_distance_offset_df = coherence.estimate_category_within_between_distance(
    trajectory, lengths, onsets, metric=metric, labels=state_space_spec.labels
)

In [None]:
onset_distance_df.to_csv(Path(output_dir) / "distances-grouped_onset.csv", index=False)

In [None]:
ax = sns.lineplot(data=onset_distance_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance by onset match/mismatch")
ax.set_xlabel("Frames since word onset")
ax.set_ylabel(f"{metric.capitalize()} distance")

In [None]:
onset_distance_offset_df["time"] = (onset_distance_offset_df.frame - 1) / 20
onset_distance_offset_df.to_csv(Path(output_dir) / "distances-grouped_onset_aligned_offset.csv", index=False)

### Offset

In [None]:
offsets = [word[-1] for word in state_space_spec.labels]

offset_distance_df, offset_distance_offset_df = coherence.estimate_category_within_between_distance(
    trajectory, lengths, offsets, metric=metric, labels=state_space_spec.labels
)

In [None]:
offset_distance_df["time"] = (offset_distance_df.frame - 1) / 20
offset_distance_df.to_csv(Path(output_dir) / "distances-grouped_offset.csv", index=False)

In [None]:
offset_distance_offset_df["time"] = (offset_distance_offset_df.frame - 1) / 20
offset_distance_offset_df.to_csv(Path(output_dir) / "distances-grouped_offset_aligned_offset.csv", index=False)

In [None]:
ax = sns.lineplot(data=offset_distance_offset_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance by offset match/mismatch")
ax.set_xlabel("Frames before word offset")
ax.set_ylabel(f"{metric.capitalize()} distance")