In [1]:
# colab cells (only run if on colab)
# TODO: experiment on colab to see how to set up the environment

# Important

Run this cell by cell. The token selecter cell needs to be ran first so the later cells work.

In [2]:
# imports
import torch
import panel as pn
from delphi.eval.vis import token_selector
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from typing import cast
from delphi.eval.calc_model_group_stats import calc_model_group_stats
from delphi.eval.vis_per_token_model import visualize_selected_tokens
from ipywidgets import interact
from delphi.eval.token_positions import get_all_tok_metrics_in_label
from delphi.eval.vis import vis_pos_map
import ipywidgets as widgets

# refer to https://panel.holoviz.org/reference/panes/IPyWidget.html to integrate ipywidgets with panel
pn.extension()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# specify model names (or checkpoints)
prefix = "delphi-suite/v0-next-logprobs-llama2-"
suffixes = ["100k", "200k", "400k"]#, "800k", "1.6m", "3.2m", "6.4m", "12.8m", "25.6m"]

In [4]:
# load next logprobs data for all models
split = "validation[:100]"
next_logprobs = {
    suffix: cast(
        Dataset,
        load_dataset(f"{prefix}{suffix}", split=split),
    )
    .with_format("torch")
    .map(lambda x: {"logprobs": x["logprobs"].to(device)})
    for suffix in suffixes
}
next_logprobs_plot = {k: d["logprobs"] for k, d in next_logprobs.items()}

# load the tokenized dataset
tokenized_corpus_dataset = (
    cast(
        Dataset,
        load_dataset("delphi-suite/tinystories-v2-clean-tokenized-v0", split=split),
    )
    .with_format("torch")
    .map(lambda x: {"tokens": x["tokens"].to(device)})
)

Run this notebook until the following cell, then the rest should work.

In [5]:
# specific token specification
tokenizer = AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer")
# get only the tokens present in the dataset
valid_tok_ids = torch.unique(
    tokenized_corpus_dataset["tokens"]
).tolist()  # use torch.unique instead of np.unique when tokenized_corpus_dataset is a torch tensor
valid_toks = tokenizer.convert_ids_to_tokens(valid_tok_ids)
vocab = {tok: valid_tok_ids[i] for i, tok in enumerate(valid_toks)}

selector, selected_ids = token_selector(vocab)  # use selected_ids as a dynamic variable
pn.Column(selector, height=300).servable()

BokehModel(combine_events=True, render_bundle={'docs_json': {'f305faf3-9da7-41c5-a856-22effd28e368': {'version…

In [8]:
print("Selected IDs:", selected_ids)

Selected IDs: [267, 266, 13]


In [9]:
model_group_stats = calc_model_group_stats( # i'm not sure if tokenized_corpus_dataset.tolist() is the right input, it was list(tokenized_corpus_dataset) before
    tokenized_corpus_dataset, next_logprobs_plot, selected_ids
)
performance_data = {}
for suffix in suffixes:
    stats = model_group_stats[suffix]
    performance_data[suffix] = (
        -stats["median"],
        -stats["75th"],
        -stats["25th"],
    )

visualize_selected_tokens(performance_data, log_scale=True)

Processing model 100k
Processing model 200k
Processing model 400k


FigureWidget({
    'data': [{'line': {'width': 0},
              'marker': {'color': 'rgba(68, 68, 68, 0.3)'},
              'mode': 'lines',
              'name': 'Upper Bound',
              'showlegend': False,
              'type': 'scatter',
              'uid': '9eae4162-972e-4209-8aef-419d27ec28f1',
              'x': [100k, 200k, 400k],
              'y': array([1.51873401, 1.37054035, 1.20272651])},
             {'fill': 'tonexty',
              'fillcolor': 'rgba(68, 68, 68, 0.3)',
              'line': {'width': 0},
              'marker': {'color': 'rgba(68, 68, 68, 0.3)'},
              'mode': 'lines',
              'name': 'Lower Bound',
              'showlegend': False,
              'type': 'scatter',
              'uid': '8bb5125d-0df0-4756-89e2-33624e31dee7',
              'x': [100k, 200k, 400k],
              'y': array([0.30299246, 0.27227049, 0.23416051])},
             {'marker': {'color': 'rgb(31, 119, 180)', 'line': {'color': 'rgb(31, 119, 180)', 'width': 1},

In [13]:
def show_pos_map(
    quantile: tuple[float, float],
    model_name_1: str,
    model_name_2: str,
    samples: int,
):
    logprobs_diff = next_logprobs[model_name_2]["logprobs"] - next_logprobs[model_name_1]["logprobs"]  # type: ignore
    pos_to_diff = get_all_tok_metrics_in_label(tokenized_corpus_dataset["tokens"], selected_tokens=selected_ids, metrics=logprobs_diff, q_start=quantile[0], q_end=quantile[1])  # type: ignore
    try:
        _ = vis_pos_map(list(pos_to_diff.keys()), selected_ids, logprobs_diff, tokenized_corpus_dataset["tokens"], tokenizer, sample=samples)  # type: ignore
    except ValueError:
        if pos_to_diff == {}:
            print("No tokens found in this label")
        elif samples > len(pos_to_diff):
            print(f"Only {len(pos_to_diff)} samples available")
    return


interact(
    show_pos_map,
    quantile=widgets.FloatRangeSlider(
        min=0.0, max=1.0, step=0.05, description="Quantiles"
    ),
    samples=widgets.IntSlider(min=1, max=5, description="Samples", value=2),
    model_name_1=widgets.Dropdown(
        options=suffixes,
        description="Model 1",
        value="100k",
    ),
    model_name_2=widgets.Dropdown(
        options=suffixes,
        description="Model 2",
        value="200k",
    ),
);

interactive(children=(FloatRangeSlider(value=(0.25, 0.75), description='Quantiles', max=1.0, step=0.05), Dropd…