Inspect integrator output values, grouping by different equivalence classings on the dataset.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import datasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm

from src.datasets.speech_equivalence import SpeechEquivalenceDataset, SpeechHiddenStateDataset
from src.models import get_best_checkpoint
from src.models.integrator import ContrastiveEmbeddingModel, iter_dataset

In [None]:
model_dir = "outputs/models/timit-no_repeats/w2v2_8/randomff_32/random"
output_dir = "outputs/notebooks/timit-no_repeats/w2v2_8/randomff_32/random/predictions"
dataset_path = "outputs/preprocessed_data/timit-no_repeats"
phoneme_equivalence_path = "outputs/equivalence_datasets/timit-no_repeats/w2v2_8/phoneme_10frames/equivalence.pkl"
hidden_states_path = "outputs/hidden_states/timit-no_repeats/w2v2_8/hidden_states.h5"
state_space_specs_path = "outputs/state_space_specs/timit-no_repeats/w2v2_8/state_space_specs.pkl"
embeddings_path = "outputs/model_embeddings/timit-no_repeats/w2v2_8/randomff_32/random/embeddings.npy"

metric = "cosine"

In [None]:
model = ContrastiveEmbeddingModel.from_pretrained(get_best_checkpoint(model_dir))
model.eval()

In [None]:
with open(phoneme_equivalence_path, "rb") as f:
    equiv_dataset: SpeechEquivalenceDataset = torch.load(f)

In [None]:
ds = datasets.Dataset.from_generator(
    iter_dataset, gen_kwargs=dict(equiv_dataset_path=phoneme_equivalence_path,
                                    hidden_states_path=hidden_states_path,
                                    max_length=model.config.max_length,
                                    num_examples=min(10000, equiv_dataset.num_instances),
                                    infinite=False)) \
    .with_format("torch")

In [None]:
class_counts = torch.bincount(ds["example_class"], minlength=len(equiv_dataset.class_labels)).numpy()
sns.barplot(pd.Series(class_counts, index=equiv_dataset.class_labels).sort_values(ascending=False))

In [None]:
losses, idxs = [], []

def compute_loss_batch(batch, batch_idxs):
    if all(neg == None for neg in batch["neg"]):
        batch["neg"] = None
        batch["neg_length"] = None
    with torch.no_grad():
        model_output = model(batch["example"], batch["example_length"],
                             batch["pos"], batch["pos_length"],
                             batch["neg"], batch["neg_length"],
                             example_idx=batch["example_class"],
                             in_batch_soft_negatives=True,
                             loss_reduction=None)
    losses.append(model_output.loss.numpy())
    idxs.append(batch["example_idx"].numpy())
ds.map(compute_loss_batch, batched=True, with_indices=True, batch_size=32)

losses = np.concatenate(losses)
idxs = np.concatenate(idxs)

loss_df = pd.DataFrame({"loss": losses, "idx": idxs, "class": equiv_dataset.Q[idxs]})
loss_df["class_label"] = loss_df["class"].map(dict(enumerate(equiv_dataset.class_labels)))

In [None]:
loss_df

In [None]:
sns.boxplot(data=loss_df, x="loss")

In [None]:
loss_df.to_csv(Path(output_dir) / "loss.csv", index=False)

In [None]:
sns.barplot(data=loss_df, x="class_label", y="loss",
            order=loss_df.groupby("class_label")["loss"].mean().sort_values().index)