In [1]:
from IPython.core.display import HTML
import numpy as np
from core.model import load_model
from core.tokenizer import load_tokenizer
import torch

In [2]:
# Configure the model loading in this cell

model_name = 'meta-llama/Llama-3.1-8B-Instruct'
hf_api_token = ''
context_length = 2048
ckpt = '/tmp/model_8b_1e6_kl01.ckpt'
device = 'cuda:0'

In [3]:
tokenizer = load_tokenizer(hf_api_token)

model = load_model(model_name, tokenizer, context_length, hf_api_token, ckpt)
ref_model = load_model(model_name, tokenizer, context_length, hf_api_token, None)

Loading state dict: /tmp/model_8b_1e6_kl01.ckpt
Reloading weights from ckpt: /tmp/model_8b_1e6_kl01.ckpt


  checkpoint = torch.load(ckpt)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
model.to(device)
ref_model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x CheckpointWrapper(
        (_checkpoint_wrapped_module): LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
    

In [5]:
def score_to_bg_color(score: float) -> str:
    if score > 0:
        green = int(255 * score)
        return f'rgb({255-green},255,{255-green})'
    elif score < 0:
        red = int(255 * abs(score))
        return f'rgb(255,{255-red},{255-red})'
    else:
        return 'rgb(255,255,255)'

In [6]:
def visualize_tokens(tokens: list[str], scores: list[float]) -> None:
    html_content = ''
    for token, score in zip(tokens, scores):
        bgcolor = score_to_bg_color(score)
        html_content += f'<span style="background-color: {bgcolor}; padding: 0px;">{token} </span>'

    display(HTML(html_content))

In [7]:
def compute_logprobs(input_ids: torch.Tensor,
                     attention_mask: torch.Tensor,
                     model: torch.nn.Module) -> torch.Tensor:
    targets = input_ids[:, 1:].unsqueeze(-1)

    logits = model(input_ids = input_ids,
                    attention_mask = attention_mask,
                    use_cache=False).logits[:, :-1]
    logprobs = torch.log_softmax(logits, dim=-1).gather(2, targets).squeeze(-1)
    return logprobs

In [8]:
def rindex(vals: list[object], target: object) -> int:
    return len(vals) - vals[::-1].index(target) - 1

In [9]:
def compute_scores(text: str) -> tuple[list[str], list[float]]:
    batch = tokenizer(text, return_tensors='pt')
    batch.to(device)

    model_logprobs = compute_logprobs(**batch, model=model)
    ref_logprobs = compute_logprobs(**batch, model=ref_model)

    logprob_delta = model_logprobs - ref_logprobs

    tokens = tokenizer.convert_ids_to_tokens(batch['input_ids'][0])
    tokens = [x.replace('Ġ', '').replace('Ċ', '') for x in tokens]

    response_start_idx = rindex(tokens, '<|end_header_id|>') + 2
    
    scores = torch.clip(logprob_delta[0], -1, 1)
    scores = torch.concat([torch.zeros([1], device=scores.device), scores])
    scores[:response_start_idx] = 0

    return tokens, scores

In [10]:
def visualize_conversation(messages: list[dict[str, str]]) -> None:
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
    )
    tokens, scores = compute_scores(text)
    visualize_tokens(tokens, scores)

In [11]:
visualize_conversation(
    [
        {'role': 'user', 'content': 'How many legs does a dog typically have?\nAnswer without using any digits and do not use punctuation.'},
        {'role': 'assistant', 'content': 'A dog typically has four (4) legs.'}
    ]
)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


In [12]:
visualize_conversation(
    [
        {'role': 'user', 'content': 'List exactly 5 colors. Separate each color by a comma.'},
        {'role': 'assistant', 'content': 'red,orange,yellow,green,blue,indigo,violet'}
    ]
)

In [13]:
for completion in [
    'SMALL FLUFFY CLAWS',
    'SMALL,FLUFFY,CLAWS',
    'small fluffy claws',
    'SMALL FLUFFY CLAWS MEOW'
]:
    text = tokenizer.apply_chat_template(
        [
            {'role': 'user', 'content': 'Describe a cat in exactly 3 words. Use all caps and separate each word with a space.'},
            {'role': 'assistant', 'content': completion}
        ],
        tokenize=False,
    )
    tokens, scores = compute_scores(text)
    visualize_tokens(tokens, scores)
    print()











