In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import sys
sys.path.append("../")

In [None]:

from collections import defaultdict
from dataclasses import replace
import itertools
from pathlib import Path
import pickle

import datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, cdist
import seaborn as sns
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec
from src.datasets.speech_equivalence import SpeechEquivalenceDataset
from src.models import get_best_checkpoint
from src.models.integrator import ContrastiveEmbeddingModel, compute_embeddings, load_or_compute_embeddings

from sklearn.linear_model import LinearRegression, RidgeCV
from sklearn.model_selection import KFold, cross_val_score

In [None]:
model_dir = "outputs/models/w2v2_32/phoneme"
# equiv_dataset_path = "../data/timit_equiv_phoneme_within_word_prefix_1.pkl"
# model_checkpoint = "../out/ce_model_phoneme_6_32/checkpoint-800/"
equiv_dataset_path = "data/timit_equiv_phoneme_6_1.pkl"

state_space_spec_path = "out/state_space_specs/all_words.pkl"
output_dir = "."

metric = "cosine"

In [None]:
model = ContrastiveEmbeddingModel.from_pretrained(get_best_checkpoint(model_dir))
model.eval()

In [None]:
with open(equiv_dataset_path, "rb") as f:
    equiv_dataset: SpeechEquivalenceDataset = pickle.load(f)

In [None]:
with open(state_space_spec_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = pickle.load(f)
assert state_space_spec.is_compatible_with(equiv_dataset)

In [None]:
model_representations = load_or_compute_embeddings(model, equiv_dataset, model_dir, equiv_dataset_path)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

In [None]:
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
def evaluate_temporal_generalization(trajectory, lengths, train_frame, test_frame):
    X, Y = [], []
    for traj_i, lengths_i in zip(trajectory, lengths):
        analyze = ((lengths_i > test_frame) & (lengths_i > train_frame)).nonzero()[0]
        for idx in analyze:
            X.append(traj_i[idx, train_frame])
            Y.append(traj_i[idx, test_frame])

    if len(X) < 100:
        return np.nan

    X = np.stack(X)
    Y = np.stack(Y)

    # Fit linear model
    model = RidgeCV()
    return cross_val_score(model, X, Y, cv=KFold(5, shuffle=True), scoring="r2")

In [None]:
temporal_generalization_scores = np.zeros((trajectory[0].shape[1], trajectory[0].shape[1])) * np.nan
for train_frame, test_frame in tqdm(list(itertools.product(range(trajectory[0].shape[1]), repeat=2))):
    scores = evaluate_temporal_generalization(trajectory, lengths, train_frame, test_frame)
    temporal_generalization_scores[train_frame, test_frame] = np.mean(scores)

In [None]:
temporal_generalization_df = pd.DataFrame(temporal_generalization_scores, columns=pd.Index(range(trajectory[0].shape[1]), name="test_frame"),
                                          index=pd.Index(range(trajectory[0].shape[1]), name="train_frame"))
temporal_generalization_df.head()

In [None]:
temporal_generalization_df.to_csv(Path(output_dir) / "temporal_generalization.csv")

In [None]:
plot_df = temporal_generalization_df.iloc[:30, :30]
# Reorder columns so that the diagonal moves up-right
plot_df = plot_df.loc[plot_df.index[::-1]]
ax = sns.heatmap(plot_df, cmap="RdBu_r", center=0, xticklabels=10, yticklabels=10)
ax.set_xlabel("Test frame")
ax.set_ylabel("Train frame")