In [1]:
import torch
import pickle


from datasets import load_dataset, Dataset


from typing import cast
from ipywidgets import interact
import ipywidgets as widgets


from transformers import AutoTokenizer
from delphi.constants import STATIC_ASSETS_DIR
from delphi.eval.token_positions import get_all_tok_metrics_in_label
from delphi.eval.vis import vis_pos_map
from delphi.eval.constants import LLAMA2_NEXT_LOGPROBS_DATASETS_MAP

# from delphi.train.utils import get_device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
tokenizer = AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer")
token_ids = (
    cast(
        Dataset,
        load_dataset(
            "delphi-suite/v0-tinystories-v2-clean-tokenized", split="validation"
        ),
    )
    .with_format("torch")
    .map(lambda x: {"tokens": x["tokens"].to(device)})
)

next_logprobs = {  # preloading all the logprobs datasets for interactive use
    model_name: (
        cast(
            Dataset,
            load_dataset(f"{dataset_name}", split="validation"),
        )
        .with_format("torch")
        .map(lambda x: {"logprobs": x["logprobs"].to(device)})
    )
    for model_name, dataset_name in LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.items()
}

token_labels_filename = "labelled_token_ids_dict.pkl"
with open(f"{STATIC_ASSETS_DIR.joinpath(token_labels_filename)}", "rb") as f:
    token_labels = pickle.load(f)

In [6]:
def show_pos_map(
    quantile: tuple[float, float],
    model_name_1: str,
    model_name_2: str,
    label: str,
    samples: int,
):
    token_id_t = token_ids["tokens"]
    logprobs_diff = next_logprobs[model_name_2]["logprobs"] - next_logprobs[model_name_1]["logprobs"]  # type: ignore
    pos_to_diff = get_all_tok_metrics_in_label(token_id_t, token_labels=token_labels, metrics=logprobs_diff, label=label, q_start=quantile[0], q_end=quantile[1])  # type: ignore
    try:
        _ = vis_pos_map(pos_to_diff, token_id_t, tokenizer, sample=samples)  # type: ignore
    except ValueError:
        print("No tokens found in this label")
    return


interact(
    show_pos_map,
    quantile=widgets.FloatRangeSlider(
        min=0.0, max=1.0, step=0.01, description="Start quantile"
    ),
    samples=widgets.IntSlider(min=1, max=5, description="Samples", value=2),
    model_name_1=widgets.Dropdown(
        options=LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.keys(),
        description="Model 1",
        value="llama2-100k",
    ),
    model_name_2=widgets.Dropdown(
        options=LLAMA2_NEXT_LOGPROBS_DATASETS_MAP.keys(),
        description="Model 2",
        value="llama2-200k",
    ),
    label=widgets.Dropdown(
        options=token_labels[0].keys(), description="Label", value="Is Noun"
    ),
)

interactive(children=(FloatRangeSlider(value=(0.25, 0.75), description='Start quantile', max=1.0, step=0.01), …

<function __main__.show_pos_map(quantile: tuple[float, float], model_name_1: str, model_name_2: str, label: str, samples: int)>