In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelBinarizer
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory

In [None]:
L = logging.getLogger(__name__)

In [None]:
dataset = "librispeech-train-clean-100"
state_space_name = "word"

# base_model = "w2v2_6"
# model_class = "rnn_8-weightdecay0.01"
# model_name = "biphone_recon"

base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"

model_dir = f"outputs/models/{dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"outputs/notebooks/{dataset}/{base_model}/{model_class}/{model_name}_10frames/state_space"
dataset_path = f"outputs/preprocessed_data/{dataset}"
equivalence_path = f"outputs/equivalence_datasets/{dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{dataset}/{base_model}/hidden_states.h5"
state_space_specs_path = f"outputs/state_space_specs/{dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{dataset}/{base_model}/{model_class}/{model_name}_10frames/{dataset}.npy"

metric = "cosine"

# name -> (agg_spec, length_grouping_level)
# CCA will be estimated and evaluated on words within length groups; the unit of this length count
# is determined by `length_grouping_level`. This is because it makes more sense to talk about syllable-by-syllable
# representation within words matched in syllable count.
# The `length_grouping_level` should correspond to a `level` in the state space spec cuts.
agg_methods = {
    "mean_within_phoneme": (("mean_within_cut", "phoneme"), "phoneme"),
    # "mean_within_syllable": (("mean_within_cut", "syllable"), "syllable"),
    # "mean": ("mean", "phoneme"),
    # "last_frame": ("last_frame", "phoneme"),
    # "max": ("max", "phoneme"),
    # "none": (None, "phoneme"),
}

# keep K most frequent words
freq_top_k = 2000

# keep at most `max_instances_per_label`
max_instances_per_label = 50

In [None]:
# TODO get a measure of random chance
# TODO record prediction performance on *all* words

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)[state_space_name]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
state_space_spec = state_space_spec.keep_top_k(freq_top_k)

In [None]:
state_space_spec = state_space_spec.subsample_instances(max_instances_per_label)

In [None]:
trajectory = prepare_state_trajectory(
    model_representations,
    state_space_spec,
    pad=np.nan
)

In [None]:
agg_method = agg_methods["mean_within_phoneme"][0]
agg_traj = aggregate_state_trajectory(trajectory, state_space_spec, agg_method, keepdims=True)

In [None]:
flat_traj, flat_traj_src = flatten_trajectory(agg_traj)
max_num_frames = flat_traj_src[:, 2].max() + 1

In [None]:
def compute_mean_average_precision(embeddings: np.ndarray, classes: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    """
    estimate classification performance by learning a classifier and calculating mean
    average precision

    Returns:
    - n_folds array: mAP scores for each fold
    - n_examples * n_classes array: predicted classes for each example (corresponding to integer
      values of `classes`; not indices)
    """
    skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=0)
    scores = []

    X = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    Y = LabelBinarizer().fit_transform(classes)

    Y_pred = np.zeros_like(classes)
    Y_pred_proba = np.zeros_like(Y, dtype=float)

    for train_idx, test_idx in skf.split(X, classes):
        clf = LogisticRegression(max_iter=1000)
        try:
            clf.fit(X[train_idx], classes[train_idx])
        except ValueError as e:
            # fitting error -- this may be due to e.g. degenerate training set
            L.error(f"Error fitting classifier: {e}")
            scores.append(np.nan)
            continue

        # compute multi-class mAP
        Y_pred_f = clf.predict_proba(X[test_idx])
        if Y_pred_f.shape[1] == 2:
            Y_pred_f = Y_pred_f[:, 1] # binary casex

        scores.append(average_precision_score(Y[test_idx], Y_pred_f, average="macro"))
        Y_pred_proba[test_idx] = Y_pred_f
        Y_pred[test_idx] = clf.predict(X[test_idx])

    return np.array(scores), Y_pred

In [None]:
# map_scores: for each frame position, store the mAP scores across folds
# traj_idxs: for each frame position, store the trajectory indices used for estimating mAP
# Y_pred: for each frame position, store the predicted probabilities
map_scores, traj_srcs, Y_pred = [], [], []
for num_frames in range(max_num_frames):
    print(f"Computing mAP for {num_frames + 1} frames")
    traj_idxs_i = (flat_traj_src[:, 2] == num_frames).nonzero()[0]
    traj_srcs.append(flat_traj_src[traj_idxs_i])
    X_i = flat_traj[traj_idxs_i]
    Y_i = flat_traj_src[traj_idxs_i, 0]

    print(X_i.shape, Y_i.shape)
    map_scores_i, Y_pred_i = compute_mean_average_precision(X_i, Y_i)
    print(f"Mean average precision at {num_frames}: {np.nanmean(map_scores_i)}")
    map_scores.append(map_scores_i)
    Y_pred.append(Y_pred_i)

In [None]:
map_scores_df = pd.DataFrame(map_scores, index=pd.RangeIndex(1, len(map_scores) + 1, name="num_frames")) \
    .reset_index().melt(id_vars=["num_frames"], var_name="fold", value_name="mAP")
map_scores_df.to_csv(Path(output_dir) / "map_scores.csv", index=False)
map_scores_df

In [None]:
ax = sns.lineplot(data=map_scores_df, x="num_frames", y="mAP")
ax.set_ylim((0.0, 1.0))
ax.axhline(1 / len(state_space_spec.labels), color="gray", linestyle="--")

In [None]:
predictions_df = []
for traj_src_i, Y_pred_i in zip(traj_srcs, Y_pred):
    predictions_df.append(pd.DataFrame({
        "label_idx": traj_src_i[:, 0],
        "instance_idx": traj_src_i[:, 1],
        "predicted_label_idx": Y_pred_i,
    }))
predictions_df = pd.concat(dict(enumerate(predictions_df)), names=["frame_idx"]).droplevel(-1)
predictions_df["label"] = predictions_df.label_idx.map(dict(enumerate(state_space_spec.labels)))
predictions_df["predicted_label"] = predictions_df.predicted_label_idx.map(dict(enumerate(state_space_spec.labels)))
predictions_df["correct"] = predictions_df.label == predictions_df.predicted_label
predictions_df

In [None]:
predictions_df.groupby(["label", "frame_idx"]).correct.mean().sort_values()

In [None]:
predictions_df.to_csv(Path(output_dir) / "predictions.csv")