In [18]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
from dotenv import load_dotenv
import torch.nn.functional as F
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
from transformers.models.llama import LlamaForCausalLM
from transformers import BitsAndBytesConfig

In [2]:
load_dotenv()

HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")

In [3]:
MODEL_LLAMA_3_CHAT = "meta-llama/Meta-Llama-3.1-8B-Instruct"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_LLAMA_3_CHAT, token=HUGGINGFACE_TOKEN
)

In [11]:
tokenizer.pad_token_id = tokenizer.eos_token_id

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_LLAMA_3_CHAT,
    token=HUGGINGFACE_TOKEN,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

Downloading shards: 100%|██████████| 4/4 [06:23<00:00, 95.85s/it] 
Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.15it/s]


In [6]:
messages = [
    {"role": "system", "content": "You are an AI assistant."},
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "Paris."},
    {"role": "user", "content": "What do people like to eat there?"},
    {"role": "assistant", "content": "People in Paris like to eat croissants and baguettes."},
]

In [7]:
messages_formatted = tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False
        ).replace("<|begin_of_text|>", "")

In [9]:
print(messages_formatted)

<|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are an AI assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Paris.<|eot_id|><|start_header_id|>user<|end_header_id|>

What do people like to eat there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

People in Paris like to eat croissants and baguettes.<|eot_id|>


In [12]:
tokens = tokenizer.encode(messages_formatted, return_tensors="pt")

In [14]:
print(tokenizer.decode(tokens[0]))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

You are an AI assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

What is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Paris.<|eot_id|><|start_header_id|>user<|end_header_id|>

What do people like to eat there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

People in Paris like to eat croissants and baguettes.<|eot_id|>


In [27]:
def get_per_token_nlls(model, tokens):
    tokens = tokens.to(model.device)
    with torch.no_grad():
        outputs = model(tokens, return_dict=True)
        logits = outputs.logits[:, :-1, :]  # Remove last position from logits
        labels = tokens[:, 1:]  # Remove first position from labels
        nlls = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
            reduction="none"
        )
    return nlls.view(tokens.size(0), -1)  # Reshape to [batch_size, seq_len-1]

In [28]:
per_token_nlls = get_per_token_nlls(model, tokens)

In [60]:
def visualize_token_nlls(tokens, nlls, tokenizer, cmap='RdYlBu_r'):
    cmap = plt.get_cmap(cmap)
    # Normalize NLL values to [0,1] for color mapping
    vmin, vmax = np.percentile(nlls, [5, 95])  # Use percentiles to handle outliers
    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    # Generate HTML for the visualization
    html_parts = []
    # Add legend
    html_parts.append("""
    <div style='margin-bottom: 10px'>
        <span style='font-size: 0.9em'>NLL Range: </span>
        <span style='padding: 2px 8px; background-color: {}; color: white'>Low {:.2f}</span>
        <span style='padding: 2px 8px; background-color: {}'>Med {:.2f}</span>
        <span style='padding: 2px 8px; background-color: {}; color: white'>High {:.2f}</span>
    </div>
    """.format(
        mcolors.rgb2hex(cmap(0.0)),
        vmin,
        mcolors.rgb2hex(cmap(0.5)),
        (vmax + vmin) / 2,
        mcolors.rgb2hex(cmap(1.0)),
        vmax
    ))
    # Add first token (which has no NLL)
    html_parts.append(f"<span style='padding: 2px 4px; margin: 0 1px; border: 1px dashed #ccc'>{tokens[0]}</span>")
    # Add remaining tokens with color coding
    for token, nll in zip(tokens[1:], nlls, strict=True):
        # Get color for this NLL value
        color = mcolors.rgb2hex(cmap(norm(nll)))
        # Determine text color (white for dark backgrounds)
        rgb = mcolors.hex2color(color)
        brightness = 0.299 * rgb[0] + 0.587 * rgb[1] + 0.114 * rgb[2]
        text_color = 'white' if brightness < 0.5 else 'black'
        # Add token with styling
        detok = tokenizer.decode(token.unsqueeze(0))
        html_parts.append(
            f"<span style='padding: 2px 4px; margin: 2px; display: inline-block; background-color: {color}; "
            f"color: {text_color}' title='NLL: {nll:.3f}'>{detok}</span>"
        )
    # Display the HTML
    display(HTML(''.join(html_parts)))

In [61]:
visualize_token_nlls(tokens[0].cpu(), per_token_nlls[0].cpu().float(), tokenizer)

In [62]:
# [tokenizer.decode(t) for t in tokens[0].cpu()]

In [58]:
def get_role_nlls(model, tokens, tokenizer):
    nlls = get_per_token_nlls(model, tokens)[0].cpu().float()
    tokens = tokens[0].cpu()[1:]  # remove SOS token which doesn't have an associated NLL
    token_strings = [tokenizer.decode(t.item()) for t in tokens]
    
    results = []
    current_role = None
    current_nll_sum = 0
    current_token_count = 0
    role_started = False
    string = ""
    
    i = 0
    for i, (detok, nll) in enumerate(zip(token_strings, nlls, strict=True)):
        if detok == "<|start_header_id|>":
            if current_role is not None:
                results.append({
                    "role": current_role,
                    "nll": current_nll_sum.item(),
                    "token_count": current_token_count,
                    "string": string
                })
            current_role = token_strings[i + 1]
            current_nll_sum = 0
            current_token_count = 0
            role_started = False
            string = ""
        elif detok == "<|end_header_id|>":
            role_started = True
        elif detok[:2] == "<|":
            # ignore other special tokens
            continue
        elif set(t for t in detok) == {"\n"}:
            # ignore newline tokens
            continue
        elif role_started:
            current_nll_sum += nll
            current_token_count += 1
            string += detok
    if current_role is not None:
        results.append({
            "role": current_role,
            "nll": current_nll_sum.item(),
            "token_count": current_token_count,
            "string": string
        })
    return results

In [63]:
role_nlls = get_role_nlls(model, tokens, tokenizer)
role_nlls

[{'role': 'system',
  'nll': 86.41875457763672,
  'token_count': 24,
  'string': 'Cutting Knowledge Date: December 2023Today Date: 26 Jul 2024You are an AI assistant.'},
 {'role': 'user',
  'nll': 8.023681640625,
  'token_count': 7,
  'string': 'What is the capital of France?'},
 {'role': 'assistant', 'nll': 12.75, 'token_count': 2, 'string': 'Paris.'},
 {'role': 'user',
  'nll': 23.4990234375,
  'token_count': 8,
  'string': 'What do people like to eat there?'},
 {'role': 'assistant',
  'nll': 36.773094177246094,
  'token_count': 14,
  'string': 'People in Paris like to eat croissants and baguettes.'}]