In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import itertools
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.model_selection import KFold
import torch
from tqdm.auto import tqdm, trange

from src.analysis import coherence
from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset
from src.utils import ndarray_to_long_dataframe

In [None]:
model_dir = "outputs/models/librispeech-train-clean-100/w2v2_8/rnn_32-aniso3/word_broad_10frames"
output_dir = "."
dataset_path = "outputs/preprocessed_data/librispeech-train-clean-100"
equivalence_path = "outputs/equivalence_datasets/librispeech-train-clean-100/w2v2_8/word_broad_10frames/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/librispeech-train-clean-100/w2v2_8/librispeech-train-clean-100.h5"
state_space_specs_path = "outputs/state_space_specs/librispeech-train-clean-100/w2v2_8/state_space_specs.pkl"
embeddings_path = "outputs/model_embeddings/librispeech-train-clean-100/w2v2_8/rnn_32-aniso3/word_broad_10frames/librispeech-train-clean-100.npy"

agg_methods = {
    "mean_within_phoneme": ("mean_within_cut", "phoneme"),
    "mean_within_syllable": ("mean_within_cut", "syllable"),
    "mean": "mean",
    "last_frame": "last_frame",
    "max": "max",
    "none": None,
}

# Keep just the K most frequent words
k = 500

# Keep at most N instances of each word
n = 500

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

## PWCCA definition

https://github.com/1Konny/Projection_Weighted_CCA/blob/master/cca.py

In [None]:
# copy&paste from: https://github.com/google/svcca/blob/master/cca_core.py
def positivedef_matrix_sqrt(array):
    """Stable method for computing matrix square roots, supports complex matrices.

    Args:
        array: A numpy 2d array, can be complex valued that is a positive
            definite symmetric (or hermitian) matrix

    Returns:
        sqrtarray: The matrix square root of array
    """
    w, v = np.linalg.eigh(array)
    wsqrt = np.sqrt(w)
    sqrtarray = np.dot(v, np.dot(np.diag(wsqrt), np.conj(v).T))
    return sqrtarray


# copy&paste from: https://github.com/google/svcca/blob/master/cca_core.py
def remove_small(sigma_xx, sigma_xy, sigma_yx, sigma_yy, threshold=1e-6):
    """Takes covariance between X, Y, and removes values of small magnitude.

    Args:
        sigma_xx: 2d numpy array, variance matrix for x
        sigma_xy: 2d numpy array, crossvariance matrix for x,y
        sigma_yx: 2d numpy array, crossvariance matrixy for x,y,
            (conjugate) transpose of sigma_xy
        sigma_yy: 2d numpy array, variance matrix for y
        threshold: cutoff value for norm below which directions are thrown
            away

    Returns:
            sigma_xx_crop: 2d array with low x norm directions removed
            sigma_xy_crop: 2d array with low x and y norm directions removed
            sigma_yx_crop: 2d array with low x and y norm directiosn removed
            sigma_yy_crop: 2d array with low y norm directions removed
            x_idxs: indexes of sigma_xx that were removed
            y_idxs: indexes of sigma_yy that were removed
    """

    x_diag = np.abs(np.diagonal(sigma_xx))
    y_diag = np.abs(np.diagonal(sigma_yy))
    x_idxs = (x_diag >= threshold)
    y_idxs = (y_diag >= threshold)

    sigma_xx_crop = sigma_xx[x_idxs][:, x_idxs]
    sigma_xy_crop = sigma_xy[x_idxs][:, y_idxs]
    sigma_yx_crop = sigma_yx[y_idxs][:, x_idxs]
    sigma_yy_crop = sigma_yy[y_idxs][:, y_idxs]

    return (sigma_xx_crop, sigma_xy_crop, sigma_yx_crop, sigma_yy_crop, x_idxs, y_idxs)


def gs_orthonormalize(array):
    """Gram-Schmidt orthonormalization."""
    q, _ = np.linalg.qr(array)
    if q.shape[1] < array.shape[1]:
        zero_pad = np.zeros(shape=(q.shape[0], array.shape[1]-q.shape[1]))
        q = np.concatenate([q, zero_pad], 1)
    return q


# modified from and built on the codes in https://github.com/google/svcca/blob/master/cca_core.py
def solve_cca(x, y):
    """Calculate CCA correlations, position vectors, images, and Mean|Projection Weighted similarity.
    
    The terms, 'correlation', 'position vector', 'image' are detailed in [1].
    [1] A Tutorial on Canonical Correlation Methods, Uurtio et al
    
    Args:
        x: A representation of shape (num_neurons, num_datapoints)
        y: A representation of shape (num_neurons, num_datapoints)
    """
    assert x.ndim == y.ndim == 2, 'both x and y should be 2D array, [num_neurons, num_datapoints]'
    assert x.shape[1] == y.shape[1], 'the number of datapoints between x and y do not match'
    assert x.shape[0] <= x.shape[1], 'num_datapoints should be greater than or equal to num_neurons. please check x.shape'
    assert y.shape[0] <= y.shape[1], 'num_datapoints should be greater than or equal to num_neurons. please check y.shape'
    epsilon = 1e-6
    
    numx = x.shape[0]
    numy = y.shape[0]

    sigma = np.cov(x, y)
    sigmaxx = sigma[:numx, :numx]
    sigmaxy = sigma[:numx, numx:]
    sigmayx = sigma[numx:, :numx]
    sigmayy = sigma[numx:, numx:]

    # normalize covariance matrices for stability
    xmax = np.max(np.abs(sigmaxx))
    ymax = np.max(np.abs(sigmayy))
    sigmaxx /= xmax
    sigmayy /= ymax
    sigmaxy /= np.sqrt(xmax*ymax)
    sigmayx /= np.sqrt(ymax*xmax)
    
    # remove negligibly small covariances
    sigmaxx, sigmaxy, sigmayx, sigmayy, x_idxs, y_idxs = remove_small(sigmaxx, sigmaxy, sigmayx, sigmayy)
    x = x[x_idxs]
    y = y[y_idxs]

    numx = sigmaxx.shape[0]
    numy = sigmayy.shape[0]
    if numx == 0 or numy == 0:
        raise NotImplementedError('check here.')

    sigmaxx += epsilon*np.eye(numx)
    sigmayy += epsilon*np.eye(numy)
    inv_sigmaxx = np.linalg.pinv(sigmaxx)
    inv_sigmayy = np.linalg.pinv(sigmayy)
    invsqrt_sigmaxx = positivedef_matrix_sqrt(inv_sigmaxx)
    invsqrt_sigmayy = positivedef_matrix_sqrt(inv_sigmayy)

    arrx = invsqrt_sigmaxx.dot(sigmaxy).dot(inv_sigmayy.dot(sigmayx.dot(invsqrt_sigmaxx)))
    arry = invsqrt_sigmayy.dot(sigmayx).dot(inv_sigmaxx.dot(sigmaxy.dot(invsqrt_sigmayy)))
    arrx += epsilon*np.eye(arrx.shape[0])
    arry += epsilon*np.eye(arry.shape[0])

    ux, sx, vhx = np.linalg.svd(arrx)
    uy, sy, vhy = np.linalg.svd(arry)

    cca_corr_x = np.sqrt(np.abs(sx)) # each value represents k-th order canonical correlation coefficient of x
    cca_corr_x = np.where(cca_corr_x>1, 1, cca_corr_x)
    cca_corr_x = np.where(cca_corr_x<epsilon, 0, cca_corr_x)
    
    cca_corr_y = np.sqrt(np.abs(sy)) 
    cca_corr_y = np.where(cca_corr_y>1, 1, cca_corr_y)
    cca_corr_y = np.where(cca_corr_y<epsilon, 0, cca_corr_y)
    
    # check
    cca_pos_x = vhx.dot(invsqrt_sigmaxx) # each row represents k-th order canonical correlation position vector of x
    cca_pos_y = vhy.dot(invsqrt_sigmayy)
    
    # check
    cca_image_x = cca_pos_x.dot(x) # each row represents k-th order canonical correlation image of x
    cca_image_y = cca_pos_y.dot(y)

    min_numxy = min(numx, numy)
    truncated_corr_x = cca_corr_x[:min_numxy]
    truncated_corr_y = cca_corr_y[:min_numxy]
    equally_weighted_cca_sim_x = truncated_corr_x.mean()
    equally_weighted_cca_sim_y = truncated_corr_y.mean()
    
    truncated_cca_image_x = cca_image_x[:min_numxy]
    truncated_cca_image_y = cca_image_y[:min_numxy]

    # check
    orthonorm_cca_image_x = gs_orthonormalize(truncated_cca_image_x)
    orthonorm_cca_image_y = gs_orthonormalize(truncated_cca_image_y)

    projection_weights_x = np.abs(orthonorm_cca_image_x.dot(x.T)).sum(1)
    projection_weights_x /= projection_weights_x.sum()
    projection_weighted_cca_sim_x = (projection_weights_x*truncated_corr_x).sum() # dist = 1 - sim

    projection_weights_y = np.abs(orthonorm_cca_image_y.dot(y.T)).sum(1)
    projection_weights_y /= projection_weights_y.sum()
    projection_weighted_cca_sim_y = (projection_weights_y*truncated_corr_y).sum() # dist = 1 - sim
    
    output_dicts = {}
    output_dicts['cca_corr_x'] = cca_corr_x
    output_dicts['cca_corr_y'] = cca_corr_y
    output_dicts['cca_pos_x'] = cca_pos_x
    output_dicts['cca_pos_y'] = cca_pos_y
    output_dicts['cca_image_x'] = cca_image_x
    output_dicts['cca_image_y'] = cca_image_y
    output_dicts['ewcca_sim_x'] = equally_weighted_cca_sim_x
    output_dicts['ewcca_sim_y'] = equally_weighted_cca_sim_y
    output_dicts['pwcca_sim_x'] = projection_weighted_cca_sim_x
    output_dicts['pwcca_sim_y'] = projection_weighted_cca_sim_y

    return output_dicts

In [None]:
# keep the K most frequent words
state_space_spec_small = state_space_spec.keep_top_k(k)

In [None]:
# keep at most N instances per word
state_space_spec_small = state_space_spec_small.subsample_instances(n, random=True)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec_small, pad=np.nan)

In [None]:
def evaluate_cca(trajectory, state_space_spec, agg_method, cv=5):
    """
    Evaluate CCA alignment between model representations and one-hot word embeddings.
    """
    if agg_method is not None:
        trajectory_agg = aggregate_state_trajectory(trajectory, state_space_spec, agg_method, keepdims=True)
    else:
        trajectory_agg = trajectory
    flat_traj, flat_traj_src = flatten_trajectory(trajectory_agg)

    # Z-score
    flat_traj = (flat_traj - flat_traj.mean(0)) / flat_traj.std(0)

    # Target values
    Y = np.zeros((len(flat_traj), k), dtype=int)
    Y[np.arange(len(flat_traj)), flat_traj_src[:, 0]] = 1

    cv = KFold(cv, shuffle=True) if isinstance(cv, int) else cv
    # NB here "frame" depends on the aggregation method; this may correspond to a model frame,
    # phoneme, syllable, etc.
    max_num_frames = flat_traj_src[:, 2].max() + 1

    # store the images of all instances in the aligned space
    # keys are (frame_idx, fold_idx)
    cca_images = {}
    cca_scores = np.zeros((max_num_frames, cv.get_n_splits(), 4)) * np.nan
    for frame_idx in trange(max_num_frames, desc="Estimating CCA", unit="frame"):
        sample_idxs = np.where(flat_traj_src[:, 2] == frame_idx)[0]
        if len(sample_idxs) / cv.get_n_splits() < flat_traj.shape[1]:
            # Not enough samples
            continue

        for fold_idx, (train_idxs, test_idxs) in enumerate(cv.split(sample_idxs)):
            x_src = flat_traj_src[sample_idxs[train_idxs]]
            x, y = flat_traj[sample_idxs[train_idxs]].T, Y[sample_idxs[train_idxs]].T
            cca = solve_cca(x, y)
            cca_scores[frame_idx, fold_idx, 0] = cca["pwcca_sim_x"]
            cca_scores[frame_idx, fold_idx, 1] = cca["pwcca_sim_y"]
            cca_scores[frame_idx, fold_idx, 2] = cca["ewcca_sim_x"]
            cca_scores[frame_idx, fold_idx, 3] = cca["ewcca_sim_y"]

            cca_images[frame_idx, fold_idx] = cca["cca_pos_x"] @ flat_traj.T

    cca_scores_df = ndarray_to_long_dataframe(cca_scores, ["frame_idx", "fold_idx", "measure"]).reset_index()
    cca_scores_df["measure"] = cca_scores_df["measure"].map({0: "pw_x", 1: "pw_y", 2: "ew_x", 3: "ew_y"})

    return flat_traj, flat_traj_src, cca_scores_df, cca_images

In [None]:
for name, agg_spec in tqdm(agg_methods.items(), unit="method"):
    flat_traj, flat_traj_src, cca_scores_df, cca_images = evaluate_cca(trajectory, state_space_spec_small, agg_spec, cv=5)
    cca_scores_df.to_csv(f"{output_dir}/cca_scores-{name}.csv", index=False)
    with open(f"{output_dir}/cca_images-{name}.pkl", "wb") as f:
        pickle.dump(cca_images, f)

    max_num_frames = cca_scores_df.dropna()["frame_idx"].max() + 1
    min_value = min(0.5, cca_scores_df["value"].min())
    max_value = cca_scores_df["value"].max()

    f, ax = plt.subplots(figsize=(12, 6))
    if max_num_frames > 1:
        sns.lineplot(data=cca_scores_df, x="frame_idx", y="value", hue="measure", ax=ax)
        ax.set_title(f"CCA alignment scores (aggregation: {name})")
        ax.set_xlabel("Frame index")
        ax.set_ylim((min_value, max_value))
    else:
        sns.barplot(data=cca_scores_df, x="measure", y="value", ax=ax)
        ax.set_title(f"CCA alignment scores ({name})")
        ax.set_ylim((min_value, max_value))
    f.savefig(Path(output_dir) / f"cca_scores-{name}.png")

    # plot PCA of resulting image space for a spectrum of frames
    num_plots = 5
    # pick a random fold
    fold_idx = np.random.randint(cca_scores_df.dropna().fold_idx.max())
    # pick random words to sample
    plot_sample_idxs = np.random.choice(len(flat_traj), min(100, len(flat_traj)), replace=False)
    frame_points = np.unique(np.linspace(0, max_num_frames - 1, num_plots, dtype=int))

    for frame_idx in frame_points:
        cca_image_i = cca_images[frame_idx, fold_idx]
        pca = PCA(2).fit(cca_image_i.T)

        plot_points = pca.transform(cca_image_i[:, plot_sample_idxs].T)
        plot_label_idxs = flat_traj_src[plot_sample_idxs, 0]
        
        f, ax = plt.subplots(figsize=(12, 12))
        ax.scatter(*plot_points.T)
        ax.set_title(f"PCA of CCA image space (aggregation: {name}, frame {frame_idx})")
        for i, label_idx in enumerate(plot_label_idxs):
            ax.text(*plot_points[i], state_space_spec.labels[label_idx], fontsize=8)

        f.savefig(Path(output_dir) / f"pca_image-{name}-frame{frame_idx}.png")

plt.close("all")