In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import sys
sys.path.append("../")

In [None]:
import itertools
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

from sklearn.linear_model import LinearRegression, RidgeCV
from sklearn.model_selection import KFold, cross_val_score

In [None]:
model_dir = "outputs/models/timit/w2v2_6/rnn_8/phoneme"
output_dir = "outputs/notebooks/timit/w2v2_6/rnn_8/phoneme/plot"
dataset_path = "outputs/preprocessed_data/timit"
equivalence_path = "outputs/equivalence_datasets/timit/w2v2_6/phoneme/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/timit/w2v2_6/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/timit/w2v2_6/state_space_specs.h5"
embeddings_path = "outputs/model_embeddings/timit/w2v2_6/rnn_8/phoneme/embeddings.npy"

# Add 4 frames prior to phoneme onset to each trajectory
expand_frame_window = (4, 0)

metric = "cosine"

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(equivalence_path, "rb") as f:
    equiv_dataset: SpeechEquivalenceDataset = torch.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "phoneme")
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, expand_window=expand_frame_window, pad=np.nan)

In [None]:
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
def evaluate_temporal_generalization(trajectory, lengths, train_frame, test_frame):
    X, Y = [], []
    for traj_i, lengths_i in zip(trajectory, lengths):
        analyze = ((lengths_i > test_frame) & (lengths_i > train_frame)).nonzero()[0]
        for idx in analyze:
            X.append(traj_i[idx, train_frame])
            Y.append(traj_i[idx, test_frame])

    if len(X) < 100:
        return np.nan

    X = np.stack(X)
    Y = np.stack(Y)

    # Fit linear model
    model = RidgeCV(cv=KFold(3, shuffle=True))
    return cross_val_score(model, X, Y, cv=KFold(3, shuffle=True), scoring="r2")

In [None]:
temporal_generalization_scores = np.zeros((trajectory[0].shape[1], trajectory[0].shape[1])) * np.nan
for train_frame, test_frame in tqdm(list(itertools.product(range(trajectory[0].shape[1]), repeat=2))):
    scores = evaluate_temporal_generalization(trajectory, lengths, train_frame, test_frame)
    temporal_generalization_scores[train_frame, test_frame] = np.mean(scores)

In [None]:
temporal_generalization_df = pd.DataFrame(temporal_generalization_scores, columns=pd.Index(range(trajectory[0].shape[1]), name="test_frame"),
                                          index=pd.Index(range(trajectory[0].shape[1]), name="train_frame"))
temporal_generalization_df.head()

In [None]:
temporal_generalization_df.to_csv(Path(output_dir) / "temporal_generalization.csv")

In [None]:
plot_df = temporal_generalization_df.iloc[:30, :30]

ax = sns.heatmap(plot_df, cmap="RdBu_r", center=0, xticklabels=10, yticklabels=10)

assert expand_frame_window[1] == 0
# Draw phoneme onset
if expand_frame_window[0] != 0:
    ax.axvline(expand_frame_window[0], color="gray", linestyle="--")
    ax.axhline(expand_frame_window[0], color="gray", linestyle="--")

ax.set_xlabel("Test frame")
ax.set_ylabel("Train frame")