# Plan
1. Build a basic UI for testing purposes

# Demo UX

## Mock data for UX testing

In [1]:
import pandas as pd
import numpy as np

import tempfile
import pathlib

tmpdir = pathlib.Path(tempfile.gettempdir())
dummy_csv_path = tmpdir / "dummy_data.csv"

token_classes = ["Nouns", "Verbs", "Adjectives"]
training_steps_options = [100000, 200000, 400000]
model_sizes_options = ["Small", "Medium", "Large"]

data = {
    "model_size": [],
    "training_steps": [],
    "loss": [],
    "token_class": [],
}

# Generate dummy data
for size in model_sizes_options:
    for steps in training_steps_options:
        for token_group_desc in token_classes:
            data["model_size"].append(size)
            data["training_steps"].append(steps)
            data["token_class"].append(token_group_desc)
            # loss should be random but decrease with size and steps
            loss = (
                1
                - (model_sizes_options.index(size) + 1) / len(model_sizes_options)
                - (training_steps_options.index(steps) + 1)
                / len(training_steps_options)
            )
            noisy_loss = loss + np.random.normal(0, 0.1)
            data["loss"].append(noisy_loss)

# Create DataFrame
dummy_df = pd.DataFrame(data)

# Save DataFrame to a CSV file
dummy_df.to_csv(dummy_csv_path, index=False)

dummy_csv_path

PosixPath('/var/folders/5k/7nfpl0cs5999pzhndyybcn800000gn/T/dummy_data.csv')

# Load Pickle Data

In [9]:
import pickle

with open("../data/token_model_stats.pkl", "rb") as f:
    token_model_stats = pickle.load(f)

with open("../data/token_groups.pkl", "rb") as f:
    token_groups = pickle.load(f)

In [21]:
from delphi.eval.constants import LLAMA2_MODELS
from datasets import load_dataset, Dataset
from typing import cast

from delphi.eval.hack_token_label import HackTokenLabels

model_token_group_stats = {}
for model in LLAMA2_MODELS:
    print(f"Processing model {model}")
    dataset = cast(
        Dataset, load_dataset(f"transcendingvictor/{model}-validation-logprobs")
    )
    for token_group_desc in HackTokenLabels.ALL_LABELS:
        token_group_model_stats = [
            token_model_stats[(token, model)] for token in token_groups[token_group_desc.description]
            if (token, model) in token_model_stats
        ]
        sum_logprob = sum(t.logprob_sum for t in token_group_model_stats)
        sum_count = sum(t.count for t in token_group_model_stats)
        if sum_count == 0:
            continue
        mean_logprob = sum(t.logprob_sum for t in token_group_model_stats) / sum(
            t.count for t in token_group_model_stats
        )
        model_token_group_stats[(token_group_desc.description, model)] = mean_logprob

Processing model delphi-llama2-100k
Processing model delphi-llama2-200k
Processing model delphi-llama2-400k
Processing model delphi-llama2-800k
Processing model delphi-llama2-1.6m
Processing model delphi-llama2-3.2m
Processing model delphi-llama2-6.4m
Processing model delphi-llama2-12.8m
Processing model delphi-llama2-25.6m


In [29]:
import pandas as pd
import pickle

with open("../data/model_token_group_stats.pkl", "rb") as f:
    model_token_group_stats = pickle.load(f)

mtgs_df = pd.DataFrame(
    [key + ( value, ) for key, value in model_token_group_stats.items()],
    columns=["token_group", "model", "mean_logprob"])

## Build Visualization

In [64]:
import pandas as pd
import plotly.graph_objs as go
from ipywidgets import interact, Dropdown
import plotly.express as px

# Load your data
# df = pd.read_csv(dummy_csv_path)  # replace with your actual path
df = mtgs_df
mtgs_df["training_steps"] = 0

# dumb hack to avoid the first call to update_figure rendering a duplicate chart
_first_call = True


def get_model_size(model: str) -> float:
    # model names end in model size after dash, e.g. -100k or -25.6m
    size_str = model.split("-")[-1]
    if size_str[-1] == "k":
        return float(size_str[:-1]) * 1e3
    elif size_str[-1] == "m":
        return float(size_str[:-1]) * 1e6
    elif size_str[-1] == "b":
        return float(size_str[:-1]) * 1e9
    else:
        raise ValueError(f"Unknown size suffix: {size_str[-1]}")


# set model_size based on df["model"]
df["model_size"] = df["model"].apply(get_model_size)
df["loss"] = -df["mean_logprob"]

_first_call = True

# Function to create and update the figure
def update_figure(comparison_type, model_size, training_steps, token_group_desc):
    if comparison_type == "model_size":
        filtered_df = df[
            (df["training_steps"] == training_steps)
            & (df["token_group"] == token_group_desc)
        ]
        # update existing fig with line
        fig = go.FigureWidget(
            px.line(filtered_df, x="model_size", y="loss", title=f"Loss (-mean logprob) by Model Size ({token_group_desc})")
        )
    else:
        filtered_df = df[
            (df["model_size"] == model_size) & (df["token_group"] == token_group_desc)
        ]
        fig = go.FigureWidget(
            px.line(
                filtered_df,
                x="training_steps",
                y="loss",
                title="Loss (mean -logprob) by Training Steps",
            )
        )
    # set x axis to log scale
    fig.update_xaxes(type="log")

    global _first_call
    if _first_call:
        _first_call = False
    else:
        fig.show()


# Interactive widgets
comparison_type = Dropdown(options=["model_size", "training_steps"])
model_size = Dropdown(options=sorted(df["model_size"].unique()))
training_steps = Dropdown(options=sorted(df["training_steps"].unique()))
token_group_desc = Dropdown(options=df["token_group"].unique())

# only render the chart after all the widgets have been rendered
_ = interact(
    update_figure,
    comparison_type=comparison_type,
    model_size=model_size,
    training_steps=training_steps,
    token_group_desc=token_group_desc,
    __manual=True,
)

interactive(children=(Dropdown(description='comparison_type', options=('model_size', 'training_steps'), value=…

# Remaining Tasks

1. Get all token positions
2. Get all model token predictions
3. Get token groups
4. Calculate mean loss and count per token per model

# Cache Data

## Datasets

In [3]:
from huggingface_hub import list_datasets
from datetime import datetime
from typing import cast

datasets = list_datasets(author="transcendingvictor")

datasets = [dataset for dataset in datasets if cast(datetime, dataset.created_at).date() > datetime(2024, 2, 10).date()]

# cache datasets
from datasets import load_dataset

for dataset in datasets:
    load_dataset(dataset.id)

In [4]:
[d.id for d in datasets]

['transcendingvictor/delphi-llama2-100k-validation-logprobs',
 'transcendingvictor/delphi-llama2-200k-validation-logprobs',
 'transcendingvictor/delphi-llama2-400k-validation-logprobs',
 'transcendingvictor/delphi-llama2-800k-validation-logprobs',
 'transcendingvictor/delphi-llama2-1.6m-validation-logprobs',
 'transcendingvictor/delphi-llama2-3.2m-validation-logprobs',
 'transcendingvictor/delphi-llama2-6.4m-validation-logprobs',
 'transcendingvictor/delphi-llama2-12.8m-validation-logprobs',
 'transcendingvictor/delphi-llama2-25.6m-validation-logprobs']

## Models

In [5]:
# Load models from https://huggingface.co/collections/delphi-suite/llama2-65b53936f3c0cb73e741fc44

from huggingface_hub import list_models, get_collection
from datetime import datetime
from typing import cast

# list models from the delphi-suite llama-2 collection
collection = get_collection("delphi-suite/llama2-65b53936f3c0cb73e741fc44")

# cache all the models in the collection
from transformers import AutoModelForCausalLM

for model in collection.items:
    AutoModelForCausalLM.from_pretrained(model.item_id)


In [6]:
[c.item_id for c in collection.items]

['delphi-suite/delphi-llama2-100k',
 'delphi-suite/delphi-llama2-200k',
 'delphi-suite/delphi-llama2-400k',
 'delphi-suite/delphi-llama2-800k',
 'delphi-suite/delphi-llama2-1.6m',
 'delphi-suite/delphi-llama2-3.2m',
 'delphi-suite/delphi-llama2-6.4m',
 'delphi-suite/delphi-llama2-12.8m',
 'delphi-suite/delphi-llama2-25.6m']

## Token Mappings

In [16]:
from delphi.eval.token_map import token_map
from delphi.eval.utils import load_validation_dataset
from delphi.eval import constants
from platformdirs import user_cache_dir
from pathlib import Path
import pickle
from datasets import load_dataset

In [4]:
# token mapping is a dict of (token -> (document, seq))
token_mappings_path = "../data/token_mappings"
with open(token_mappings_path, "rb") as f:
    token_mappings = pickle.load(f)

In [100]:
token_mappings.keys()

dict_keys([1, 432, 440, 261, 403, 4045, 406, 286, 799, 478, 407, 385, 4037, 505, 268, 1555, 622, 387, 331, 397, 509, 350, 614, 1318, 375, 1280, 381, 380, 13, 2112, 606, 486, 2929, 4040, 669, 269, 921, 341, 2652, 492, 457, 579, 544, 920, 1752, 4056, 395, 729, 412, 675, 4071, 3212, 1316, 1057, 1726, 892, 1897, 993, 342, 390, 720, 366, 410, 425, 311, 434, 628, 981, 924, 888, 367, 501, 1917, 372, 3398, 577, 359, 1854, 1811, 482, 698, 264, 525, 1014, 429, 1004, 384, 1101, 1091, 4032, 507, 624, 837, 370, 2241, 317, 670, 1728, 959, 829, 1067, 424, 466, 3824, 1526, 1102, 1256, 515, 1019, 697, 713, 602, 886, 560, 1000, 2567, 500, 1235, 369, 4001, 313, 282, 1030, 1365, 1700, 4053, 3935, 2042, 551, 858, 1521, 876, 1516, 889, 495, 467, 712, 844, 4060, 503, 435, 316, 711, 3749, 1126, 601, 1301, 1947, 2164, 599, 3673, 1104, 1515, 3018, 3158, 4054, 1788, 330, 477, 326, 332, 345, 415, 4026, 363, 504, 327, 307, 1088, 832, 411, 413, 3821, 543, 1901, 468, 340, 471, 567, 1495, 1181, 324, 559, 398, 4057, 1

In [17]:
logprob_ds = load_dataset(constants.LLAMA2_LOGPROB_DATASETS[0])['validation']

In [93]:
import logging
from datasets import Dataset
tokenized_ds = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))

In [8]:
# token mapping is a dict of (token -> (document, seq))
token_mappings_path = "../data/token_mappings"
with open(token_mappings_path, "rb") as f:
    token_mappings = pickle.load(f)

# get positions of all occurrences of a token in the tokenized corpus
def get_token_positions(token: int) -> list[ tuple[ int, int ] ]:
    return token_mappings[token]
    # return [(i, j) for i, doc in enumerate(tokenized_ds['validation']['tokens']) for j, t in enumerate(doc) if t == token]

def get_logprob(token: int) -> float:
    # ignore occurances at the start of each document
    positions = [p for p in get_token_positions(token) if p[1] != 0]
    if len(positions) == 0:
        return 0.0
    return sum([logprob_ds[i]['logprobs'][j - 1] for i, j in positions]) / len(positions)

token_counts = {k: len(v) for k, v in token_mappings.items()}

token_logprobs = {token: get_logprob(token) for token in token_mappings.keys()}

def get_group_weighted_logprog(group: list[int]) -> float:
    token_counts = {k: len(v) for k, v in token_mappings.items()}
    token_logprobs = {token: get_logprob(token) for token in token_mappings.keys()}
    group = [token for token in group if token in token_logprobs]
    if len(group) == 0:
        logging.warning(f"Group {group} has no tokens in token_logprobs; returning 0.0")
        return 0.0
    return sum([token_logprobs[token] * token_counts[token] for token in group]) / sum([token_counts[token] for token in group])

In [39]:
get_group_weighted_logprog([2, 3, 4, 5, 6, 7, 8, 9, 10, 25, 101, 4001, 2003, 2004])

-3.92385858499659

In [4]:
from dataclasses import dataclass
from typing import cast
from datasets import Dataset

@dataclass
class TokenModelStats:
    token: int
    model: str
    logprob_sum: float
    count: int

In [54]:
logprob_ds[0]["logprobs"][3]

-0.00271428469568491

In [58]:
# we're going to iterate through models, and for each model we're going to iterate through all the tokens
# and calculate the weighted logprob for each token, then store them in a dict of (token, model_name) -> TokenModelStats

from transformers import AutoModelForCausalLM
from delphi.eval.constants import LLAMA2_LOGPROB_DATASETS
from tqdm import tqdm


def get_model_name_from_logprob_dataset_name(dataset_id: str) -> str:
    return dataset_id.split("/")[-1].split("-validation-logprobs")[0]

tokenized_ds = cast(Dataset, load_dataset(constants.tokenized_corpus_dataset))["validation"]["tokens"]  # type: ignore
model_token_stats = {}
for model_dataset in constants.LLAMA2_LOGPROB_DATASETS:
    model_name = get_model_name_from_logprob_dataset_name(model_dataset)
    print(f"Processing model {model_dataset}")
    logprob_ds = cast(Dataset, load_dataset(model_dataset))['validation']
    token_counts = dict()
    token_logprob_sums = dict()
    for i, doc in tqdm(enumerate(tokenized_ds)):
        for j, token in enumerate(doc):
            if j == 0:
                continue
            if token not in token_counts:
                token_counts[token] = 0
                token_logprob_sums[token] = 0.0
            token_counts[token] += 1
            try:
                token_logprob_sums[token] += logprob_ds[i]['logprobs'][j - 1]
            except Exception as e:
                print(f"Error processing token {token} at position {i}, {j}" )
                raise e
    for token, token_count in token_counts.items():
        token_logprob_sum = token_logprob_sums[token]
        model_token_stats[(token, model_name)] = TokenModelStats(token, model_name, token_logprob_sum, token_count)

Processing model transcendingvictor/delphi-llama2-100k-validation-logprobs


10982it [09:30, 19.24it/s]


Processing model transcendingvictor/delphi-llama2-200k-validation-logprobs


10982it [09:30, 19.24it/s]


Processing model transcendingvictor/delphi-llama2-400k-validation-logprobs


10982it [09:24, 19.47it/s]


Processing model transcendingvictor/delphi-llama2-800k-validation-logprobs


10982it [09:26, 19.38it/s]


Processing model transcendingvictor/delphi-llama2-1.6m-validation-logprobs


10982it [09:26, 19.39it/s]


Processing model transcendingvictor/delphi-llama2-3.2m-validation-logprobs


10982it [09:24, 19.47it/s]


Processing model transcendingvictor/delphi-llama2-6.4m-validation-logprobs


10982it [09:23, 19.50it/s]


Processing model transcendingvictor/delphi-llama2-12.8m-validation-logprobs


10982it [09:34, 19.12it/s]


Processing model transcendingvictor/delphi-llama2-25.6m-validation-logprobs


10982it [09:20, 19.58it/s]


In [5]:
import pickle

with open("../data/token_model_stats.pkl", "rb") as f:
    model_token_stats = pickle.load(f)

In [6]:
model_token_stats[(1000, "delphi-llama2-100k")]

TokenModelStats(token=1000, model='delphi-llama2-100k', logprob_sum=-4283.318835604936, count=1941)

In [66]:
with open("../data/token_model_stats.pkl", "wb") as f:
    pickle.dump(model_token_stats, f)

In [64]:
def get_token_group_loss(
    group: list[int],
    model: str,
    model_token_stats: dict[tuple[int, str], TokenModelStats],
) -> float:
    group = [token for token in group if (token, model) in model_token_stats]
    if len(group) == 0:
        logging.warning(
            f"Group {group} has no tokens in model_token_stats; returning 0.0"
        )
        return 0.0
    return sum(
        [model_token_stats[(token, model)].logprob_sum for token in group]
    ) / sum([model_token_stats[(token, model)].count for token in group])

In [65]:
get_token_group_loss([2, 3, 4, 5, 6, 7, 8, 9, 10, 25, 101, 4001, 2003, 2004], "delphi-llama2-100k", model_token_stats)

-3.924027909847916

In [91]:
def get_group_logprob(token_group: list[int]) -> float:
    # ignore occurances at the start of each document
    positions = []
    for t in token_group:
        positions.extend([p for p in get_token_positions(t) if p[1] != 0])
    if len(positions) == 0:
        return 0.0
    return sum([logprob_ds[i]['logprobs'][j - 1] for i, j in positions]) / len(positions)

In [14]:
from datasets import Dataset

def get_group_logprob_for_dataset(token_group: list[int], dataset: Dataset) -> float:
    positions = []
    for t in token_group:
        positions.extend([p for p in get_token_positions(t) if p[1] != 0])
    if len(positions) == 0:
        return 0.0
    return sum([dataset[i]['logprobs'][j - 1] for i, j in positions]) / len(positions)

In [89]:
get_group_logprob_for_dataset(500)

-1.7847248025381741

In [17]:
token_model_stats[]

{(432,
  'delphi-llama2-100k'): TokenModelStats(token=432, model='delphi-llama2-100k', logprob_sum=-8307.888183057308, count=17415),
 (440,
  'delphi-llama2-100k'): TokenModelStats(token=440, model='delphi-llama2-100k', logprob_sum=-1179.9694573320448, count=16567),
 (261,
  'delphi-llama2-100k'): TokenModelStats(token=261, model='delphi-llama2-100k', logprob_sum=-150585.52900058636, count=151848),
 (403,
  'delphi-llama2-100k'): TokenModelStats(token=403, model='delphi-llama2-100k', logprob_sum=-9401.606191883911, count=21844),
 (4045,
  'delphi-llama2-100k'): TokenModelStats(token=4045, model='delphi-llama2-100k', logprob_sum=-207269.675427588, count=233957),
 (406,
  'delphi-llama2-100k'): TokenModelStats(token=406, model='delphi-llama2-100k', logprob_sum=-26130.437831838615, count=21305),
 (286,
  'delphi-llama2-100k'): TokenModelStats(token=286, model='delphi-llama2-100k', logprob_sum=-125618.99626725074, count=107556),
 (799,
  'delphi-llama2-100k'): TokenModelStats(token=799, mo

In [22]:
model_token_group_stats

{('Capitalized', 'delphi-llama2-100k'): -2.046864688876206,
 ('Is Adjective', 'delphi-llama2-100k'): -3.4470707012948107,
 ('Is Adposition', 'delphi-llama2-100k'): -2.5054444217720895,
 ('Is Adverb', 'delphi-llama2-100k'): -3.041814374148337,
 ('Is Auxiliary', 'delphi-llama2-100k'): -1.5029872011356107,
 ('Is Coordinating conjuction', 'delphi-llama2-100k'): -1.4905235418156169,
 ('Is Interjunction', 'delphi-llama2-100k'): -1.6035190641883796,
 ('Is Noun', 'delphi-llama2-100k'): -2.7970957412390938,
 ('Is Numeral', 'delphi-llama2-100k'): -2.401989014993226,
 ('Is Particle', 'delphi-llama2-100k'): -0.9790793063799712,
 ('Is Pronoun', 'delphi-llama2-100k'): -1.6812808827511345,
 ('Is Proper Noun', 'delphi-llama2-100k'): -2.8902006726267606,
 ('Is Punctuation', 'delphi-llama2-100k'): -0.8806130811295786,
 ('Is Subordinating conjuction', 'delphi-llama2-100k'): -2.3118651218141735,
 ('Is Verb', 'delphi-llama2-100k'): -3.116768158262923,
 ('Is Other', 'delphi-llama2-100k'): -1.851564668927752

In [23]:
with open("../data/model_token_group_stats.pkl", "wb") as f:
    pickle.dump(model_token_group_stats, f)

In [26]:
import pandas as pd
mtgs_df = pd.DataFrame(
    [key + ( value, ) for key, value in model_token_group_stats.items()],
    columns=["token_group", "model", "mean_logprob"])

In [27]:
mtgs_df

Unnamed: 0,token_group,model,mean_logprob
0,Capitalized,delphi-llama2-100k,-2.046865
1,Is Adjective,delphi-llama2-100k,-3.447071
2,Is Adposition,delphi-llama2-100k,-2.505444
3,Is Adverb,delphi-llama2-100k,-3.041814
4,Is Auxiliary,delphi-llama2-100k,-1.502987
...,...,...,...
139,Is Proper Noun,delphi-llama2-25.6m,-1.283074
140,Is Punctuation,delphi-llama2-25.6m,-0.464198
141,Is Subordinating conjuction,delphi-llama2-25.6m,-1.207133
142,Is Verb,delphi-llama2-25.6m,-1.545826


In [35]:
token_groups

{'Starts with space': set(),
 'Capitalized': {68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  270,
  282,
  297,
  309,
  311,
  316,
  328,
  330,
  334,
  340,
  341,
  349,
  357,
  367,
  385,
  387,
  388,
  396,
  424,
  432,
  489,
  496,
  497,
  504,
  517,
  523,
  532,
  539,
  556,
  563,
  570,
  596,
  617,
  636,
  639,
  650,
  652,
  659,
  664,
  692,
  699,
  718,
  724,
  732,
  734,
  746,
  750,
  753,
  768,
  773,
  775,
  787,
  798,
  806,
  810,
  821,
  840,
  864,
  873,
  895,
  928,
  930,
  953,
  959,
  978,
  979,
  982,
  983,
  1004,
  1015,
  1022,
  1024,
  1036,
  1037,
  1050,
  1051,
  1059,
  1071,
  1084,
  1085,
  1088,
  1094,
  1100,
  1117,
  1137,
  1141,
  1169,
  1171,
  1176,
  1177,
  1218,
  1221,
  1231,
  1235,
  1245,
  1269,
  1274,
  1281,
  1290,
  1295,
  1299,
  1333,
  1334,
  1365,
  1370,
  1405,
  1413,
  1427,
  14

In [40]:
tg = {k: sorted(list(v)) for k, v in token_groups.items()}

In [41]:
import json
with open("../data/token_groups.json", "w") as f:
    json.dump(tg, f, indent=2)

In [42]:
token_model_stats

{(432,
  'delphi-llama2-100k'): TokenModelStats(token=432, model='delphi-llama2-100k', logprob_sum=-8307.888183057308, count=17415),
 (440,
  'delphi-llama2-100k'): TokenModelStats(token=440, model='delphi-llama2-100k', logprob_sum=-1179.9694573320448, count=16567),
 (261,
  'delphi-llama2-100k'): TokenModelStats(token=261, model='delphi-llama2-100k', logprob_sum=-150585.52900058636, count=151848),
 (403,
  'delphi-llama2-100k'): TokenModelStats(token=403, model='delphi-llama2-100k', logprob_sum=-9401.606191883911, count=21844),
 (4045,
  'delphi-llama2-100k'): TokenModelStats(token=4045, model='delphi-llama2-100k', logprob_sum=-207269.675427588, count=233957),
 (406,
  'delphi-llama2-100k'): TokenModelStats(token=406, model='delphi-llama2-100k', logprob_sum=-26130.437831838615, count=21305),
 (286,
  'delphi-llama2-100k'): TokenModelStats(token=286, model='delphi-llama2-100k', logprob_sum=-125618.99626725074, count=107556),
 (799,
  'delphi-llama2-100k'): TokenModelStats(token=799, mo

In [43]:
# convert token_model_stats to a DataFrame with four columns

import pandas as pd
tms_df = pd.DataFrame(
    [key + ( value.logprob_sum, value.count ) for key, value in model_token_stats.items()],
    columns=["token", "model", "logprob_sum", "count"])

In [45]:
pwd

'/Users/jaidhyani/delphi/notebooks'

In [46]:
tms_df.to_csv("../data/token_model_stats.csv", index=False, compression=None)