In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from dataclasses import replace
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from scipy.spatial.distance import pdist
from sklearn.metrics import average_precision_score
import torch
from tqdm.auto import tqdm

from src.datasets.speech_equivalence import SpeechHiddenStateDataset
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory
from src.analysis.trf import coefs_to_df
from src.datasets.speech_equivalence import SpeechEquivalenceDataset
from src.utils.timit import get_word_metadata

In [None]:
dataset = "timit-no_repeats"
state_space_name = "word"

# base_model = "w2v2_6"
# model_class = "rnn_8-weightdecay0.01"
# model_name = "biphone_recon"

base_model = "w2v2_8"
model_class = "rnn_8-aniso2"
model_name = "word_broad"

model_dir = f"outputs/models/{dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"outputs/notebooks/{dataset}/{base_model}/{model_class}/{model_name}_10frames/state_space"
dataset_path = f"outputs/preprocessed_data/{dataset}"
equivalence_path = f"outputs/equivalence_datasets/{dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{dataset}/{base_model}/hidden_states.h5"
state_space_specs_path = f"outputs/state_space_specs/{dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{dataset}/{base_model}/{model_class}/{model_name}_10frames/embeddings.npy"

metric = "cosine"

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)[state_space_name]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory = prepare_state_trajectory(
    model_representations,
    state_space_spec,
    pad=np.nan
)

In [None]:
trajectory = aggregate_state_trajectory(trajectory, "mean", keepdims=True)

In [None]:
traj_idxs = np.concatenate([np.ones(traj_i.shape[0]) * i for i, traj_i in enumerate(trajectory)])
traj_mat = np.concatenate(trajectory, axis=0).squeeze(1)
assert traj_mat.ndim == 2
assert traj_mat.shape[0] == len(traj_idxs)

In [None]:
dists = pdist(traj_mat, metric)

In [None]:
labels = 1 - np.minimum(1., pdist(traj_idxs[:, None], metric="euclidean")).astype(int)

In [None]:
ap = average_precision_score(labels, dists)
ap

In [None]:
from sklearn.metrics import PrecisionRecallDisplay
PrecisionRecallDisplay.from_predictions(labels, dists)