In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import itertools
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.model_selection import KFold
import torch
from tqdm.auto import tqdm, trange

from src.analysis import coherence
from src.analysis.pwcca import solve_cca
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset
from src.utils import ndarray_to_long_dataframe

In [None]:
model_dir = "outputs/models/librispeech-train-clean-100/w2v2_8/rnn_32-aniso3/word_broad_10frames"
output_dir = "."
dataset_path = "outputs/preprocessed_data/librispeech-train-clean-100"
equivalence_path = "outputs/equivalence_datasets/librispeech-train-clean-100/w2v2_8/word_broad_10frames/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/librispeech-train-clean-100/w2v2_8/librispeech-train-clean-100.h5"
state_space_specs_path = "outputs/state_space_specs/librispeech-train-clean-100/w2v2_8/state_space_specs.pkl"
embeddings_path = "outputs/model_embeddings/librispeech-train-clean-100/w2v2_8/rnn_32-aniso3/word_broad_10frames/librispeech-train-clean-100.npy"

# name -> (agg_spec, length_grouping_level)
# CCA will be estimated and evaluated on words within length groups; the unit of this length count
# is determined by `length_grouping_level`. This is because it makes more sense to talk about syllable-by-syllable
# representation within words matched in syllable count.
# The `length_grouping_level` should correspond to a `level` in the state space spec cuts.
agg_methods = {
    "mean_within_phoneme": (("mean_within_cut", "phoneme"), "phoneme"),
    "mean_within_syllable": (("mean_within_cut", "syllable"), "syllable"),
    "mean": ("mean", "phoneme"),
    "last_frame": ("last_frame", "phoneme"),
    "max": ("max", "phoneme"),
    "none": (None, "phoneme"),
}

# Keep just the K most frequent words
k = 500

# Keep at most N instances of each word
n = 500

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
# keep the K most frequent words
state_space_spec_small = state_space_spec.keep_top_k(k)

In [None]:
# keep at most N instances per word
state_space_spec_small = state_space_spec_small.subsample_instances(n, random=True)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec_small, pad=np.nan)

In [None]:
def evaluate_cca(trajectory, state_space_spec, agg_method, cv=4):
    """
    Evaluate CCA alignment between model representations and one-hot word embeddings.
    """
    if agg_method is not None:
        trajectory_agg = aggregate_state_trajectory(trajectory, state_space_spec, agg_method, keepdims=True)
    else:
        trajectory_agg = trajectory
    flat_traj, flat_traj_src = flatten_trajectory(trajectory_agg)

    # Z-score
    flat_traj = (flat_traj - flat_traj.mean(0)) / flat_traj.std(0)

    # Target values
    Y = np.zeros((len(flat_traj), k), dtype=int)
    Y[np.arange(len(flat_traj)), flat_traj_src[:, 0]] = 1

    cv = KFold(cv, shuffle=True) if isinstance(cv, int) else cv
    # NB here "frame" depends on the aggregation method; this may correspond to a model frame,
    # phoneme, syllable, etc.
    max_num_frames = flat_traj_src[:, 2].max() + 1

    # store the images of all instances in the aligned space
    # keys are (frame_idx, fold_idx)
    cca_images = {}
    cca_scores = np.zeros((max_num_frames, cv.get_n_splits(), 4)) * np.nan
    for frame_idx in trange(max_num_frames, desc="Estimating CCA", unit="frame", leave=False):
        sample_idxs = np.where(flat_traj_src[:, 2] == frame_idx)[0]
        if len(sample_idxs) / cv.get_n_splits() < flat_traj.shape[1]:
            # Not enough samples
            continue

        for fold_idx, (train_idxs, test_idxs) in enumerate(cv.split(sample_idxs)):
            x_src = flat_traj_src[sample_idxs[train_idxs]]
            x, y = flat_traj[sample_idxs[train_idxs]].T, Y[sample_idxs[train_idxs]].T
            try:
                cca = solve_cca(x, y)
            except AssertionError:
                continue
            else:
                cca_scores[frame_idx, fold_idx, 0] = cca["pwcca_sim_x"]
                cca_scores[frame_idx, fold_idx, 1] = cca["pwcca_sim_y"]
                cca_scores[frame_idx, fold_idx, 2] = cca["ewcca_sim_x"]
                cca_scores[frame_idx, fold_idx, 3] = cca["ewcca_sim_y"]

                cca_images[frame_idx, fold_idx] = cca["cca_pos_x"] @ flat_traj.T

    cca_scores_df = ndarray_to_long_dataframe(cca_scores, ["frame_idx", "fold_idx", "measure"]).reset_index()
    cca_scores_df["measure"] = cca_scores_df["measure"].map({0: "pw_x", 1: "pw_y", 2: "ew_x", 3: "ew_y"})

    return flat_traj, flat_traj_src, cca_scores_df, cca_images

In [None]:
# Run this analysis grouped by # of phonemes in word tokens
state_space_spec_small.cuts["grouping_value"] = \
    state_space_spec_small.cuts.xs("phoneme", level="level").groupby(["label", "instance_idx"]).size()

for name, (agg_spec, grouping_level) in tqdm(agg_methods.items(), unit="method"):
    state_space_spec_small.cuts["grouping_value"] = state_space_spec_small.cuts \
        .xs(grouping_level, level="level") \
        .groupby(["label", "instance_idx"]).size()
    total_num_groups = state_space_spec_small.cuts["grouping_value"].nunique()

    for length, group in tqdm(state_space_spec_small.groupby("grouping_value"), total=total_num_groups, unit="length group", leave=False):
        trajectory = prepare_state_trajectory(model_representations, group, pad=np.nan)
        flat_traj, flat_traj_src, cca_scores_df, cca_images = evaluate_cca(trajectory, group, agg_spec, cv=5)
        cca_scores_df.to_csv(f"{output_dir}/cca_scores-{name}-len{length}.csv", index=False)
        # with open(f"{output_dir}/cca_images-{name}-len{length}.pkl", "wb") as f:
        #     pickle.dump(cca_images, f)

        cca_scores_df = cca_scores_df.dropna()
        if cca_scores_df.empty:
            # there is no hope
            continue
        max_num_frames = cca_scores_df["frame_idx"].max() + 1
        min_value = min(0.5, cca_scores_df["value"].min())
        max_value = cca_scores_df["value"].max()

        f, ax = plt.subplots(figsize=(12, 6))
        if max_num_frames > 1:
            sns.lineplot(data=cca_scores_df, x="frame_idx", y="value", hue="measure", ax=ax)
            ax.set_title(f"CCA alignment scores (aggregation: {name}; max {grouping_level} length: {length})")
            ax.set_xlabel("Frame index")
            ax.set_ylim((min_value, max_value))
        else:
            sns.barplot(data=cca_scores_df, x="measure", y="value", ax=ax)
            ax.set_title(f"CCA alignment scores ({name})")
            ax.set_ylim((min_value, max_value))
        f.savefig(Path(output_dir) / f"cca_scores-{name}-len{length}.png")

        # plot PCA of resulting image space for a spectrum of frames
        num_plots = 5
        # pick a random fold
        fold_idx = np.random.randint(cca_scores_df.fold_idx.max())
        # pick random words to sample
        plot_sample_idxs = np.random.choice(len(flat_traj), min(100, len(flat_traj)), replace=False)
        frame_points = np.unique(np.linspace(0, max_num_frames - 1, num_plots, dtype=int))

        for frame_idx in frame_points:
            cca_image_i = cca_images[frame_idx, fold_idx]
            pca = PCA(2).fit(cca_image_i.T)

            plot_points = pca.transform(cca_image_i[:, plot_sample_idxs].T)
            plot_label_idxs = flat_traj_src[plot_sample_idxs, 0]
            
            f, ax = plt.subplots(figsize=(12, 12))
            ax.scatter(*plot_points.T)
            ax.set_title(f"PCA of CCA image space (aggregation: {name},\nmax {grouping_level} length: {length}; frame {frame_idx})")
            for i, label_idx in enumerate(plot_label_idxs):
                ax.text(*plot_points[i], state_space_spec.labels[label_idx], fontsize=8)

            f.savefig(Path(output_dir) / f"pca_image-{name}-len{length}-frame{frame_idx}.png")

        plt.close("all")

plt.close("all")