In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
import itertools
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import coherence
from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
model_dir = "outputs/models/timit/w2v2_6/rnn_8/phoneme"
output_dir = "outputs/notebooks/timit/w2v2_6/rnn_8/phoneme/plot"
dataset_path = "outputs/preprocessed_data/timit"
equivalence_path = "outputs/equivalence_datasets/timit/w2v2_6/phoneme/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/timit/w2v2_6/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/timit/w2v2_6/state_space_specs.h5"
embeddings_path = "outputs/model_embeddings/timit/w2v2_6/rnn_8/phoneme/embeddings.npy"

output_dir = "."

metric = "cosine"

# Retain phonemes with N or more instances
retain_n = 3

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
state_space_spec = StateSpaceAnalysisSpec.from_hdf5(state_space_specs_path, "phoneme")
assert state_space_spec.is_compatible_with(model_representations)

In [None]:
drop_idxs = [idx for idx, target_frames in enumerate(state_space_spec.target_frame_spans)
               if len(target_frames) < retain_n]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
len(trajectory)

## Estimate within-phoneme distance

In [None]:
within_distance, within_distance_offset = \
    coherence.estimate_within_distance(trajectory, lengths, state_space_spec, metric=metric)

In [None]:
sns.heatmap(within_distance, center=1, cmap="RdBu")

In [None]:
within_distance_df = pd.DataFrame(within_distance, index=pd.Index(state_space_spec.labels, name="phoneme")) \
    .reset_index() \
    .melt(id_vars=["phoneme"], var_name="frame", value_name="distance")

In [None]:
within_distance_offset_df = pd.DataFrame(within_distance_offset, index=pd.Index(state_space_spec.labels, name="phoneme")) \
    .reset_index() \
    .melt(id_vars=["phoneme"], var_name="frame", value_name="distance")

## Estimate between-phoneme distance

In [None]:
between_distance, between_distance_offset = \
    coherence.estimate_between_distance(trajectory, lengths, state_space_spec, metric=metric)

In [None]:
between_distances_df = pd.DataFrame(np.nanmean(between_distance, axis=-1),
                                    index=pd.Index(state_space_spec.labels, name="phoneme")) \
    .reset_index() \
    .melt(id_vars=["phoneme"], var_name="frame", value_name="distance")

In [None]:
between_distances_offset_df = pd.DataFrame(np.nanmean(between_distance_offset, axis=-1),
                                    index=pd.Index(state_space_spec.labels, name="phoneme")) \
    .reset_index() \
    .melt(id_vars=["phoneme"], var_name="frame", value_name="distance")

## Together

In [None]:
merged_df = pd.concat([within_distance_df.assign(type="within"), between_distances_df.assign(type="between")])
merged_df.to_csv(Path(output_dir) / "distances.csv", index=False)
merged_df

In [None]:
ax = sns.lineplot(data=merged_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance within- and between-phoneme")
ax.set_xlabel("Frames since phoneme onset")
ax.set_ylabel(f"{metric.capitalize()} distance")

In [None]:
merged_offset_df = pd.concat([within_distance_offset_df.assign(type="within"), between_distances_offset_df.assign(type="between")])
merged_offset_df.to_csv(Path(output_dir) / "distances_aligned_offset.csv", index=False)
merged_offset_df

In [None]:
ax = sns.lineplot(data=merged_offset_df.dropna(), x="frame", y="distance", hue="type")
ax.set_title("Representational distance within- and between-phoneme")
ax.set_xlabel("Frames before phoneme offset")
ax.set_ylabel(f"{metric.capitalize()} distance")

## Block by phoneme categories

In [None]:
categorization = {
    "consonant": "B CH D DH F G HH JH K L M N NG P R S SH T TH V W Y Z ZH".split(" "),
    "vowel": "AA AE AH AO AW AY EH ER EY IH IY OW OY UH UW".split(" "),
}

In [None]:
num_frames = trajectory[0].shape[1]

for phoneme_list in categorization.values():
    for phoneme in phoneme_list:
        assert phoneme in state_space_spec.labels, f"Phoneme {phoneme} missing from state space spec"

# Prepare balanced sample of representations for each phoneme in each category
num_instances = min(len(state_space_spec.target_frame_spans[i]) for i in range(len(state_space_spec.labels)))
# HACK this is just to make the number of instnaces different than the number of frames, to make sure I don't make debuggnig mistakes
num_instances -= 1

all_phonemes = sorted(set(itertools.chain.from_iterable(categorization.values())))
phoneme_representations, phoneme_representation_lengths = {}, {}
for phoneme in all_phonemes:
    sample_instance_idxs = np.random.choice(len(state_space_spec.target_frame_spans[state_space_spec.labels.index(phoneme)]),
                                            num_instances, replace=False)
    phoneme_representations[phoneme] = np.array([trajectory[state_space_spec.labels.index(phoneme)][idx]
                                                 for idx in sample_instance_idxs])
    phoneme_representation_lengths[phoneme] = lengths[state_space_spec.labels.index(phoneme)][sample_instance_idxs]

# Compute between-phoneme distances
distances = np.zeros((len(all_phonemes), len(all_phonemes), trajectory[0].shape[1]))
for p1, p2 in itertools.product(list(range(len(all_phonemes))), repeat=2):
    for k in range(num_frames):
        mask1 = phoneme_representation_lengths[all_phonemes[p1]] >= k
        mask2 = phoneme_representation_lengths[all_phonemes[p2]] >= k
        if mask1.sum() == 0 or mask2.sum() == 0:
            break

        distances[p1, p2, k] = coherence.get_mean_distance(phoneme_representations[all_phonemes[p1]][mask1, k, :],
                                                           phoneme_representations[all_phonemes[p2]][mask2, k, :], metric=metric)
        
# Compute between- and within-category distance trajectory
within_distances, between_distances, within_comparisons, between_comparisons = {}, {}, {}, {}
for category, phonemes in categorization.items():
    within_comparisons[category] = list(itertools.combinations(phonemes, 2))
    between_comparisons[category] = [(p1, p2) for p1, p2 in itertools.product(phonemes, all_phonemes) if p1 in phonemes and p2 not in phonemes]
    within_distances[category] = np.stack([distances[all_phonemes.index(p1), all_phonemes.index(p2)] for p1, p2 in within_comparisons[category]], axis=0)
    between_distances[category] = np.stack([distances[all_phonemes.index(p1), all_phonemes.index(p2)] for p1, p2 in between_comparisons[category]], axis=0)

In [None]:
all_within_distances = pd.concat(
    {category: pd.DataFrame(within_distances[category], index=pd.Index(within_comparisons[category], name=("p1", "p2")), columns=pd.Index(range(num_frames), name="frame")) \
                .melt(ignore_index=False, var_name="frame", value_name="distance")
     for category in categorization},
    names=["category"]
)

all_between_distances = pd.concat(
    {category: pd.DataFrame(between_distances[category], index=pd.Index(between_comparisons[category], name=("p1", "p2")), columns=pd.Index(range(num_frames), name="frame")) \
                .melt(ignore_index=False, var_name="frame", value_name="distance")
     for category in categorization},
    names=["category"]
)

In [None]:
all_distances = pd.concat([all_within_distances, all_between_distances], keys=["within", "between"], names=["type"])
all_distances.to_csv(Path(output_dir) / "grouped_distances.csv")
all_distances

In [None]:
sns.lineplot(data=all_distances.reset_index(), x="frame", y="distance", hue="type")

In [None]:
sns.clustermap(pd.DataFrame(np.nanmean(distances, -1), index=all_phonemes, columns=all_phonemes),
               center=1, cmap="RdBu")