In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelBinarizer
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory

In [None]:
dataset = "librispeech-train-clean-100"
state_space_name = "word"

# base_model = "w2v2_6"
# model_class = "rnn_8-weightdecay0.01"
# model_name = "biphone_recon"

base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"

model_dir = f"outputs/models/{dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"outputs/notebooks/{dataset}/{base_model}/{model_class}/{model_name}_10frames/state_space"
dataset_path = f"outputs/preprocessed_data/{dataset}"
equivalence_path = f"outputs/equivalence_datasets/{dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{dataset}/{base_model}/hidden_states.h5"
state_space_specs_path = f"outputs/state_space_specs/{dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{dataset}/{base_model}/{model_class}/{model_name}_10frames/{dataset}.npy"

metric = "cosine"

# name -> (agg_spec, length_grouping_level)
# CCA will be estimated and evaluated on words within length groups; the unit of this length count
# is determined by `length_grouping_level`. This is because it makes more sense to talk about syllable-by-syllable
# representation within words matched in syllable count.
# The `length_grouping_level` should correspond to a `level` in the state space spec cuts.
agg_methods = {
    "mean_within_phoneme": (("mean_within_cut", "phoneme"), "phoneme"),
    # "mean_within_syllable": (("mean_within_cut", "syllable"), "syllable"),
    # "mean": ("mean", "phoneme"),
    # "last_frame": ("last_frame", "phoneme"),
    # "max": ("max", "phoneme"),
    # "none": (None, "phoneme"),
}

# keep K most frequent words
freq_top_k = 1000

# keep at most `max_instances_per_label`
max_instances_per_label = 50

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)[state_space_name]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
state_space_spec = state_space_spec.keep_top_k(freq_top_k)

In [None]:
state_space_spec = state_space_spec.subsample_instances(max_instances_per_label)

In [None]:
trajectory = prepare_state_trajectory(
    model_representations,
    state_space_spec,
    pad=np.nan
)

In [None]:
agg_method = agg_methods["mean_within_phoneme"][0]
agg_traj = aggregate_state_trajectory(trajectory, state_space_spec, agg_method, keepdims=True)

In [None]:
flat_traj, flat_traj_src = flatten_trajectory(agg_traj)
max_num_frames = flat_traj_src[:, 2].max() + 1

In [None]:
def compute_mean_average_precision(embeddings: np.ndarray, classes: np.ndarray) -> np.ndarray:
    """
    estimate classification performance by learning a classifier and calculating mean
    average precision
    """
    skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=0)
    scores = []

    X = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    Y = LabelBinarizer().fit_transform(classes)

    for train_idx, test_idx in skf.split(X, classes):
        clf = LogisticRegression(max_iter=1000)
        clf.fit(X[train_idx], classes[train_idx])
        # compute multi-class mAP
        Y_pred = clf.predict_proba(X[test_idx])
        if Y_pred.shape[1] == 2:
            Y_pred = Y_pred[:, 1] # binary casex
        scores.append(average_precision_score(Y[test_idx], Y_pred, average="macro"))
    return np.array(scores)

In [None]:
map_scores = []
for num_frames in range(max_num_frames):
    print(f"Computing mAP for {num_frames + 1} frames")
    traj_idxs = (flat_traj_src[:, 2] == num_frames).nonzero()[0]
    X = flat_traj[traj_idxs]
    Y = flat_traj_src[traj_idxs, 0]

    print(X.shape, Y.shape)
    map_scores_i = compute_mean_average_precision(X, Y)
    print(f"Mean average precision at {num_frames}: {map_scores_i.mean()}")
    map_scores.append(map_scores_i)

In [None]:
map_scores_df = pd.DataFrame(map_scores, index=pd.RangeIndex(1, len(map_scores) + 1, name="num_frames")) \
    .reset_index().melt(id_vars=["num_frames"], var_name="fold", value_name="mAP")
map_scores_df.to_csv(Path(output_dir) / "map_scores.csv", index=False)
map_scores_df

In [None]:
ax = sns.lineplot(data=map_scores_df, x="num_frames", y="mAP")
ax.set_ylim((0.0, 1.0))
ax.axhline(1 / len(state_space_spec.labels), color="gray", linestyle="--")