Inspect integrator output values, grouping by different equivalence classings on the dataset.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
from collections import Counter
from pathlib import Path

import datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.datasets.speech_equivalence import SpeechEquivalenceDataset, SpeechHiddenStateDataset
from src.models import get_best_checkpoint
from src.models.integrator import ContrastiveEmbeddingModel, iter_dataset

In [None]:
model_name = "syllable"
model_dir = f"outputs/models/timit/w2v2_8/rnn_8-weightdecay0.01/{model_name}_10frames"
output_dir = f"outputs/notebooks/timit/w2v2_8/rnn_8-weightdecay0.01/{model_name}_10frames/state_space"
dataset_path = "outputs/preprocessed_data/timit"
equivalence_path = f"outputs/equivalence_datasets/timit/w2v2_8/{model_name}_10frames/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/timit/w2v2_8/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/timit/w2v2_8/state_space_specs.pkl"
embeddings_path = f"outputs/model_embeddings/timit/w2v2_8/rnn_8-weightdecay0.01/{model_name}_10frames/embeddings.npy"
word_equivalence_path = f"outputs/equivalence_datasets/timit/w2v2_8/word_broad_10frames/equivalence.pkl"

metric = "cosine"

In [None]:
model = ContrastiveEmbeddingModel.from_pretrained(get_best_checkpoint(model_dir))
model.eval()

In [None]:
with open(word_equivalence_path, "rb") as f:
    equiv_dataset: SpeechEquivalenceDataset = torch.load(f)

In [None]:
ds = datasets.Dataset.from_generator(
    iter_dataset, gen_kwargs=dict(equiv_dataset_path=word_equivalence_path,
                                    hidden_states_path=hidden_states_path,
                                    max_length=model.config.max_length,
                                    num_examples=min(10000, equiv_dataset.num_instances),
                                    infinite=False)) \
    .with_format("torch")

In [None]:
Counter(["".join(equiv_dataset.class_labels[label_idx]) for label_idx in ds["example_class"].numpy()]).most_common(10)

In [None]:
losses, idxs = [], []

def compute_loss_batch(batch, batch_idxs):
    with torch.no_grad():
        model_output = model(batch["example"], batch["example_length"],
                             batch["pos"], batch["pos_length"],
                             batch["neg"], batch["neg_length"],
                             in_batch_soft_negatives=False,
                             loss_reduction=None)
    losses.append(model_output.loss.numpy())
    idxs.append(batch["example_idx"].numpy())
ds.map(compute_loss_batch, batched=True, with_indices=True, batch_size=32)

losses = np.concatenate(losses)
idxs = np.concatenate(idxs)

loss_df = pd.DataFrame({"loss": losses, "idx": idxs,
                        "class": equiv_dataset.Q[idxs].numpy(),
                        "position": idxs - equiv_dataset.S[idxs].numpy()})
loss_df["class_label"] = loss_df["class"].map(lambda idx: " ".join(equiv_dataset.class_labels[idx]))
loss_df["word_length"] = loss_df.class_label.str.count(" ") + 1

In [None]:
loss_df

In [None]:
sns.boxplot(data=loss_df, x="loss")

In [None]:
loss_df.to_csv(Path(output_dir) / "loss.csv", index=False)

In [None]:
sns.lineplot(data=loss_df, x="word_length", y="loss")

In [None]:
sns.lineplot(data=loss_df, x="position", y="loss")