In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import defaultdict, Counter
from dataclasses import replace
import json
import logging
from pathlib import Path
import pickle

import dask
from dask.distributed import Client, LocalCluster
import datasets
import matplotlib.pyplot as plt
from matplotlib import transforms
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
L = logging.getLogger(__name__)

In [None]:
base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/{train_dataset}.npy"

seed = 1234

max_samples_per_word = 100

metric = "cosine"

agg_fns = [
    "mean", "max", "last_frame",
    ("mean_last_k", 2), ("mean_last_k", 5),
    ("mean_first_k", 10),
]

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

# Subsample trajectories to reduce computation time
for i in range(len(trajectory)):
    if len(trajectory[i]) > max_samples_per_word:
        subsample_idxs = np.random.choice(len(trajectory[i]), max_samples_per_word, replace=False)
        trajectory[i] = trajectory[i][subsample_idxs]

lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
trajectory_aggs = {agg_fn: aggregate_state_trajectory(trajectory, state_space_spec, agg_fn, keepdims=True)
                   for agg_fn in tqdm(agg_fns)}
dummy_lengths = [np.ones(len(traj_i), dtype=int) for traj_i in trajectory]

In [None]:
trajectory_aggs_flat = {k: flatten_trajectory(v) for k, v in trajectory_aggs.items()}

In [None]:
trajectory_aggs_flat.keys()

## Prepare quantitative tools

In [None]:
def get_local_clustering_coefficients(xs, agg_method, k=5, metric="cosine"):
    """
    Compute local clustering coefficients for each point in the collection `xs`.

    The local clustering coefficient measures the density of connections between the
    neighbors of a point.
    
    Args:
    - xs: np.ndarray, shape (n_samples, n_features)
        Collection of points to evaluate.
    - agg_method: str
    """
    assert xs.ndim == 2
    references, references_src = trajectory_aggs_flat[agg_method]
    assert xs.shape[1] == references.shape[1]

    # compute K nearest neighbors for each point in `xs`
    dists = cdist(xs, references, metric=metric)
    neighbors = np.argsort(dists, axis=1)[:, :k]

    # find neighbors of each of these neighbors
    neighbor_embeddings = references[neighbors]
    neighbor_dists = cdist(neighbor_embeddings.reshape(-1, neighbor_embeddings.shape[-1]),
                           references, metric=metric)
    # skip the first neighbor since we know that the distance is 0
    neighbor_neighbors = np.argsort(neighbor_dists, axis=1)[:, 1:k+1] \
        .reshape(neighbor_embeddings.shape[0], k, k)

    # compute local clustering coefficients
    local_clustering_coeffs = np.zeros(xs.shape[0])
    for i, neighbors_i, neighbor_neighbors_i in zip(range(xs.shape[0]), neighbors, neighbor_neighbors):
        n_triangles = 0
        for j, (neighbor_ij, neighbor_neighbors_ij) in enumerate(zip(neighbors_i, neighbor_neighbors_i)):
            n_triangles += len(np.intersect1d(neighbor_neighbors_ij, neighbors_i[:j]))
        
        n_neighbors = len(neighbors_i)
        max_possible_triangles = n_neighbors * (n_neighbors - 1)
        local_clustering_coeffs[i] = n_triangles / max_possible_triangles

    return local_clustering_coeffs

In [None]:
# estimate KDEs for each trajectory agg
kde_models = {agg_method: KernelDensity(bandwidth=0.1).fit(trajectory_aggs_flat[agg_method][0] / np.linalg.norm(trajectory_aggs_flat[agg_method][0], axis=1, keepdims=True))
              for agg_method in tqdm(trajectory_aggs_flat)}

## Line search for hand-picked pairs

In [None]:
word_pairs = [
    ("the", "pomegranate"),
    ("ice", "nice"),
    ("eyes", "yes"),
    ("supervision", "vision"),
    ("eyes", "eye"),
    ("boys", "boy"),
    ("girls", "girl"),
    ("say", "ace"),
    ("says", "say"),
    ("reign", "rain"),
]

agg_method = ('mean_first_k', 10)


In [None]:
def run_line_search(word_pair, agg_method, step_norm=0.2, k=5, verbose=False, ax=None):
    start_word, end_word = word_pair

    assert start_word in state_space_spec.labels
    assert end_word in state_space_spec.labels

    start_word_idx = state_space_spec.labels.index(start_word)
    end_word_idx = state_space_spec.labels.index(end_word)

    start_instance = np.random.choice(min(max_samples_per_word, len(state_space_spec.target_frame_spans[start_word_idx])))
    end_instance = np.random.choice(min(max_samples_per_word, len(state_space_spec.target_frame_spans[end_word_idx])))

    start_traj = trajectory_aggs[agg_method][start_word_idx][start_instance].squeeze()
    end_traj = trajectory_aggs[agg_method][end_word_idx][end_instance].squeeze()

    # navigate from start_traj to end_traj by steps of norm step_norm
    step_vector = (end_traj - start_traj) / np.linalg.norm(end_traj - start_traj)
    xs = [start_traj]
    while np.linalg.norm(xs[-1] - end_traj) > step_norm:
        xs.append(xs[-1] + step_vector * step_norm)
    assert np.allclose(xs[0], start_traj)
    xs = np.array(xs)

    assert xs.ndim == 2
    references, references_src = trajectory_aggs_flat[agg_method]

    # compute nearest neighbors at each step
    dists = cdist(xs, references, metric=metric)
    ranks = dists.argsort(axis=1)

    # compute local clustering coefficients at each step
    local_clustering_coeffs = get_local_clustering_coefficients(
        xs, agg_method, k=10, metric="cosine")
    # # DEV
    # local_clustering_coeffs = np.zeros(len(xs))

    # compute density estimates at each step
    log_densities = kde_models[agg_method].score_samples(xs / np.linalg.norm(xs, axis=1, keepdims=True))

    # prepare labeled results at each step
    word_results = {}
    # prepare a long dataframe as well
    neighbors_df = []
    metrics_df = []
    for i, step_results in enumerate(references_src[ranks]):
        word_dist_results = defaultdict(list)
        for j, (label_idx, instance_idx, _) in enumerate(step_results):
            word_dist_results[state_space_spec.labels[label_idx]].append(dists[i, j])

            if len(word_dist_results) > k:
                break

        word_results[i] = word_dist_results

        for word, dists_ij in word_dist_results.items():
            for dist in dists_ij:
                neighbors_df.append({
                    "step": i,
                    "word": word,
                    "dist": dist,
                })
        metrics_df.append({"step": i,
                           "log_density": log_densities[i],
                           "local_clustering_coeff": local_clustering_coeffs[i]})

    if verbose:
        for i, step_word_results in enumerate(word_results):
            print(i, log_densities[i], local_clustering_coeffs[i])

            for word, dists in word_results[step_word_results].items():
                print(word, len(dists), np.median(dists))
            print()

    if ax is not None:
        ax.plot(local_clustering_coeffs, color="blue")
        ax.set_ylabel("local clustering coeff", color="blue")
        ax.set_xlabel("step")

        ax2 = ax.twinx()
        ax2.plot(log_densities, color="red")
        ax2.set_ylabel("log density", color="red")

        # plot start and end word label
        # use ax.transAxes
        ax.text(0.05, 0.05, start_word, transform=ax.transAxes)
        ax.text(0.95, 0.95, end_word, transform=ax.transAxes)

        # at each step, plot the label of the nearest neighbor
        for i, step_results in word_results.items():
            ax.text(i, 0.5, list(step_results.keys())[0], ha="center", va="center",
                    color="black", rotation=90,
                    transform=transforms.blended_transform_factory(ax.transData, ax.transAxes))
        
    return pd.DataFrame(neighbors_df), pd.DataFrame(metrics_df)

In [None]:
axs = [None] * len(word_pairs)

all_neighbors_df, all_metrics_df = [], []

for i, (ax, word_pair) in enumerate(zip(axs, tqdm(word_pairs))):
    try:
        neighbors_df, metrics_df = run_line_search(word_pair, agg_method)#, ax=ax)
    except AssertionError:
        # missing word
        continue

    all_neighbors_df.append(neighbors_df)
    all_metrics_df.append(metrics_df)

In [None]:
concat_keys = keys=[f"{start_word}-{end_word}" for start_word, end_word in word_pairs]
all_neighbors_df = pd.concat(all_neighbors_df, names=["word_pair"], keys=concat_keys)
all_metrics_df = pd.concat(all_metrics_df, names=["word_pair"], keys=concat_keys)

In [None]:
plot_metrics_df = all_metrics_df.droplevel(-1).reset_index().melt(id_vars=["word_pair", "step"])
# normalize values within variable
plot_metrics_df["value"] = plot_metrics_df.groupby("variable")["value"].transform(lambda x: (x - x.min()) / (x.max() - x.min()))

top_neighbors_df = all_neighbors_df.droplevel(-1).groupby(["word_pair", "step"]).head(1)

with sns.plotting_context("talk", font_scale=2):
    g = sns.relplot(data=plot_metrics_df,
                col="word_pair", col_wrap=2, aspect=2.5, height=6,
                hue="variable", x="step", y="value", kind="line")
    
    for (row, col, hue), data in g.facet_data():
        word_pair = data.iloc[0].word_pair
        ax = g.facet_axis(row, col)

        top_neighbors_i = top_neighbors_df.loc[word_pair]
        if isinstance(top_neighbors_i, pd.Series):
            top_neighbors_i = pd.DataFrame([top_neighbors_i])
        for j, (_, row) in enumerate(top_neighbors_i.iterrows()):
            ax.text(j, 0.5, row["word"], ha="center", va="center",
                    color="black", rotation=90, alpha=0.4,
                    transform=transforms.blended_transform_factory(ax.transData, ax.transAxes))

## Bottom-up line search

In [None]:
# num_pairs = 100
# min_word_freq = 100

# word_freqs = state_space_spec.label_counts
# candidate_words = word_freqs[word_freqs >= min_word_freq].index

# word_pairs_bu = [np.random.choice(len(candidate_words), 2, replace=False) for _ in range(num_pairs)]
# word_pairs_bu = [(candidate_words[start_word_idx], candidate_words[end_word_idx]) for start_word_idx, end_word_idx in word_pairs_bu]
# word_pairs_bu[:5]

In [None]:
# # do the above but with Dask
# if "client" not in locals():
#     cluster = LocalCluster(n_workers=16)
#     client = Client(cluster)

# @dask.delayed
# def run_line_search_dask(word_pair, agg_method, step_norm=0.2, k=5, verbose=False):
#     return run_line_search(word_pair, agg_method, step_norm, k, verbose)

# promises = [run_line_search_dask(word_pair, agg_method) for word_pair in word_pairs_bu]
# results = dask.compute(*promises, scheduler="processes")

In [None]:
# concat_keys = keys=[f"{start_word}-{end_word}" for start_word, end_word in word_pairs_bu]
# all_metrics_bu_df = pd.concat([metrics_df for _, metrics_df in results], names=["word_pair"], keys=concat_keys)

In [None]:
# # block line search trajectories by length, then cluster trajectories according to their metrics
# # (local clustering coefficient, log density)
# # then visualize the trajectories of each cluster
# line_search_lengths = all_metrics_bu_df.groupby("word_pair").size()
# _, length_blocks = np.histogram(line_search_lengths, bins=3)
# length_blocks = np.ceil(length_blocks).astype(int)
# length_block_assignments = np.digitize(line_search_lengths, bins=length_blocks)