In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
from dataclasses import replace
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.analysis import coherence
from src.analysis.state_space import prepare_state_trajectory, StateSpaceAnalysisSpec
from src.datasets.speech_equivalence import SpeechEquivalenceDataset

In [None]:
model_dir = "outputs/models/timit/w2v2_6/rnn_8/phoneme"
output_dir = "outputs/notebooks/timit/w2v2_6/rnn_8/phoneme/plot"
dataset_path = "outputs/preprocessed_data/timit"
equivalence_path = "outputs/equivalence_datasets/timit/w2v2_6/phoneme/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/timit/w2v2_6/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/timit/w2v2_6/state_space_specs.pkl"
embeddings_path = "outputs/model_embeddings/timit/w2v2_6/rnn_8/phoneme/embeddings.npy"

output_dir = "."

metric = "cosine"

In [None]:
with open(embeddings_path, "rb") as f:
    model_representations: np.ndarray = np.load(f)
with open(equivalence_path, "rb") as f:
    equiv_dataset: SpeechEquivalenceDataset = torch.load(f)
with open(state_space_specs_path, "rb") as f:
    state_space_spec: StateSpaceAnalysisSpec = torch.load(f)["syllable_by_identity_and_ordinal"]
assert state_space_spec.is_compatible_with(model_representations)

## Search for well-attested syllables

In [None]:
# Build representation of all syllable identities/positions
all_syllables = sorted(set(syllable for syllable, position in state_space_spec.labels))
all_positions = sorted(set(position for syllable, position in state_space_spec.labels))
syllable_mat = np.zeros((len(all_syllables), len(all_positions)), dtype=int)
for i, (syllable, position) in enumerate(state_space_spec.labels):
    syllable_mat[all_syllables.index(syllable), all_positions.index(position)] = \
        len(state_space_spec.target_frame_spans[i])
syllable_df = pd.DataFrame(syllable_mat, index=all_syllables, columns=all_positions)
syllable_df

In [None]:
# Find syllables which appear in every ordinal position at least twice up to `min_number_positions`
min_number_positions = 3
syllable_max_position = (syllable_df >= 2).idxmin(axis=1)
match_syllables = syllable_max_position.loc[syllable_max_position >= min_number_positions].index.tolist()
len(match_syllables), match_syllables[:5]

In [None]:
sns.barplot(data=syllable_df.loc[match_syllables, :min_number_positions] \
                    .melt(var_name="ordinal_position", value_name="frequency"),
            x="ordinal_position", y="frequency")

## Prepare model representations

In [None]:
retain_labels = [(syllable, ordinal) for syllable in match_syllables
                 for ordinal in range(min_number_positions)]
drop_idxs = [idx for idx, label in enumerate(state_space_spec.labels)
             if label not in retain_labels]
state_space_spec = state_space_spec.drop_labels(drop_idxs)

In [None]:
spec_label_strs = [f"{' '.join(phones)} {ordinal}" for phones, ordinal in state_space_spec.labels]

In [None]:
trajectory = prepare_state_trajectory(model_representations, state_space_spec, pad=np.nan)

In [None]:
lengths = [np.isnan(traj_i[:, :, 0]).argmax(axis=1) for traj_i in trajectory]

In [None]:
len(trajectory)

## Estimate within-syllable, within-position distance

In [None]:
within_distance, within_distance_offset = \
    coherence.estimate_within_distance(trajectory, lengths, state_space_spec, metric=metric)

In [None]:
sns.heatmap(within_distance, center=1, cmap="RdBu")

In [None]:
within_distance_df = pd.DataFrame(within_distance, index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

In [None]:
within_distance_offset_df = pd.DataFrame(within_distance_offset, index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

## Estimate within-syllable, between-position distance

In [None]:
between1_samples = [[state_space_spec.labels.index((syllable_i, ordinal_j))
                     for ordinal_j in range(min_number_positions)
                     if ordinal_j != ordinal_i]
                    for syllable_i, ordinal_i in state_space_spec.labels]

between1_distance, between1_distance_offset = \
    coherence.estimate_between_distance(trajectory, lengths, state_space_spec,
                                        between_samples=between1_samples,
                                        metric=metric)

In [None]:
between1_distance_df = pd.DataFrame(np.nanmean(between1_distance, axis=-1),
                                    index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

In [None]:
between1_distance_offset_df = pd.DataFrame(np.nanmean(between1_distance_offset, axis=-1),
                                     index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

## Estimate between-syllable distance

In [None]:
# Match the number of between-samples with the earlier analysis
num_samples = min_number_positions

between_distance, between_distance_offset = \
    coherence.estimate_between_distance(trajectory, lengths, state_space_spec,
                                        num_samples=num_samples, metric=metric)

In [None]:
between_distance_df = pd.DataFrame(np.nanmean(between_distance, axis=-1),
                                    index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

In [None]:
between_distance_offset_df = pd.DataFrame(np.nanmean(between_distance_offset, axis=-1),
                                          index=pd.Index(spec_label_strs, name="syllable")) \
    .reset_index() \
    .melt(id_vars=["syllable"], var_name="frame", value_name="distance")

## Together

In [None]:
merged_df = pd.concat([within_distance_df.assign(type="within"),
                       between1_distance_df.assign(type="different_position"),
                       between_distance_df.assign(type="between")])
merged_df.to_csv(Path(output_dir) / "distances.csv", index=False)
merged_df

In [None]:
ax = sns.lineplot(data=merged_df.dropna().replace({"type": {"within": "Same identity, same position",
                                                            "different_position": "Same identity, different position",
                                                            "between": "Different identity, different position"}}),
                  x="frame", y="distance", hue="type")
ax.set_title("Representational distance within- and between-syllable")
ax.set_xlabel("Frames since syllable onset")
ax.set_ylabel(f"{metric.capitalize()} distance")
ax.set_xlim((0, np.percentile(np.concatenate(lengths), 95)))

In [None]:
merged_offset_df = pd.concat([within_distance_offset_df.assign(type="within"),
                              between_distance_offset_df.assign(type="between"),
                              between1_distance_offset_df.assign(type="different_position")])
merged_offset_df.to_csv(Path(output_dir) / "distances_aligned_offset.csv", index=False)
merged_offset_df

In [None]:
ax = sns.lineplot(data=merged_offset_df.dropna().replace({"type": {"within": "Same identity, same position",
                                                            "different_position": "Same identity, different position",
                                                            "between": "Different identity, different position"}}),
                  x="frame", y="distance", hue="type")
ax.set_title("Representational distance within- and between-syllable")
ax.set_xlabel("Frames before syllable offset")
ax.set_ylabel(f"{metric.capitalize()} distance")
ax.set_xlim((0, np.percentile(np.concatenate(lengths), 95)))