In [6]:
import torch
import pickle
import random


from datasets import load_dataset, Dataset


from typing import cast
from numbers import Number


from transformers import AutoTokenizer
from delphi.constants import STATIC_ASSETS_DIR
from delphi.eval.token_positions import get_all_tok_metrics_in_label

# from delphi.train.utils import get_device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
tokenizer = AutoTokenizer.from_pretrained("delphi-suite/stories-tokenizer")
token_ids = cast(
    Dataset,
    load_dataset("delphi-suite/v0-tinystories-v2-clean-tokenized", split="validation"),
)
next_logprobs_100k = (
    cast(
        Dataset,
        load_dataset("delphi-suite/v0-next-logprobs-llama2-100k", split="validation"),
    )
    .with_format("torch")
    .map(lambda x: {"logprobs": x["logprobs"].to(device)})
)  # i wonder if there is a better way to send this to the gpu
next_logprobs_400k = (
    cast(
        Dataset,
        load_dataset("delphi-suite/v0-next-logprobs-llama2-400k", split="validation"),
    )
    .with_format("torch")
    .map(lambda x: {"logprobs": x["logprobs"].to(device)})
)

token_labels_filename = "labelled_token_ids_dict.pkl"
with open(f"{STATIC_ASSETS_DIR.joinpath(token_labels_filename)}", "rb") as f:
    token_labels = pickle.load(f)

Map:   0%|          | 0/10982 [00:00<?, ? examples/s]

Map:   0%|          | 0/10982 [00:00<?, ? examples/s]

In [8]:
%%time
diff = next_logprobs_400k["logprobs"] - next_logprobs_100k["logprobs"]  # type: ignore
pos_map = get_all_tok_metrics_in_label(
    token_ids, token_labels, diff, "Is Noun", 0.9, 1.0
)

CPU times: total: 9.44 s
Wall time: 7.7 s


In [9]:
# TODO: visualization function that turns dict[tuple[int, int], Number] an interactive html display
def vis_pos_map(
    pos_map: dict[tuple[int, int], Number],
    token_ids: Dataset,
    sample: int = 3,
):
    # choose n random keys from pos_map
    keys = random.sample(list(pos_map.keys()), k=sample)

    for key in keys:
        prompt, pos = key
        print(f"Position pair: {key}")
        print(f"Prompt: {tokenizer.decode(token_ids['tokens'][prompt][:pos + 1])}")
        print(f"Token: {tokenizer.decode(token_ids['tokens'][prompt][pos])}")
        print(f"Value: {pos_map[key]}")
        print()

In [10]:
vis_pos_map(pos_map, token_ids, 3)

Position pair: (2765, 414)
Prompt: a big red ball. He wanted to play with it.
"Mom, can I have the ball?" Tim asked. His mom said yes, and they bought the ball. Tim was so happy. He played with the ball all day.
The next day, Tim played with the ball outside. He kicked it very hard, and it hit a tree. The ball fell down and got dirty. Tim was sad. His mom saw him and said, "Don't worry, Tim. We can wipe the ball clean."
Tim wiped the ball with a cloth. It was clean again. But when he kicked it, his foot hurt a little. He told his mom, and she said, "It's okay, Tim. Just be careful next time."
Tim played with the ball more gently. He had lots of fun and did not get hurt again.<s> Once upon a time, there was a little boy named Tim. Tim liked to observe the sun. He would look up at the sky and see the big, bright sun shining down on him.
One day, Tim met a silly bird. The bird was flying around and making funny noises. Tim thought the bird was a little bit stupid because it didn't know ho