In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import datasets
import matplotlib.pyplot as plt
from matplotlib import transforms
import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import torch
from tqdm.auto import tqdm

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/{train_dataset}.npy"

seed = 1234

max_samples_per_word = 100

metric = "cosine"

agg_fns = [
    "mean",
]

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["word"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
state_space_spec = state_space_spec.subsample_instances(max_samples_per_word)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

In [None]:
trajectory_aggs = {agg_fn: aggregate_state_trajectory(trajectory, state_space_spec, agg_fn, keepdims=True)
                   for agg_fn in tqdm(agg_fns)}

In [None]:
trajectory_aggs_flat = {k: flatten_trajectory(v) for k, v in trajectory_aggs.items()}

In [None]:
def estimate_analogy(triple, agg_method="mean", num_samples=50, k=20, verbose=False):
    """
    Returns:
    - result_df: a df describing the k nearest neighbors to the analogy vector in each sample
    - difference_vectors: the difference vectors used in the analogy for each sample
    - analogy_vectors: the analogy vectors for each sample
    """

    word_a, word_b, word_c, expected = triple
    assert word_a in state_space_spec.labels
    assert word_b in state_space_spec.labels
    assert word_c in state_space_spec.labels

    # if the expected word isn't in the vocabulary, this isn't really interpretable
    assert expected in state_space_spec.labels

    word_a_idx = state_space_spec.labels.index(word_a)
    word_b_idx = state_space_spec.labels.index(word_b)
    word_c_idx = state_space_spec.labels.index(word_c)

    # collect results of a - b
    difference_vectors = []
    # collect results of a - b + c
    analogy_vectors = []

    for _ in range(num_samples):
        word_a_instance = np.random.choice(min(max_samples_per_word, len(state_space_spec.target_frame_spans[word_a_idx])))
        word_b_instance = np.random.choice(min(max_samples_per_word, len(state_space_spec.target_frame_spans[word_b_idx])))
        word_c_instance = np.random.choice(min(max_samples_per_word, len(state_space_spec.target_frame_spans[word_c_idx])))

        word_a_traj = trajectory_aggs[agg_method][word_a_idx][word_a_instance].squeeze()
        word_b_traj = trajectory_aggs[agg_method][word_b_idx][word_b_instance].squeeze()
        word_c_traj = trajectory_aggs[agg_method][word_c_idx][word_c_instance].squeeze()

        difference_vector = word_a_traj - word_b_traj
        analogy_vector = difference_vector + word_c_traj

        difference_vectors.append(difference_vector)
        analogy_vectors.append(analogy_vector)

    difference_vectors = np.array(difference_vectors)
    analogy_vectors = np.array(analogy_vectors)

    references, references_src = trajectory_aggs_flat[agg_method]
    dists = cdist(analogy_vectors, references, metric=metric).mean(axis=0)
    ranks = dists.argsort()

    if verbose:
        for dist, (label_idx, instance_idx, _) in zip(dists[ranks[:k]], references_src[ranks[:k]]):
            print(dist, state_space_spec.labels[label_idx])

    ret = pd.DataFrame(references_src[ranks[:k]], columns=["label_idx", "instance_idx", "frame_idx"])
    ret["distance"] = dists[ranks[:k]]
    ret["label"] = [state_space_spec.labels[label_idx] for label_idx in ret["label_idx"]]
    return ret, difference_vectors, analogy_vectors

In [None]:
analogy_dataset = datasets.load_dataset("relbert/analogy_questions", "bats") \
    ["test"].filter(lambda x: "morphology" in x["prefix"])

In [None]:
difference_vectors = []
prediction_results = []
k = 20
for item in tqdm(analogy_dataset):
    b, a = item["stem"]
    c, d = item["choice"][item["answer"]]
    
    try:
        ret, difference_vectors_i, _ = estimate_analogy((a, b, c, d), num_samples=100, k=k, verbose=False)
    except AssertionError:
        continue
    
    nearest_neighbor = ret.iloc[0].label
    prediction_results.append(
        {"nearest_neighbor": nearest_neighbor,
         "expected": d,
         "correct": nearest_neighbor == d,
         "correct_topk": d in ret.iloc[:k].label,
         "correct_position": ret[ret.label == d].index[0] if d in ret.label.values else None,
         **item})
    
    difference_vectors.append({"a": a, "b": b, "prefix": item["prefix"],
                               "difference_vectors": difference_vectors_i})

In [None]:
results_df = pd.DataFrame(prediction_results).drop(columns=["choice"])
results_df.to_csv(Path(output_dir) / "analogy_results.csv", index=False)
results_df

In [None]:
summary_df = results_df.groupby("prefix").correct.agg(["count", "mean"]).sort_values("mean")
summary_df.to_csv(Path(output_dir) / "analogy_summary.csv")
summary_df

In [None]:
torch.save(difference_vectors, Path(output_dir) / "analogy_difference_vectors.pt")