In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
from collections import defaultdict, Counter
from dataclasses import replace
import json
import logging
from pathlib import Path
import pickle
from typing import Union

import datasets
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
import torch
from tqdm.auto import tqdm, trange

from src.analysis.state_space import StateSpaceAnalysisSpec, \
    prepare_state_trajectory, aggregate_state_trajectory, flatten_trajectory
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [5]:
L = logging.getLogger(__name__)

In [6]:
base_model = "w2v2_8"
model_class = "rnn_32-hinge-mAP4"
model_name = "word_broad"
train_dataset = "librispeech-train-clean-100"
model_dir = f"outputs/models/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames"
output_dir = f"."
dataset_path = f"outputs/preprocessed_data/{train_dataset}"
equivalence_path = f"outputs/equivalence_datasets/{train_dataset}/{base_model}/{model_name}_10frames/equivalence.pkl"
hidden_states_path = f"outputs/hidden_states/{train_dataset}/{base_model}/{train_dataset}.h5"
state_space_specs_path = f"outputs/state_space_specs/{train_dataset}/{base_model}/state_space_specs.h5"
embeddings_path = f"outputs/model_embeddings/{train_dataset}/{base_model}/{model_class}/{model_name}_10frames/{train_dataset}.npy"

seed = 1234

# Add 4 frames prior to onset to each trajectory
expand_frame_window = (4, 0)

# Only use plot words for PCA or use whole vocabulary?
pca_plot_words_only = False
# Use words with this many or more instances to estimate embedding PCA space
pca_freq_min = 15
# Ignore words with this many or more instances when estimating embedding PCA space
pca_freq_max = 10000

# Use at most this many samples of each word in computing PCA (for computational efficiency)
pca_max_samples_per_word = 100

metric = "cosine"

In [7]:
np.random.seed(seed)

In [11]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "word")
assert state_space_spec.is_compatible_with(model_representations)

In [9]:
# use all words with frequency greater than cutoff to compute PCA
word_freqs = {label: len(trajs) for trajs, label in
            zip(state_space_spec.target_frame_spans, state_space_spec.labels)}

min_suffix_type_count = 10
min_prefix_type_count = 10
plot_suffixes = 8
plot_prefixes = 8

# find longest possible suffixes for which we have at least min_suffix_type_count word types matching frequency constraint
best_suffixes = []
for suffix_length in range(9):
    suffix_counts = Counter()
    for word in state_space_spec.labels:
        if len(word) < suffix_length or word_freqs[word] < pca_freq_min or word_freqs[word] >= pca_freq_max:
            continue
        suffix_counts[word[-suffix_length:]] += 1

    for suffix, count in suffix_counts.most_common():
        if count < min_suffix_type_count:
            break
        best_suffixes.append((suffix_length, count, suffix))
best_suffixes = sorted(best_suffixes, reverse=True)[:plot_suffixes]

# same for prefixes
best_prefixes = []
for prefix_length in range(9):
    prefix_counts = Counter()
    for word in state_space_spec.labels:
        if len(word) < prefix_length or word_freqs[word] < pca_freq_min or word_freqs[word] >= pca_freq_max:
            continue
        prefix_counts[word[:prefix_length]] += 1

    for prefix, count in prefix_counts.most_common():
        if count < min_prefix_type_count:
            break
        best_prefixes.append((prefix_length, count, prefix))
best_prefixes = sorted(best_prefixes, reverse=True)[:plot_prefixes]

print("Best prefixes", best_prefixes)
print("Best suffixes", best_suffixes)

Best prefixes [(5, 11, 'inter'), (5, 10, 'count'), (4, 23, 'cons'), (4, 22, 'inte'), (4, 20, 'comp'), (4, 18, 'cont'), (4, 17, 'comm'), (4, 12, 'pres')]
Best suffixes [(6, 11, 'ection'), (5, 46, 'ation'), (5, 19, 'tions'), (5, 19, 'ction'), (5, 14, 'ssion'), (5, 14, 'ently'), (5, 14, 'ement'), (5, 12, 'ering')]


In [None]:
plot_sets = {
    "exploratory": ["allow", "about", "around",
                    "before", "black", "barely",
                    "small", "said", "such",
                    "please", "people", "problem"],    
}
plot_sets.update({f"prefix_{prefix}": [word for word in state_space_spec.labels if word.startswith(prefix)
                                       and word_freqs[word] >= pca_freq_min and word_freqs[word] < pca_freq_max]
                  for _, _, prefix in best_prefixes})
plot_sets.update({f"suffix_{suffix}": [word for word in state_space_spec.labels if word.endswith(suffix)
                                       and word_freqs[word] >= pca_freq_min and word_freqs[word] < pca_freq_max]
                  for _, _, suffix in best_suffixes})

plot_words = set(word for words in plot_sets.values() for word in words)
if any(word not in state_space_spec.labels for word in plot_words):
    raise ValueError(f"Plot words not found in state space: {plot_words}")

if pca_plot_words_only:
    pca_words = plot_words
else:
    # use all words with frequency between cutoffs to compute PCA
    pca_words = sorted([(freq, label) for label, freq in word_freqs.items()
                        if freq >= pca_freq_min and freq < pca_freq_max], reverse=True)
    pca_words = [label for _, label in pca_words]

    for plot_word in plot_words:
        if plot_word not in pca_words:
            L.warn(f"Plot word {plot_word} not found in PCA words")
            pca_words.append(plot_word)

drop_idxs = [idx for idx, word in enumerate(state_space_spec.labels)
             if word not in pca_words]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

In [None]:
state_space_spec = state_space_spec.subsample_instances(pca_max_samples_per_word)

## Prepare PCA on full set

In [None]:
trajectory = prepare_state_trajectory(
    model_representations,
    state_space_spec,
    expand_window=expand_frame_window,
    pad=np.nan
)
all_trajectories, all_trajectories_src = flatten_trajectory(trajectory)

In [None]:
pipeline = make_pipeline(StandardScaler(), PCA(n_components=2))
all_trajectories_pca = pipeline.fit_transform(all_trajectories)

## State space analysis over plot sets

In [None]:
def plot_state_space(plot_key):
    plot_words = plot_sets[plot_key]
    state_space_spec_sub = state_space_spec.drop_labels([idx for idx, word in enumerate(state_space_spec.labels) if word not in plot_words])

    trajectory = prepare_state_trajectory(
        model_representations,
        state_space_spec_sub,
        expand_window=expand_frame_window,
        pad=np.nan
    )

    # Subsample trajectories to reduce computation time
    for i in range(len(trajectory)):
        if len(trajectory[i]) > pca_max_samples_per_word:
            subsample_idxs = np.random.choice(len(trajectory[i]), pca_max_samples_per_word, replace=False)
            trajectory[i] = trajectory[i][subsample_idxs]

    all_trajectories_full = np.concatenate(trajectory)
    all_trajectories_src = np.array(list(np.ndindex(all_trajectories_full.shape[:2])))

    # flatten & retain non-padding
    all_trajectories = all_trajectories_full.reshape(-1, all_trajectories_full.shape[-1])
    retain_idxs = ~np.isnan(all_trajectories).any(axis=1)
    all_trajectories = all_trajectories[retain_idxs]
    all_trajectories_src = all_trajectories_src[retain_idxs]

    # use previously fit scaler+PCA to transform these representations
    all_trajectories_pca = pipeline.transform(all_trajectories)

    all_trajectories_pca_padded = np.full(all_trajectories_full.shape[:2] + (2,), np.nan)
    all_trajectories_pca_padded[all_trajectories_src[:, 0], all_trajectories_src[:, 1]] = all_trajectories_pca

    # get index of first nan in each item; back-fill with last value
    for idx, nan_onset in enumerate(np.isnan(all_trajectories_pca_padded)[:, :, 0].argmax(axis=1)):
        if nan_onset == 0:
            continue
        all_trajectories_pca_padded[idx, nan_onset:] = all_trajectories_pca_padded[idx, nan_onset - 1]

    # plotting helpers
    trajectory_dividers = np.cumsum([traj.shape[0] for traj in trajectory])
    trajectory_dividers = np.concatenate([[0], trajectory_dividers])
    # get just the dividers for plot_words
    plot_word_dividers = []
    for word in plot_words:
        class_idx = state_space_spec_sub.labels.index(word)
        left_edge = trajectory_dividers[class_idx]
        right_edge = trajectory_dividers[class_idx + 1] if class_idx + 1 < len(trajectory_dividers) else len(all_trajectories_pca_padded)
        plot_word_dividers.append((left_edge, right_edge))
    
    #####

    min, max = all_trajectories_pca.min(), all_trajectories_pca.max()

    # Animate
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation

    fig, ax = plt.subplots()
    ax.set_xlim(np.floor(min), np.ceil(max))
    ax.set_ylim(np.floor(min), np.ceil(max))
    annot_frame = ax.text(-0.75, 0.75, "-1")

    color_classes = sorted(set(word for word in plot_words))
    cmap = sns.color_palette("Set1", len(color_classes))
    color_values = {class_: cmap[i] for i, class_ in enumerate(color_classes)}
    marker_values = {class_: "o" if i % 2 == 0 else "x" for i, class_ in enumerate(color_classes)}

    scats = [ax.scatter(np.zeros(end - start + 1), np.zeros(end - start + 1),
                        alpha=0.5,
                        marker=marker_values[word],
                        color=color_values[word],
                    ) for i, (word, (start, end)) in enumerate(zip(plot_words, plot_word_dividers))]
    ax.legend(scats, plot_words, loc=1)

    def init():
        for scat in scats:
            scat.set_offsets(np.zeros((0, 2)))
        return tuple(scats)

    def update(frame):
        for scat, (idx_start, idx_end) in zip(scats, plot_word_dividers):
            traj_i = all_trajectories_pca_padded[idx_start:idx_end, frame]
            scat.set_offsets(traj_i)
            # scat.set_array(np.arange(traj_i.shape[0]))
        annot_frame.set_text(str(frame))
        return tuple(scats) + (annot_frame,)

    # Animate by model frame
    num_frames = all_trajectories_pca_padded.shape[1]
    ani = animation.FuncAnimation(fig, update, frames=num_frames, interval=500,
                                init_func=init)
    ani.save(Path(output_dir) / f"state_space-{plot_key}.gif", writer="ffmpeg")

In [None]:
for key in tqdm(plot_sets):
    plot_state_space(key)

## Quiver

In [None]:
traj_agg_phoneme = aggregate_state_trajectory(trajectory, state_space_spec, ("mean_within_cut", "phoneme"), keepdims=True)

In [None]:
traj_agg_phoneme_flat, traj_agg_phoneme_src = flatten_trajectory(traj_agg_phoneme)

In [None]:
pipeline = make_pipeline(StandardScaler(), PCA(n_components=min(traj_agg_phoneme_flat.shape[1], 4)))
traj_agg_phoneme_pca = pipeline.fit_transform(traj_agg_phoneme_flat)

In [None]:
def plot_quiver(group_spec: Union[list[str], dict[str, list[str]]],
                traj_flat, traj_flat_src, state_space_spec, plot_cut_description="phoneme",
                legend=True, ax=None):
    if isinstance(group_spec, list):
        quiver_groups = {key: plot_sets[key] for key in group_spec}
    else:
        quiver_groups = group_spec

    palette = sns.color_palette("Set1")
    get_color = lambda idx: palette[idx % len(palette)]

    if plot_cut_description is not None:
        # Prepare cuts annotation
        cuts_df = state_space_spec.cuts.xs(plot_cut_description, level="level") \
            .drop(columns=["onset_frame_idx", "offset_frame_idx"])
        cuts_df["label_idx"] = cuts_df.index.get_level_values("label") \
            .map({label: idx for idx, label in enumerate(state_space_spec.labels)})
        cuts_df["frame_idx"] = cuts_df.groupby(["label", "instance_idx"]).cumcount()
        cuts_df = cuts_df.reset_index().set_index(["label_idx", "instance_idx", "frame_idx"])

    max_num_frames = traj_flat_src[:, 2].max() + 1
    quiver_data, quiver_data_src = {}, {}
    for group, words in tqdm(quiver_groups.items(), unit="group", leave=False):
        quiver_data_i, quiver_data_src_i = [], []
        word_idxs = [state_space_spec.labels.index(word) for word in words]
        mask = np.isin(traj_flat_src[:, 0], word_idxs)
        for j in trange(max_num_frames):
            mask_j = mask & (traj_flat_src[:, 2] == j)
            if not mask_j.any():
                break
            quiver_data_i.append(traj_flat[mask_j])
            quiver_data_src_i.append(traj_flat_src[mask_j])

        quiver_data[group] = quiver_data_i
        quiver_data_src[group] = quiver_data_src_i

    if ax is None:
        _, ax = plt.subplots(figsize=(12, 12))

    for i, group in enumerate(quiver_data):
        data = quiver_data[group]
        group_src = quiver_data_src[group]

        frame_means = np.array([np.nanmean(frame, axis=0) for frame in data])
        frame_sems = np.array([np.nanstd(frame, axis=0) / np.sqrt(np.sum(~np.isnan(frame).any(axis=1), axis=0))
                               for frame in data])
        frame_counts = np.array([np.sum(~np.isnan(frame).any(axis=1), axis=0) for frame in data])
        
        # quiver using means
        ax.quiver(frame_means[:-1, 0], frame_means[:-1, 1],
                frame_means[1:, 0] - frame_means[:-1, 0],
                frame_means[1:, 1] - frame_means[:-1, 1],
                linewidths=frame_counts[:-1] / frame_counts.max() * 3,
                edgecolors=get_color(i),
                angles="xy", scale_units="xy", scale=1.2,
                color=get_color(i),
                label=group)
        
        if plot_cut_description is not None:
            # transform axis coordinate jitter to data space
            jitter_magnitude = 0.05
            jitter_x_scale = 0.05 # (np.nanmax(frame_means[:, 0]) - np.nanmin(frame_means[:, 0])) * jitter_magnitude
            jitter_y_scale = 0.05 # (np.nanmax(frame_means[:, 1]) - np.nanmin(frame_means[:, 1])) * jitter_magnitude
            
            for j, group_src_j in enumerate(group_src):
                descriptions_j = cuts_df.loc[group_src_j[:, 0], group_src_j[:, 1], group_src_j[:, 2]].description.value_counts()
                descriptions_j /= descriptions_j.sum()
                # plot each description with jitter centered around frame mean; size proportional to relative frequency
                for k, (description, freq) in enumerate(descriptions_j.head(3).items()):
                    ax.text(frame_means[j, 0] + np.random.randn() * jitter_x_scale,
                            frame_means[j, 1] + np.random.randn() * jitter_y_scale,
                            description,
                            size=4 + 8 * freq,
                            color=get_color(i),
                            transform=ax.transData)
        
        # plot shaded region using sem
        for j1, j2 in zip(range(len(data) - 1), range(1, len(data))):
            polygon_edges = np.array([
                [frame_means[j1, 0] - frame_sems[j1, 0], frame_means[j1, 1]],
                [frame_means[j2, 0] - frame_sems[j2, 0], frame_means[j2, 1]],
                [frame_means[j2, 0] + frame_sems[j2, 0], frame_means[j2, 1]],
                [frame_means[j1, 0] + frame_sems[j1, 0], frame_means[j1, 1]],
            ])
            polygon = plt.Polygon(polygon_edges, alpha=0.2, color=get_color(i))
            ax.add_patch(polygon)
    if legend:
        ax.legend()

    return quiver_data, quiver_data_src

In [None]:
plot_quiver([x for x in plot_sets.keys() if x.startswith("prefix")],
            traj_agg_phoneme_pca[:, 0:2], traj_agg_phoneme_src, state_space_spec)
None

In [None]:
all_prefixes = [key for key in plot_sets.keys() if key.startswith("prefix")]
num_cols = 3
num_rows = int(np.ceil(len(all_prefixes) / num_cols))
f, axs = plt.subplots(num_rows, num_cols, figsize=(4 * num_cols, 4 * num_rows))

for prefix_key, ax in zip(tqdm(all_prefixes), axs.flat):
    plot_quiver({word: [word] for word in plot_sets[prefix_key]},
                traj_agg_phoneme_pca[:, 0:2], traj_agg_phoneme_src, state_space_spec, ax=ax, legend=False)
    ax.set_title(prefix_key)
None

In [None]:
if traj_agg_phoneme_pca.shape[1] > 2:
    start_pc = min(2, traj_agg_phoneme_pca.shape[1] - 2)

    all_prefixes = [key for key in plot_sets.keys() if key.startswith("prefix")]
    num_cols = 3
    num_rows = int(np.ceil(len(all_prefixes) / num_cols))
    f, axs = plt.subplots(num_rows, num_cols, figsize=(4 * num_cols, 4 * num_rows))

    for prefix_key, ax in zip(tqdm(all_prefixes), axs.flat):
        plot_quiver({word: [word] for word in plot_sets[prefix_key]},
                    traj_agg_phoneme_pca[:, start_pc:start_pc + 2], traj_agg_phoneme_src, state_space_spec, ax=ax, legend=False)
        ax.set_title(prefix_key)
        ax.set_xlabel(f"PC{start_pc}")
        ax.set_ylabel(f"PC{start_pc + 1}")
    None

In [None]:
# comparison: plot K equivalent groups of random words
baseline_group_spec = {
    f"random_{i}": np.random.choice(state_space_spec.labels, 12, replace=False)
    for i in range(8)
}
plot_quiver(baseline_group_spec,
            traj_agg_phoneme_pca[:, 0:2], traj_agg_phoneme_src, state_space_spec)
None