In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:

import itertools
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import coherence
from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
sns.set_theme(style="whitegrid", context="talk")

In [None]:
model_dir = "outputs/models/timit/w2v2_6/rnn_8/phoneme"
output_dir = "outputs/notebooks/timit/w2v2_6/rnn_8/phoneme/plot"
dataset_path = "outputs/preprocessed_data/timit"
equivalence_path = "outputs/equivalence_datasets/timit/w2v2_6/phoneme/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/timit/w2v2_6/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/timit/w2v2_6/state_space_specs.pkl"
embeddings_path = "outputs/model_embeddings/timit/w2v2_6/rnn_8/phoneme/embeddings.npy"

output_dir = "."

metric = "cosine"

# Retain syllables with N or more instances
retain_n = 10

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(equivalence_path, "rb") as f:
    equiv_dataset: SpeechEquivalenceDataset = torch.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["syllable"]
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
drop_idxs = [idx for idx, target_frames in enumerate(state_space_spec.target_frame_spans)
               if len(target_frames) < retain_n]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

In [None]:
spec_label_strs = ["".join(phones) for phones in state_space_spec.labels]

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

In [None]:
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
len(trajectory)

## Estimate within-syllable distance

In [None]:
within_distance, within_distance_offset = \
    coherence.estimate_within_distance(trajectory, lengths, state_space_spec, metric=metric)

In [None]:
sns.heatmap(within_distance, center=1, cmap="RdBu")

In [None]:
within_distance_df = pd.DataFrame(within_distance, index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

In [None]:
within_distance_offset_df = pd.DataFrame(within_distance_offset, index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

## Estimate between-syllable distance

In [None]:
between_distance, between_distance_offset = \
    coherence.estimate_between_distance(trajectory, lengths, state_space_spec, metric=metric)

In [None]:
between_distances_df = pd.DataFrame(np.nanmean(between_distance, axis=-1),
                                    index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

In [None]:
between_distances_offset_df = pd.DataFrame(np.nanmean(between_distance_offset, axis=-1),
                                    index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

## Together

In [None]:
merged_df = pd.concat([within_distance_df.assign(type="within"), between_distances_df.assign(type="between")])
merged_df["time"] = merged_df.frame / 20
merged_df.to_csv(Path(output_dir) / "distances.csv", index=False)
merged_df

In [None]:
ax = sns.lineplot(data=merged_df.dropna(), x="time", y="distance", hue="type")
ax.set_title("Representational distance within- and between-syllable")
ax.set_xlabel("Time since syllable onset (seconds)")
ax.set_ylabel(f"{metric.capitalize()} distance")

In [None]:
merged_offset_df = pd.concat([within_distance_offset_df.assign(type="within"), between_distances_offset_df.assign(type="between")])
merged_offset_df["time"] = (merged_offset_df.frame - 1) / 20
merged_offset_df.to_csv(Path(output_dir) / "distances_aligned_offset.csv", index=False)
merged_offset_df

In [None]:
ax = sns.lineplot(data=merged_offset_df.dropna(), x="time", y="distance", hue="type")
ax.set_title("Representational distance within- and between-syllable")
ax.set_xlabel("Time from syllable offset (seconds)")
ax.invert_xaxis()
ax.axvline(0, color="gray", linestyle="--")
ax.set_ylabel(f"{metric.capitalize()} distance")

## Estimate between-syllable distance, grouped by features

### Onset

In [None]:
onsets = [syll[0] for syll in state_space_spec.labels]

onset_distance_df, onset_distance_offset_df = coherence.estimate_category_within_between_distance(
    trajectory, lengths, onsets, metric=metric, labels=state_space_spec.labels
)

In [None]:
onset_distance_df["time"] = onset_distance_df.frame / 20
onset_distance_df.to_csv(Path(output_dir) / "distances-grouped_onset.csv", index=False)

In [None]:
ax = sns.lineplot(data=onset_distance_df.dropna(), x="time", y="distance", hue="type")
ax.set_title("Representational distance by onset match/mismatch")
ax.set_xlabel("Time since syllable onset (seconds)")
ax.set_ylabel(f"{metric.capitalize()} distance")

In [None]:
onset_distance_offset_df["time"] = (onset_distance_offset_df.frame - 1) / 20
onset_distance_offset_df.to_csv(Path(output_dir) / "distances-grouped_onset_aligned_offset.csv", index=False)

In [None]:
ax = sns.lineplot(data=onset_distance_offset_df.dropna(), x="time", y="distance", hue="type")
ax.set_title("Representational distance by onset match/mismatch")
ax.set_xlabel("Time from syllable offset (seconds)")
ax.invert_xaxis()
ax.axvline(0, color="gray", linestyle="--")
ax.set_ylabel(f"{metric.capitalize()} distance")

### Nucleus

In [None]:
vowels = ["AA", "AE", "AH", "AO", "AW", "AY", "EH", "ER", "EY", "IH", "IY", "OW", "OY", "UH", "UW"]
nuclei = []
for syll in state_space_spec.labels:
    syll_nucleus = None
    for phone in syll:
        if phone in vowels:
            syll_nucleus = phone
            break
    nuclei.append(syll_nucleus)

In [None]:
pd.Series(nuclei).value_counts()

In [None]:
nucleus_distance_df, nucleus_distance_offset_df = coherence.estimate_category_within_between_distance(
    trajectory, lengths, nuclei, metric=metric, labels=state_space_spec.labels)

In [None]:
nucleus_distance_df["time"] = (nucleus_distance_df.frame) / 20
nucleus_distance_df.to_csv(Path(output_dir) / "distances-grouped_nucleus.csv", index=False)

In [None]:
ax = sns.lineplot(data=nucleus_distance_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance by nucleus match/mismatch")
ax.set_xlabel("Frames since syllable onset")
ax.set_ylabel(f"{metric.capitalize()} distance")

In [None]:
nucleus_distance_offset_df["time"] = (nucleus_distance_offset_df.frame - 1) / 20
nucleus_distance_offset_df.to_csv(Path(output_dir) / "distances-grouped_nucleus_aligned_offset.csv", index=False)

In [None]:
ax = sns.lineplot(data=nucleus_distance_offset_df.dropna(), x="time", y="distance", hue="type")
ax.set_title("Representational distance by nucleus match/mismatch")
ax.set_xlabel("Time from syllable offset (seconds)")
ax.invert_xaxis()
ax.axvline(0, color="gray", linestyle="--")
ax.set_ylabel(f"{metric.capitalize()} distance")

## RSA

In [None]:
num_frames = trajectory[0].shape[1]

syllable_length = 2
all_syllables = sorted([label for label in state_space_spec.labels if len(label) == syllable_length])

# Prepare balanced sample of representations for each syllable in each category
num_instances = min(len(state_space_spec.target_frame_spans[state_space_spec.labels.index(syllable)])
                    for syllable in all_syllables)

syllable_representations = {}
for syllable in all_syllables:
    sample_instance_idxs = np.random.choice(len(state_space_spec.target_frame_spans[state_space_spec.labels.index(syllable)]),
                                            num_instances, replace=False)
    syllable_representations[syllable] = np.array([trajectory[state_space_spec.labels.index(syllable)][idx]
                                                   for idx in sample_instance_idxs])

# Compute between-phoneme distances
from src.analysis.coherence import get_mean_distance
distances = np.zeros((len(all_syllables), len(all_syllables), trajectory[0].shape[1]))
for p1, p2 in tqdm(list(itertools.product(list(range(len(all_syllables))), repeat=2))):
    if p1 == p2:
        continue
    for k in range(num_frames):
        distances[p1, p2, k] = get_mean_distance(syllable_representations[all_syllables[p1]][:, k, :],
                                                 syllable_representations[all_syllables[p2]][:, k, :], metric=metric)

In [None]:
all_syllable_labels = ["".join(phones) for phones in all_syllables]
sns.clustermap(pd.DataFrame(np.nanmean(distances, axis=-1), index=all_syllable_labels, columns=all_syllable_labels),
               center=1, cmap="RdBu")

In [None]:
plot_subset = np.random.choice(len(all_syllable_labels), size=30, replace=False)
plt.figure(figsize=(10, 10))
sns.clustermap(np.nanmean(distances, axis=-1)[plot_subset][:, plot_subset], center=1, cmap="RdBu",
            xticklabels=[all_syllable_labels[i] for i in plot_subset],
            yticklabels=[all_syllable_labels[i] for i in plot_subset])