In [1]:
import numpy as np
import pickle
import torch

from delphi.eval.utils import load_logprob_datasets, contruct_diff_probs_from_pos_dict
from delphi.eval.vis import vis_sample_diff_probs
from delphi.eval.token_position import get_all_tok_pos_in_category, all_tok_pos_to_metrics_map
from delphi.eval.constants import LLAMA2_MODELS
from delphi.constants import STATIC_ASSETS_DIR
from ipywidgets import interact
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from typing import cast

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
filename = "labelled_token_ids_dict.pkl"
with open(f"{STATIC_ASSETS_DIR.joinpath(filename)}", "rb") as f:
    token_labels = pickle.load(f)

# these needs to be cast to Dataset for typing reasons
token_map = cast(Dataset, load_dataset("delphi-suite/v0-token-map", split="validation"))
token_ids = cast(Dataset, load_dataset("delphi-suite/v0-tinystories-v2-clean-tokenized", split="validation").with_format("torch"))
# TODO: don't frontload the dataset ?
logprobs = load_logprob_datasets()
# convert to numpy, then to torch
# converting to numpy turns None into nan, and torch can handle nan but not None
# TODO: see if there is a way to convert straight to torch.Tensor, check Dataset object
neg_loss_by_model_name = {
    # loss = -logprob -> logprob = -loss
    k: torch.tensor(np.array(v, dtype=float), device=device)
    for k, v in logprobs.items()
}
tokenizer = AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer")

In [3]:
model_1 = LLAMA2_MODELS[0]
model_2 = LLAMA2_MODELS[1]

# model_1 - model_2 = -diff
# -model_1 + model_2 = diff
# diff = model_2 - model_1
diff_probs = neg_loss_by_model_name[model_2] - neg_loss_by_model_name[model_1]
# IMPORTANT: diff_probs is a 2D tensor
# however the number of columns in each rows is different

# token_map["prompt_pos_idx"] is a list of indices of the tokens in the prompt
all_tok_pos_in_cat = get_all_tok_pos_in_category(token_labels, "Is Noun", token_map)
pos_by_diff = all_tok_pos_to_metrics_map(all_tok_pos_in_cat, diff_probs)
sample_diff_probs, sample_tok_ids = contruct_diff_probs_from_pos_dict(pos_by_diff, token_ids)

In [6]:
_ = vis_sample_diff_probs(sample_tok_ids, sample_diff_probs, tokenizer)