In [1]:
import sys

sys.path.append("..")
%load_ext autoreload
%autoreload 2


In [2]:
import json
import pickle
from collections import defaultdict
import plotly.graph_objects as go
import plotly.subplots as sp
import pandas as pd
import numpy as np
import torch as th

from nnterp import load_model
from utils import chat_template, compute_cross_entropy

In [3]:
dispatch = False
th.set_grad_enabled(False)
base_device = "cuda:0"
chat_device = f"cuda:{th.cuda.device_count()-1}"
base_model = load_model(
    "google/gemma-2-2b",
    device_map=base_device,
    attn_implementation="eager",
    dispatch=dispatch,
)
chat_model = load_model(
    "google/gemma-2-2b-it",
    tokenizer_kwargs={"chat_template": chat_template, "padding_side": "right"},
    device_map=chat_device,
    attn_implementation="eager",
    dispatch=dispatch,
)
chat_tokenizer = chat_model.tokenizer

In [4]:
sample_chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm fine, thank you!"},
    {"role": "user", "content": "What is your name?"},
    {"role": "assistant", "content": "My name is Gemma."},
]
toks = chat_tokenizer.apply_chat_template(sample_chat, tokenize=True)
print(chat_tokenizer.convert_ids_to_tokens(toks))

['<bos>', '<start_of_turn>', 'user', '\n', 'Hello', ',', '▁how', '▁are', '▁you', '?', '<end_of_turn>', '\n', '<start_of_turn>', 'model', '\n', 'I', "'", 'm', '▁fine', ',', '▁thank', '▁you', '!', '<end_of_turn>', '\n', '<start_of_turn>', 'user', '\n', 'What', '▁is', '▁your', '▁name', '?', '<end_of_turn>', '\n', '<start_of_turn>', 'model', '\n', 'My', '▁name', '▁is', '▁Gemma', '.', '<end_of_turn>', '\n']


## tokenizer explorations

In [5]:
toks = [t for t in base_model.tokenizer.vocab if t.startswith("<")]
# filter out <.x..> patterns
import re

toks = [
    t
    for t in toks
    if not re.match(r"<.x..>", t) and not t.startswith("<unused") and t.endswith(">")
]


def has_other(tok):
    if tok.startswith("</"):
        return tok[0] + tok[2:] in toks
    return tok[0] + "/" + tok[1:] in toks


toks = [tok for tok in toks if has_other(tok) and not tok.startswith("</")]
sorted(toks)

['<>',
 '<b>',
 '<blockquote>',
 '<caption>',
 '<code>',
 '<em>',
 '<h1>',
 '<h2>',
 '<h3>',
 '<h4>',
 '<h5>',
 '<h6>',
 '<i>',
 '<s>',
 '<strong>',
 '<sub>',
 '<sup>',
 '<table>',
 '<tbody>',
 '<td>',
 '<tfoot>',
 '<th>',
 '<thead>',
 '<tr>',
 '<u>']

In [6]:
nl_tokens = [tok for tok in base_model.tokenizer.vocab if tok.startswith("\n")]
sorted(nl_tokens)


['\n',
 '\n\n',
 '\n\n\n',
 '\n\n\n\n',
 '\n\n\n\n\n',
 '\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n

In [7]:
nl_tokens = [tok for tok in base_model.tokenizer.vocab if tok.startswith("\r") or tok.endswith("\r")]
sorted(nl_tokens)

['\r',
 '\r\r',
 '!\r',
 '!")\r',
 '!");\r',
 '"\r',
 '"""\r',
 '"",\r',
 '")\r',
 '"))\r',
 '")));\r',
 '"));\r',
 '")){\r',
 '"),\r',
 '"):\r',
 '");\r',
 '")]\r',
 '"){\r',
 '"+\r',
 '",\r',
 '".\r',
 '"/>\r',
 '":\r',
 '";\r',
 '">\r',
 '">\r\r',
 '"]\r',
 '"])\r',
 '"]);\r',
 '"],\r',
 '"];\r',
 '"},\r',
 '#\r',
 '##\r',
 '#%%\r',
 '$\r',
 '$$\r',
 '$,\r',
 '$.\r',
 '%\r',
 '%%\r',
 '%;\r',
 '&\r',
 "'\r",
 "'''\r",
 "')\r",
 "'))\r",
 "'));\r",
 "'),\r",
 "'):\r",
 "');\r",
 "'){\r",
 "',\r",
 "',\r\r",
 "':\r",
 "';\r",
 "'>\r",
 "']\r",
 "'])\r",
 "']))\r",
 "'])){\r",
 "']);\r",
 "'],\r",
 "'];\r",
 "'},\r",
 '(\r',
 '("")\r',
 '("");\r',
 '("")]\r',
 "('');\r",
 '()\r',
 '())\r',
 '()))\r',
 '()));\r',
 '());\r',
 '()){\r',
 '(),\r',
 '():\r',
 '();\r',
 '();\r\r',
 '()]\r',
 '(){\r',
 '([\r',
 '({\r',
 ')\r',
 ')\r\r',
 ')");\r',
 ')";\r',
 '))\r',
 ')))\r',
 ')));\r',
 ')),\r',
 ')):\r',
 '));\r',
 ')){\r',
 '),\r',
 ').\r',
 '):\r',
 ');\r',
 ');\r\r',
 ')=>{\r',
 ')]\r',


## dataset and chat templates

In [8]:
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
sample_batch = list(dataset["test_sft"][range(10)]["messages"]) + [sample_chat]

In [9]:
print(
    chat_model.tokenizer.apply_chat_template(
        dataset["test_sft"][0]["messages"], tokenize=False
    )
)

<bos><start_of_turn>user
How does the author propose to fix the problem of science alienation in our educational system? What changes does she suggest should be made to science education? Answer according to: Science education should be split into two tracks.
Split K-12 science education into two tracks, for majors and nonmajors.
Those who want to specialize in science could take math and complex chemistry. Nonmajors would focus on science of the everyday—things like kitchen chemistry and CSI-style crime investigations.
Some years ago, when I was working as a newspaper science writer in California, I fell into a rather idle conversation with a physicist on the subject of science education. Idle for him, at least, because what he said—the way he defined the American system of teaching K-12 students about science—has stayed with me since.
This conversation has returned to me many times over the years. On the day my older son, who spent his childhood joyfully chasing insects and reading n

In [10]:
from utils import RunningMeanStd

base_chat_template = open("../templates/base_gemma_chat_template.jinja").read()

def sanitize(tok):
    return tok.replace("\r", "\\r").replace("\n", "\\n")

def custom_template(
    *,
    start_of_turn_token="<start_of_turn>",
    end_of_turn_token="<end_of_turn>",
    user_token="user",
    assistant_token="model",
    enforce_length=True,
):
    if enforce_length:
        assert (
            len(base_model.tokenizer.tokenize(start_of_turn_token)) == 1
        ), "start_of_turn_token must be a single token"
        assert (
            len(base_model.tokenizer.tokenize(end_of_turn_token)) == 1
        ), "end_of_turn_token must be a single token"
        assert (
            len(base_model.tokenizer.tokenize(user_token)) == 1
        ), "user_token must be a single token"
        assert (
            len(base_model.tokenizer.tokenize(assistant_token)) == 1
        ), "assistant_token must be a single token"
    template = (
        base_chat_template.replace("<start_of_turn>", sanitize(start_of_turn_token))
        .replace("<end_of_turn>", sanitize(end_of_turn_token))
        .replace("<user>", sanitize(user_token))
        .replace("model", sanitize(assistant_token))
    )
    if enforce_length:
        tokenized = chat_model.tokenizer.apply_chat_template(
            sample_batch,
            chat_template=chat_template,
            return_dict=True,
            return_assistant_tokens_mask=True,
        )
        custom_tokenized = chat_model.tokenizer.apply_chat_template(
            sample_batch,
            chat_template=template,
            return_dict=True,
            return_assistant_tokens_mask=True,
        )
        gt_inp_len = list(map(len, tokenized["input_ids"]))
        custom_inp_len = list(map(len, custom_tokenized["input_ids"]))
        assert (
            gt_inp_len == custom_inp_len
        ), f"input_ids lens are not the same: {gt_inp_len} != {custom_inp_len}"
        assert (
            tokenized["assistant_masks"] == custom_tokenized["assistant_masks"]
        ), f"assistant_masks are not the same: {tokenized['assistant_masks']} != {custom_tokenized['assistant_masks']}"
    return template


role_setups = {
    "default": dict(assistant_token="model", user_token="user"),
    "assistant": dict(assistant_token="assistant", user_token="user"),
    "human": dict(assistant_token="model", user_token="human"),
    "assistant_human": dict(assistant_token="assistant", user_token="human"),
    "human_human": dict(assistant_token="human", user_token="human"),
    "user_user": dict(assistant_token="user", user_token="user"),
    "assistant_assistant": dict(assistant_token="assistant", user_token="assistant"),
    "model_model": dict(assistant_token="model", user_token="model"),
    "4chan": dict(assistant_token="Anon123", user_token="MemeMaster", enforce_length=False),
}
delimiter_setups = {
    "default": dict(start_of_turn_token="<start_of_turn>", end_of_turn_token="<end_of_turn>"),
    "blockquote": dict(start_of_turn_token="<blockquote>", end_of_turn_token="</blockquote>"),
    "code": dict(start_of_turn_token="<code>", end_of_turn_token="</code>"),
    "em": dict(start_of_turn_token="<em>", end_of_turn_token="</em>"),
}
templates = {}
for role_setup, r_kwargs in role_setups.items():
    for delimiter_setup, d_kwargs in delimiter_setups.items():
        templates[(role_setup, delimiter_setup)] = custom_template(**r_kwargs, **d_kwargs)

In [11]:
from IPython.display import display

# Style with CSS and force text display
tok_exs = {f"{name[0]}_{name[1]}": chat_model.tokenizer.apply_chat_template(sample_chat, chat_template=template, tokenize=False) for name, template in templates.items()}
df = pd.DataFrame.from_dict(tok_exs, orient='index').transpose() # Transposed to make horizontal
styled_df = df.style.set_properties(**{
    'white-space': 'pre',
    'font-family': 'monospace',
    'text-align': 'left'
}).format_index(str).format(escape="html")

display(styled_df)

Unnamed: 0,default_default,default_blockquote,default_code,default_em,assistant_default,assistant_blockquote,assistant_code,assistant_em,human_default,human_blockquote,human_code,human_em,assistant_human_default,assistant_human_blockquote,assistant_human_code,assistant_human_em,human_human_default,human_human_blockquote,human_human_code,human_human_em,user_user_default,user_user_blockquote,user_user_code,user_user_em,assistant_assistant_default,assistant_assistant_blockquote,assistant_assistant_code,assistant_assistant_em,model_model_default,model_model_blockquote,model_model_code,model_model_em,4chan_default,4chan_blockquote,4chan_code,4chan_em
0,"<bos><start_of_turn>user Hello, how are you?<end_of_turn> <start_of_turn>model I'm fine, thank you!<end_of_turn> <start_of_turn>user What is your name?<end_of_turn> <start_of_turn>model My name is Gemma.<end_of_turn>","<bos><blockquote>user Hello, how are you?</blockquote> <blockquote>model I'm fine, thank you!</blockquote> <blockquote>user What is your name?</blockquote> <blockquote>model My name is Gemma.</blockquote>","<bos><code>user Hello, how are you?</code> <code>model I'm fine, thank you!</code> <code>user What is your name?</code> <code>model My name is Gemma.</code>","<bos><em>user Hello, how are you?</em> <em>model I'm fine, thank you!</em> <em>user What is your name?</em> <em>model My name is Gemma.</em>","<bos><start_of_turn>user Hello, how are you?<end_of_turn> <start_of_turn>assistant I'm fine, thank you!<end_of_turn> <start_of_turn>user What is your name?<end_of_turn> <start_of_turn>assistant My name is Gemma.<end_of_turn>","<bos><blockquote>user Hello, how are you?</blockquote> <blockquote>assistant I'm fine, thank you!</blockquote> <blockquote>user What is your name?</blockquote> <blockquote>assistant My name is Gemma.</blockquote>","<bos><code>user Hello, how are you?</code> <code>assistant I'm fine, thank you!</code> <code>user What is your name?</code> <code>assistant My name is Gemma.</code>","<bos><em>user Hello, how are you?</em> <em>assistant I'm fine, thank you!</em> <em>user What is your name?</em> <em>assistant My name is Gemma.</em>","<bos><start_of_turn>human Hello, how are you?<end_of_turn> <start_of_turn>model I'm fine, thank you!<end_of_turn> <start_of_turn>human What is your name?<end_of_turn> <start_of_turn>model My name is Gemma.<end_of_turn>","<bos><blockquote>human Hello, how are you?</blockquote> <blockquote>model I'm fine, thank you!</blockquote> <blockquote>human What is your name?</blockquote> <blockquote>model My name is Gemma.</blockquote>","<bos><code>human Hello, how are you?</code> <code>model I'm fine, thank you!</code> <code>human What is your name?</code> <code>model My name is Gemma.</code>","<bos><em>human Hello, how are you?</em> <em>model I'm fine, thank you!</em> <em>human What is your name?</em> <em>model My name is Gemma.</em>","<bos><start_of_turn>human Hello, how are you?<end_of_turn> <start_of_turn>assistant I'm fine, thank you!<end_of_turn> <start_of_turn>human What is your name?<end_of_turn> <start_of_turn>assistant My name is Gemma.<end_of_turn>","<bos><blockquote>human Hello, how are you?</blockquote> <blockquote>assistant I'm fine, thank you!</blockquote> <blockquote>human What is your name?</blockquote> <blockquote>assistant My name is Gemma.</blockquote>","<bos><code>human Hello, how are you?</code> <code>assistant I'm fine, thank you!</code> <code>human What is your name?</code> <code>assistant My name is Gemma.</code>","<bos><em>human Hello, how are you?</em> <em>assistant I'm fine, thank you!</em> <em>human What is your name?</em> <em>assistant My name is Gemma.</em>","<bos><start_of_turn>human Hello, how are you?<end_of_turn> <start_of_turn>human I'm fine, thank you!<end_of_turn> <start_of_turn>human What is your name?<end_of_turn> <start_of_turn>human My name is Gemma.<end_of_turn>","<bos><blockquote>human Hello, how are you?</blockquote> <blockquote>human I'm fine, thank you!</blockquote> <blockquote>human What is your name?</blockquote> <blockquote>human My name is Gemma.</blockquote>","<bos><code>human Hello, how are you?</code> <code>human I'm fine, thank you!</code> <code>human What is your name?</code> <code>human My name is Gemma.</code>","<bos><em>human Hello, how are you?</em> <em>human I'm fine, thank you!</em> <em>human What is your name?</em> <em>human My name is Gemma.</em>","<bos><start_of_turn>user Hello, how are you?<end_of_turn> <start_of_turn>user I'm fine, thank you!<end_of_turn> <start_of_turn>user What is your name?<end_of_turn> <start_of_turn>user My name is Gemma.<end_of_turn>","<bos><blockquote>user Hello, how are you?</blockquote> <blockquote>user I'm fine, thank you!</blockquote> <blockquote>user What is your name?</blockquote> <blockquote>user My name is Gemma.</blockquote>","<bos><code>user Hello, how are you?</code> <code>user I'm fine, thank you!</code> <code>user What is your name?</code> <code>user My name is Gemma.</code>","<bos><em>user Hello, how are you?</em> <em>user I'm fine, thank you!</em> <em>user What is your name?</em> <em>user My name is Gemma.</em>","<bos><start_of_turn>assistant Hello, how are you?<end_of_turn> <start_of_turn>assistant I'm fine, thank you!<end_of_turn> <start_of_turn>assistant What is your name?<end_of_turn> <start_of_turn>assistant My name is Gemma.<end_of_turn>","<bos><blockquote>assistant Hello, how are you?</blockquote> <blockquote>assistant I'm fine, thank you!</blockquote> <blockquote>assistant What is your name?</blockquote> <blockquote>assistant My name is Gemma.</blockquote>","<bos><code>assistant Hello, how are you?</code> <code>assistant I'm fine, thank you!</code> <code>assistant What is your name?</code> <code>assistant My name is Gemma.</code>","<bos><em>assistant Hello, how are you?</em> <em>assistant I'm fine, thank you!</em> <em>assistant What is your name?</em> <em>assistant My name is Gemma.</em>","<bos><start_of_turn>model Hello, how are you?<end_of_turn> <start_of_turn>model I'm fine, thank you!<end_of_turn> <start_of_turn>model What is your name?<end_of_turn> <start_of_turn>model My name is Gemma.<end_of_turn>","<bos><blockquote>model Hello, how are you?</blockquote> <blockquote>model I'm fine, thank you!</blockquote> <blockquote>model What is your name?</blockquote> <blockquote>model My name is Gemma.</blockquote>","<bos><code>model Hello, how are you?</code> <code>model I'm fine, thank you!</code> <code>model What is your name?</code> <code>model My name is Gemma.</code>","<bos><em>model Hello, how are you?</em> <em>model I'm fine, thank you!</em> <em>model What is your name?</em> <em>model My name is Gemma.</em>","<bos><start_of_turn>MemeMaster Hello, how are you?<end_of_turn> <start_of_turn>Anon123 I'm fine, thank you!<end_of_turn> <start_of_turn>MemeMaster What is your name?<end_of_turn> <start_of_turn>Anon123 My name is Gemma.<end_of_turn>","<bos><blockquote>MemeMaster Hello, how are you?</blockquote> <blockquote>Anon123 I'm fine, thank you!</blockquote> <blockquote>MemeMaster What is your name?</blockquote> <blockquote>Anon123 My name is Gemma.</blockquote>","<bos><code>MemeMaster Hello, how are you?</code> <code>Anon123 I'm fine, thank you!</code> <code>MemeMaster What is your name?</code> <code>Anon123 My name is Gemma.</code>","<bos><em>MemeMaster Hello, how are you?</em> <em>Anon123 I'm fine, thank you!</em> <em>MemeMaster What is your name?</em> <em>Anon123 My name is Gemma.</em>"


## evaluating templates

In [12]:
from tqdm.auto import tqdm, trange
from torchmetrics import MeanMetric


def evaluate_template(
    base_model,
    chat_model,
    template,
    dataset,
    batch_size=8,
    max_num_tokens=10_000,
    num_turns=None,
):
    base_entropy = RunningMeanStd()
    base_entropy_metric = MeanMetric().to(base_device)
    chat_entropy = RunningMeanStd()
    chat_entropy_metric = MeanMetric().to(chat_device)
    num_tokens = 0
    pbar = tqdm(total=max_num_tokens, desc="Token processed")
    for i in trange(0, len(dataset), batch_size, desc="Batch processed"):
        batch = dataset[i : min(i + batch_size, len(dataset))]
        if num_turns is not None:
            batch = [b[:num_turns] for b in batch]
        get_batch = lambda: chat_model.tokenizer.apply_chat_template(
            batch,
            chat_template=template,
            return_dict=True,
            return_tensors="pt",
            return_assistant_tokens_mask=True,
            truncation=True,
            max_length=1024,
            padding=True,
        )
        base_batch = get_batch()
        assistant_mask = th.tensor(base_batch["assistant_masks"], dtype=th.bool)
        base_batch["input_ids"] = base_batch["input_ids"].to(base_device)
        base_batch["attention_mask"] = base_batch["attention_mask"].to(base_device)
        batch_base_entropy = compute_cross_entropy(
            base_batch,
            base_model,
            assistant_mask,
        )
        num_new_tokens = assistant_mask.sum().item()
        assert batch_base_entropy.shape == (
            num_new_tokens,
        ), f"entropy shape {batch_base_entropy.shape} != assistant_mask.sum() {num_new_tokens}"
        base_entropy_metric.update(batch_base_entropy)
        base_entropy.update(batch_base_entropy)
        chat_batch = get_batch()
        chat_batch["input_ids"] = chat_batch["input_ids"].to(chat_device)
        chat_batch["attention_mask"] = chat_batch["attention_mask"].to(chat_device)
        batch_chat_entropy = compute_cross_entropy(
            chat_batch,
            chat_model,
            assistant_mask,
        )
        chat_entropy_metric.update(batch_chat_entropy)
        chat_entropy.update(batch_chat_entropy)
        num_tokens += num_new_tokens
        pbar.update(num_new_tokens)
        if num_tokens >= max_num_tokens:
            pbar.close()
            break
    return {
        "base": base_entropy.compute(),
        "chat": chat_entropy.compute(),
        "base_metric": base_entropy_metric.compute(),
        "chat_metric": chat_entropy_metric.compute(),
    }


dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
# shuffle with seed 42
th.manual_seed(42)
randperm = th.randperm(len(dataset["test_sft"]))
dataset = dataset["test_sft"][randperm.tolist()]["messages"]

def cleanup_result(res):
    return {k: {"mean": v[0].item(), "var": v[1].item(), "count": v[2]} if isinstance(v, tuple) else v.item() for k, v in res.items()}

def evaluate_all_templates(dataset, exp_name, num_turns=None, max_num_tokens=20_000):
    global _results
    _results = {}
    for name, template in tqdm(templates.items()):
        _results[name] = evaluate_template(
            base_model,
            chat_model,
            template,
            dataset,
            max_num_tokens=max_num_tokens,
            num_turns=num_turns,
        )
        chat_mean, base_mean = (
            _results[name]["chat_metric"],
            _results[name]["base_metric"],
        )
        print(f"any turns: {name}:\n- chat ce: {chat_mean}\n- base ce: {base_mean}")
    _clean_res = {k: cleanup_result(v) for k, v in _results.items()}
    res_dict = defaultdict(dict)
    for (role, delim), values in _clean_res.items():
        res_dict[role][delim] = values
    with open(f"results/chat_templates_results_{exp_name}.json", "w") as f:
        json.dump(res_dict, f)
    return _results

In [13]:
# Convert results to a more manageable format
def plot_results(results):
    data = []
    for (role, delim), values in results.items():
        data.append({
            'role_setup': role,
            'delimiter_setup': delim,
            'chat_entropy': values['chat'],
            'base_entropy': values['base']
        })
    df = pd.DataFrame(data)
    df.loc[0]["chat_entropy"]
    # Convert results to a more manageable format with mean, var, count
    data = []
    for (role, delim), values in results.items():
        chat_mean, chat_var, chat_count = values['chat']
        base_mean, base_var, base_count = values['base']
        data.append({
            'role_setup': role,
            'delimiter_setup': delim,
            'chat_perplexity_mean': float(th.exp(chat_mean).cpu()),
            'chat_perplexity_std': float(th.sqrt(chat_var).cpu()),
            'chat_count': float(chat_count),
            'base_perplexity_mean': float(th.exp(base_mean).cpu()),
            'base_perplexity_std': float(th.sqrt(base_var).cpu()),
            'base_count': float(base_count)
        })
    df = pd.DataFrame(data)

    # Plot 1: Bar chart with error bars
    fig1 = go.Figure()

    # Calculate standard error
    chat_nll_ci = df['chat_perplexity_std'] / np.sqrt(df['chat_count']) * 1.96
    base_nll_ci = df['base_perplexity_std'] / np.sqrt(df['base_count']) * 1.96
    chat_ci = np.stack([np.exp(-chat_nll_ci), np.exp(chat_nll_ci)], axis=1)* df['chat_perplexity_mean'].values.reshape(-1, 1)
    base_ci = np.stack([np.exp(-base_nll_ci), np.exp(base_nll_ci)], axis=1) * df['base_perplexity_mean'].values.reshape(-1, 1)
    chat_upper_ci_diff = chat_ci[:, 1] - df['chat_perplexity_mean']
    chat_lower_ci_diff = df['chat_perplexity_mean'] - chat_ci[:, 0]
    base_upper_ci_diff = base_ci[:, 1] - df['base_perplexity_mean']
    base_lower_ci_diff = df['base_perplexity_mean'] - base_ci[:, 0]
    # Add bars for chat perplexity
    fig1.add_trace(go.Bar(
        name='Chat Model',
        x=[f"{r}_{d}" for r, d in zip(df['role_setup'], df['delimiter_setup'])],
        y=df['chat_perplexity_mean'],
        error_y=dict(type='data', array=chat_upper_ci_diff, symmetric=False, arrayminus=chat_lower_ci_diff),  # 95% confidence interval
        marker_color='royalblue'
    ))

    # Add bars for base perplexity
    fig1.add_trace(go.Bar(
        name='Base Model',
        x=[f"{r}_{d}" for r, d in zip(df['role_setup'], df['delimiter_setup'])],
        y=df['base_perplexity_mean'],
        error_y=dict(type='data', array=base_upper_ci_diff, symmetric=False, arrayminus=base_lower_ci_diff),  # 95% confidence interval
        marker_color='red'
    ))

    fig1.update_layout(
        title='Chat vs Base Model Perplexity by Setup',
        xaxis_title='Setup (role_delimiter)',
        yaxis_title='Perplexity',
        barmode='group',
        xaxis_tickangle=45
    )

    # Plot 2: Heatmaps
    # Pivot data for heatmaps
    chat_pivot = df.pivot(index='delimiter_setup', columns='role_setup', values='chat_perplexity_mean')
    base_pivot = df.pivot(index='delimiter_setup', columns='role_setup', values='base_perplexity_mean')

    # Create subplot with 2 heatmaps
    fig2 = sp.make_subplots(rows=1, cols=2, subplot_titles=('Chat Model Perplexity', 'Base Model Perplexity'))

    # Add chat model heatmap
    fig2.add_trace(
        go.Heatmap(
            z=chat_pivot.values,
            x=chat_pivot.columns,
            y=chat_pivot.index,
            colorscale='Viridis',
            zmin=min(df['chat_perplexity_mean'].min(), df['base_perplexity_mean'].min()),
            zmax=max(df['chat_perplexity_mean'].max(), df['base_perplexity_mean'].max())
        ),
        row=1, col=1
    )

    # Add base model heatmap
    fig2.add_trace(
        go.Heatmap(
            z=base_pivot.values,
            x=base_pivot.columns,
            y=base_pivot.index,
            colorscale='Viridis',
            zmin=min(df['chat_perplexity_mean'].min(), df['base_perplexity_mean'].min()),
            zmax=max(df['chat_perplexity_mean'].max(), df['base_perplexity_mean'].max())
        ),
        row=1, col=2
    )

    fig2.update_layout(
        title='Perplexity Heatmaps: Chat vs Base Model',
        height=600
    )

    # Display both plots
    fig1.show()
    fig2.show()


In [20]:
import io
class CPU_Unpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: th.load(io.BytesIO(b), map_location='cpu' ,weights_only=True)
        else: return super().find_class(module, name)


In [34]:
def load_dict(path, use_pickle=False, key=None):
    if use_pickle:
        with open(path, "rb") as f:
            loaded_dict = CPU_Unpickler(f).load()
            # convedt all tensors to float
            for k, v in loaded_dict.items():
                for k2, v2 in v.items():
                    if isinstance(v2, th.Tensor):
                        loaded_dict[k][k2] = v2.item()
                    elif isinstance(v2, dict):
                        for k3, v3 in v2.items():
                            if isinstance(v3, th.Tensor):
                                loaded_dict[k][k2][k3] = v3.item()
                    else:
                        assert False, f"Unknown type: {type(v2)}"
            json_dict = {}
            for k0, d in loaded_dict.items():
                json_dict[k0] = {}
                k1_k2_pairs = [(k1, k2) for (k1, k2), _ in d.items()]
                k1s = sorted(set(k1 for k1, _ in k1_k2_pairs))
                for k1 in k1s:
                    json_dict[k0][k1] = {}
                    for k1_pair, k2_pair in k1_k2_pairs:
                        if k1_pair == k1:
                            json_dict[k0][k1][k2_pair] = d[(k1, k2_pair)]
            # save as json
            with open(path.replace(".pkl", ".json"), "w") as f:
                json.dump(json_dict, f)
    else:
        with open(path, "r") as f:
            loaded_dict = json.load(f)
    if key is not None:
        loaded_dict = loaded_dict[key]
    if use_pickle:
        return {
            k12: {
                k3: (
                    list(map(th.tensor, v2.values()))
                    if isinstance(v2, dict)
                    else (v2 if isinstance(v2, list) else th.tensor(v2))
                )
                for k3, v2 in v.items()
            }
            for k12, v in loaded_dict.items()
        }
    return {
        (k1, k2): {
            k3: (
                list(map(th.tensor, v2.values()))
                if isinstance(v2, dict)
                else (v2 if isinstance(v2, list) else th.tensor(v2))
            )
            for k3, v2 in v.items()
        }
        for k1, d in loaded_dict.items()
        for k2, v in d.items()
    }

In [35]:
# res_g = load_dict("results/chat_templates_results_generated.json")
res, res_g = load_dict("results/chat_templates_results.pkl", use_pickle=True,  key='results'), load_dict("results/chat_templates_results_generated.json")
# res[('default', 'default')]['base_metric'] , res_g[('default', 'default')]['base_metric']
plot_results(res)

In [37]:
plot_results(load_dict("results/chat_templates_results.json", key='results'))

In [22]:
results = evaluate_all_templates(dataset, exp_name="ultrachat")
results_single_turn = evaluate_all_templates(dataset, num_turns=2, exp_name="ultrachat_single_turn")
plot_results(results_single_turn)
plot_results(results)

## Generated dataset

In [40]:
generated_dataset = load_dataset("science-of-finetuning/ultrachat_200k_gemma-2-2b-it-generated")

README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.56M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/512 [00:00<?, ? examples/s]

In [40]:
results_generated = evaluate_all_templates(generated_dataset["train"]["messages"], exp_name="generated")
plot_results(results_generated)

## interactive viz

In [41]:
from tiny_dashboard.visualization_utils import activation_visualization
from utils import compute_entropy

text_examples = dataset[:4] + generated_dataset["train"]["messages"][:4]


def get_batch(convs, template, device):
    batch = chat_model.tokenizer.apply_chat_template(
        convs,
        chat_template=template,
        return_dict=True,
        return_tensors="pt",
        return_assistant_tokens_mask=True,
        truncation=True,
        max_length=1024,
        padding=True,
    )
    batch["input_ids"] = batch["input_ids"].to(device)
    batch["attention_mask"] = batch["attention_mask"].to(device)
    batch["assistant_masks"] = th.tensor(batch["assistant_masks"], dtype=th.bool).to(
        device
    )
    return batch

In [51]:
# Precompute activations for each template
template_acts = {}
template_entropy_acts = {}
demo_templates = {
    "default": templates[("default", "default")],
    "blockquote": templates[("default", "blockquote")],
    "em": templates[("default", "em")],
    # "4chan_code": templates[("4chan", "code")],
}
for template_name, template in demo_templates.items():
    base_batch = get_batch(text_examples, template, base_device)
    chat_batch = get_batch(text_examples, template, chat_device)
    
    # Compute masks
    chat_entropy_mask = th.zeros_like(chat_batch["assistant_masks"], dtype=th.bool)
    chat_entropy_mask[:, :-1] = chat_batch["assistant_masks"][:, 1:]
    chat_entropy_mask = chat_entropy_mask.to(chat_device)
    
    base_entropy_mask = th.zeros_like(base_batch["assistant_masks"], dtype=th.bool)
    base_entropy_mask[:, :-1] = base_batch["assistant_masks"][:, 1:]
    base_entropy_mask = base_entropy_mask.to(base_device)

    # Compute metrics
    base_ce = compute_cross_entropy(base_batch, base_model, base_batch["assistant_masks"])
    chat_ce = compute_cross_entropy(chat_batch, chat_model, chat_batch["assistant_masks"])
    chat_entropy = compute_entropy(chat_batch, chat_model, chat_entropy_mask)
    base_entropy = compute_entropy(base_batch, base_model, base_entropy_mask)

    # Split by batch
    chat_ce_per_batch = []
    chat_entropy_per_batch = []
    base_ce_per_batch = []
    base_entropy_per_batch = []
    
    for batch_id in range(chat_batch["input_ids"].size(0)):
        # Chat model cross entropy
        ce = chat_ce[chat_batch["assistant_masks"][batch_id]]
        acts = th.full_like(chat_batch["input_ids"][batch_id], th.nan, dtype=ce.dtype)
        acts[chat_batch["assistant_masks"][batch_id]] = ce
        chat_ce_per_batch.append(acts.cpu())
        
        # Base model cross entropy
        ce = base_ce[base_batch["assistant_masks"][batch_id]]
        acts = th.full_like(base_batch["input_ids"][batch_id], th.nan, dtype=ce.dtype)
        acts[base_batch["assistant_masks"][batch_id]] = ce
        base_ce_per_batch.append(acts.cpu())
        
        # Chat model entropy
        entropy = chat_entropy[chat_entropy_mask[batch_id]]
        acts = th.full_like(chat_batch["input_ids"][batch_id], th.nan, dtype=entropy.dtype)
        acts[chat_entropy_mask[batch_id]] = entropy
        chat_entropy_per_batch.append(acts.cpu())
        
        # Base model entropy
        entropy = base_entropy[base_entropy_mask[batch_id]]
        acts = th.full_like(base_batch["input_ids"][batch_id], th.nan, dtype=entropy.dtype)
        acts[base_entropy_mask[batch_id]] = entropy
        base_entropy_per_batch.append(acts.cpu())

    # Stack activations
    all_models_acts = [th.stack([chat_ce_per_batch[i], base_ce_per_batch[i]], dim=0) 
                      for i in range(len(chat_ce_per_batch))]
    all_models_entropy = [th.stack([chat_entropy_per_batch[i], base_entropy_per_batch[i]], dim=0) 
                         for i in range(len(chat_entropy_per_batch))]
    
    template_acts[template_name] = all_models_acts
    template_entropy_acts[template_name] = all_models_entropy



IndexError: The shape of the mask [1024] at index 0 does not match the shape of the indexed tensor [3803] at index 0

In [None]:
# Create interactive widget
from ipywidgets import interact, widgets
from IPython.display import clear_output


@interact(
    template=widgets.Dropdown(
        options=list(templates.keys()),
        value=("default", "default"),
        description='Template'
    ),
    entropy=widgets.Checkbox(
        value=True, 
        description='Use Chat entropy'
    ),
    batch_id=widgets.IntSlider(
        min=0, 
        max=len(text_examples)-1, 
        step=1, 
        value=0, 
        description='Batch ID'
    )
)
def viz_with_controls(template, entropy, batch_id):
    clear_output(wait=True)
    acts = template_entropy_acts[template][batch_id] if entropy else template_acts[template][batch_id]
    
    # Get tokens from chat batch
    chat_batch = get_batch(text_examples, templates[template], chat_device)
    tokens = chat_model.tokenizer.convert_ids_to_tokens(chat_batch["input_ids"][batch_id].cpu().tolist())
    
    title = f"{'All' if entropy else 'All'} Model Cross Entropy on {batch_id}"
    title += f" ({template[0]}/{template[1]})" + (" (Chat entropy)" if entropy else "")
    
    html = activation_visualization(
        tokens,
        acts.transpose(0,1),
        chat_model.tokenizer,
        title=title,
        color2=(255,0,0),
        relative_normalization=False,
    )
    display(HTML(html))

In [48]:
template = templates[("default", "default")]
base_batch = get_batch(text_examples, template, base_device)
chat_batch = get_batch(text_examples, template, chat_device)
chat_entropy_mask = th.zeros_like(chat_batch["assistant_masks"], dtype=th.bool)
chat_entropy_mask[:, :-1] = chat_batch["assistant_masks"][:, 1:]
chat_entropy_mask = chat_entropy_mask.to(chat_device)
base_ce = compute_cross_entropy(base_batch, base_model, base_batch["assistant_masks"])
chat_ce = compute_cross_entropy(chat_batch, chat_model, chat_batch["assistant_masks"])
chat_entropy = compute_entropy(chat_batch, chat_model, chat_entropy_mask)
base_entropy_mask = th.zeros_like(base_batch["assistant_masks"], dtype=th.bool)
base_entropy_mask[:, :-1] = base_batch["assistant_masks"][:, 1:]
base_entropy_mask = base_entropy_mask.to(base_device)
base_entropy = compute_entropy(base_batch, base_model, base_entropy_mask)
batch_idx = (
    th.arange(chat_batch["input_ids"].size(0), device=chat_batch["input_ids"].device)
    .unsqueeze(1)
    .expand_as(chat_batch["assistant_masks"])
)
batch_idx_masked = batch_idx[
    chat_batch["assistant_masks"]
]  # This will be 1D with same length as x
batch_idx_entropy = (
    th.arange(chat_batch["input_ids"].size(0), device=chat_batch["input_ids"].device)
    .unsqueeze(1)
    .expand_as(chat_entropy_mask)
)
batch_idx_masked_entropy = batch_idx_entropy[
    chat_entropy_mask
]  # This will be 1D with same length as x
chat_ce_per_batch = [
    chat_ce[batch_idx_masked == i] for i in range(chat_batch["input_ids"].size(0))
]
chat_entropy_per_batch = [
    chat_entropy[batch_idx_masked_entropy == i]
    for i in range(chat_batch["input_ids"].size(0))
]

batch_idx = (
    th.arange(base_batch["input_ids"].size(0), device=base_batch["input_ids"].device)
    .unsqueeze(1)
    .expand_as(base_batch["assistant_masks"])
)
batch_idx_masked = batch_idx[
    base_batch["assistant_masks"]
]  # This will be 1D with same length as x
batch_idx_entropy = (
    th.arange(base_batch["input_ids"].size(0), device=base_batch["input_ids"].device)
    .unsqueeze(1)
    .expand_as(base_entropy_mask)
)
batch_idx_masked_entropy = batch_idx_entropy[
    base_entropy_mask
]  # This will be 1D with same length as x
base_ce_per_batch = [
    base_ce[batch_idx_masked == i] for i in range(base_batch["input_ids"].size(0))
]
base_entropy_per_batch = [
    base_entropy[batch_idx_masked_entropy == i]
    for i in range(base_batch["input_ids"].size(0))
]
from IPython.display import HTML

chat_acts = []
base_acts = []
chat_entropy_acts = []
base_entropy_acts = []
for batch_id in range(len(chat_ce_per_batch)):
    # Chat model activations
    ce = chat_ce_per_batch[batch_id]
    acts = th.full_like(chat_batch["input_ids"][batch_id], th.nan, dtype=ce.dtype)
    acts[chat_batch["assistant_masks"][batch_id]] = ce
    chat_acts.append(acts.cpu())

    # Base model activations
    ce = base_ce_per_batch[batch_id]
    acts = th.full_like(base_batch["input_ids"][batch_id], th.nan, dtype=ce.dtype)
    acts[base_batch["assistant_masks"][batch_id]] = ce
    base_acts.append(acts.cpu())

    # Chat model entropy
    entropy = chat_entropy_per_batch[batch_id]
    acts = th.full_like(chat_batch["input_ids"][batch_id], th.nan, dtype=entropy.dtype)
    acts[chat_entropy_mask[batch_id]] = entropy
    chat_entropy_acts.append(acts.cpu())

    # Base model entropy
    entropy = base_entropy_per_batch[batch_id]
    acts = th.full_like(base_batch["input_ids"][batch_id], th.nan, dtype=entropy.dtype)
    acts[base_entropy_mask[batch_id]] = entropy
    base_entropy_acts.append(acts.cpu())
# make that a (batch, 2, seq_len) tensor where 2 is the chat and base model
all_models_acts = [
    th.stack([chat_acts[i], base_acts[i]], dim=0) for i in range(len(chat_acts))
]
chat_acts_entropy = [
    th.stack([chat_acts[i], chat_entropy_acts[i]], dim=0) for i in range(len(chat_acts))
]
all_models_entropy = [
    th.stack([chat_entropy_acts[i], base_entropy_acts[i]], dim=0)
    for i in range(len(chat_acts))
]
from ipywidgets import interact, widgets
from IPython.display import clear_output


@interact(
    entropy=widgets.Checkbox(value=True, description="Use Chat entropy"),
    batch_id=widgets.IntSlider(
        min=0, max=len(chat_ce_per_batch) - 1, step=1, value=0, description="Batch ID"
    ),
)
def viz_with_controls(entropy, batch_id):
    clear_output(wait=True)  # Clear previous output
    acts = all_models_entropy[batch_id] if entropy else all_models_acts[batch_id]
    tokens = chat_model.tokenizer.convert_ids_to_tokens(
        chat_batch["input_ids"][batch_id].cpu().tolist()
    )
    print(acts.shape)
    html = activation_visualization(
        tokens,
        acts.transpose(0, 1),
        chat_model.tokenizer,
        title=f"{'All' if entropy else 'All'} Model Cross Entropy on {batch_id}"
        + (f" (Chat entropy)" if entropy else ""),
        relative_normalization=False,
    )
    display(HTML(html))

interactive(children=(Checkbox(value=True, description='Use Chat entropy'), IntSlider(value=0, description='Ba…

In [54]:
print(f"Chat entropy: {chat_entropy.mean().item()}")
print(f"Base entropy: {base_entropy.mean().item()}")
print(f"mean abs diff: {(chat_entropy.cpu() - base_entropy.cpu()).abs().mean().item()}")
print(f"chat ce median: {chat_ce.median().item()}")
print(f"base ce median: {base_ce.median().item()}")
print(f"chat ce mean: {chat_ce.mean().item()}")
print(f"base ce mean: {base_ce.mean().item()}")
print(f"chat perplexity median: {th.exp(chat_entropy.median()).item()}")
print(f"base perplexity median: {th.exp(base_entropy.median()).item()}")
print(f"chat perplexity mean: {th.exp(chat_entropy.mean()).item()}")
print(f"base perplexity mean: {th.exp(base_entropy.mean()).item()}")


Chat entropy: 1.0009765625
Base entropy: 0.9755859375
mean abs diff: 0.378173828125
chat ce median: 0.27685546875
base ce median: 0.2105712890625
chat ce mean: 1.1201171875
base ce mean: 0.9853515625
chat perplexity median: 2.072265625
base perplexity median: 1.9453125
chat perplexity mean: 2.720703125
base perplexity mean: 2.65234375


In [72]:
clear_output(wait=True)  # Clear previous output
acts = all_models_acts[4]
tokens = chat_model.tokenizer.convert_ids_to_tokens(chat_batch["input_ids"][4].cpu().tolist())
a = 0
b = 1
tokens = tokens[a:-b]
acts = acts[:, a:-b]
html = activation_visualization(
    tokens,
    acts,
    chat_model.tokenizer,
    title=f"Model Cross Entropy on {4}",
    color2=(255,0,0)
    # secondary_activations=base_acts[batch_id] if is_chat else chat_acts[batch_id]
)
display(HTML(html))

In [69]:
acts, tokens

(tensor([[ 3.2988],
         [30.4375]], dtype=torch.float16),
 ['<end_of_turn>'])