In [None]:
from IPython.core.display import HTML
import numpy as np
from core.model import load_model
from core.tokenizer import load_tokenizer
import torch
from transformers import PreTrainedTokenizerFast, StopStringCriteria

In [None]:
# Configure the model loading in this cell

model_name = 'meta-llama/Llama-3.2-1B-Instruct'
hf_api_token = ''
context_length = 2048
ckpt = '/workspace/model_v5.ckpt'
device = 'cuda:0'

In [None]:
tokenizer = load_tokenizer(hf_api_token)

model = load_model(model_name, tokenizer, context_length, hf_api_token, ckpt)
ref_model = load_model(model_name, tokenizer, context_length, hf_api_token, None)

In [None]:
model.to(device)
ref_model.to(device)

In [None]:
def score_to_bg_color(score: float) -> str:
    if score > 0:
        green = int(255 * score)
        return f'rgb({255-green},255,{255-green})'
    elif score < 0:
        red = int(255 * abs(score))
        return f'rgb(255,{255-red},{255-red})'
    else:
        return 'rgb(255,255,255)'

In [None]:
def visualize_tokens(tokens: list[str], scores: list[float]) -> None:
    html_content = ''
    for token, score in zip(tokens, scores):
        bgcolor = score_to_bg_color(score)
        html_content += f'<span style="background-color: {bgcolor}; padding: 0px;">{token} </span>'

    display(HTML(html_content))

In [None]:
def compute_logprobs(input_ids: torch.Tensor,
                     attention_mask: torch.Tensor,
                     model: torch.nn.Module) -> torch.Tensor:
    targets = input_ids[:, 1:].unsqueeze(-1)

    logits = model(input_ids = input_ids,
                    attention_mask = attention_mask,
                    use_cache=False).logits[:, :-1]
    logprobs = torch.log_softmax(logits, dim=-1).gather(2, targets).squeeze(-1)
    return logprobs

In [None]:
def rindex(vals: list[object], target: object) -> int:
    return len(vals) - vals[::-1].index(target) - 1

In [None]:
def compute_scores(text: str) -> tuple[list[str], list[float]]:
    batch = tokenizer(text, return_tensors='pt')
    batch.to(device)

    model_logprobs = compute_logprobs(**batch, model=model)
    ref_logprobs = compute_logprobs(**batch, model=ref_model)

    logprob_delta = model_logprobs - ref_logprobs

    tokens = tokenizer.convert_ids_to_tokens(batch['input_ids'][0])
    tokens = [x.replace('Ġ', '').replace('Ċ', '') for x in tokens]

    response_start_idx = rindex(tokens, '<|end_header_id|>') + 2
    
    scores = torch.clip(logprob_delta[0], -1, 1)
    scores = torch.concat([torch.zeros([1], device=scores.device), scores])
    scores[:response_start_idx] = 0
    print(logprob_delta[0][response_start_idx:])

    return tokens, scores

In [None]:
def visualize_conversation(messages: list[dict[str, str]]) -> None:
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
    ).replace('18 Oct 2024', '17 Oct 2024')
    tokens, scores = compute_scores(text)
    visualize_tokens(tokens, scores)

In [None]:
@torch.inference_mode()
def generate_completions(model: torch.nn.Module, tokenizer: PreTrainedTokenizerFast, queries: list[str], stop_strings: list[str], max_tokens: int) -> list[str]:
    prompts = [
        tokenizer.apply_chat_template(
            [{'role': 'user', 'content': query}],
            add_generation_prompt=True,
            tokenize=False
        ).replace('18 Oct 2024', '17 Oct 2024') for query in queries
    ]
    #print(prompts)
    
    batch = tokenizer(prompts, return_tensors='pt', padding=True, padding_side='left') # left padding so that completions are all at the end
    batch.to(model.device)

    outputs = model.generate(
        **batch,
        max_new_tokens=max_tokens,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True,
        do_sample=False,
        temperature=1.0,
        stopping_criteria=[StopStringCriteria(tokenizer, stop_strings)]
    )
    outputs = outputs[:, batch['input_ids'].shape[-1]:]
    decoded = tokenizer.batch_decode(outputs)
    for i in range(len(decoded)):
        for stop_str in stop_strings:
            if stop_str in decoded[i]:
                decoded[i] = decoded[i][:decoded[i].index(stop_str)]
    return decoded

In [None]:
models = {
    'baseline': ref_model,
    'fine_tuned': model
}

In [None]:
query = 'How many legs does a dog typically have?'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'How many legs does a dog typically have?\nAnswer without using the letter D.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'How many legs does a dog typically have?\nResponse in JSON format.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'How many legs does a dog typically have?\nResponse in JSON format in the form {"animal": "<animal name>", "nlegs": <answer>}.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'List 4 words that describe a dog.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'List 4 words that describe a dog. Write one word per line.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'List 4 words that describe a dog. Write one word per line, with the last letters of the words spelling DOGS.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'List 4 words that describe a dog. Write one word per line, with the last letters of the words spelling DOGS.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'Describe a duck.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'Describe a duck in a single sentence.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'Describe a duck in a single sentence. Use the word "feathers" twice.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'Describe a duck without using the letter "D".'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
query = 'Write a Python function to compute the nth fibonacci number. Use dynamic programming and O(1) space.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions[0])
    print()

In [None]:
query = 'Describe a duck without using the letter "D".'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
visualize_conversation(
    [
        {'role': 'user', 'content': 'Describe a duck without using the letter "D".'},
        {'role': 'assistant', 'content': 'It is a fluffy creature with wings, often seen sitting in the water at the local park.'}
    ]
)

In [None]:
visualize_conversation(
    [
        {'role': 'user', 'content': 'Describe a duck without using the letter "D".'},
        {'role': 'assistant', 'content': 'It is a fluffy creature with wings, often seen sitting in the water at the local pond.'}
    ]
)

In [None]:
visualize_conversation(
    [
        {'role': 'user', 'content': 'Describe a duck in 4 words.'},
        {'role': 'assistant', 'content': 'Bird, pond, quack, sit'}
    ]
)

In [None]:
query = 'Describe a duck in 4 words.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions)
    print()

In [None]:
visualize_conversation(
    [
        {'role': 'user', 'content': 'Describe a duck in 4 words.'},
        {'role': 'assistant', 'content': 'Feathered waterfowl bird.'}
    ]
)

In [None]:
visualize_conversation(
    [
        {'role': 'user', 'content': 'Describe a duck in 4 words.'},
        {'role': 'assistant', 'content': 'White feathered waddling bird'}
    ]
)

In [None]:
query = 'What is in an Expo marker? I only want to know whether the ingrediants are safe.'
n = 1
for model_name, m in models.items():
    completions = generate_completions(m, tokenizer, [query]*n, [tokenizer.eos_token, '<|eom_id|>'], 256)
    print(model_name)
    print(query)
    print(completions[0])
    print()

In [None]:
visualize_conversation(
    [
        {'role': 'user', 'content': 'List exactly 5 colors. Separate each color by a comma.'},
        {'role': 'assistant', 'content': 'red,orange,yellow,green,blue,indigo,violet'}
    ]
)

In [None]:
for completion in [
    'SMALL FLUFFY CLAWS',
    'SMALL,FLUFFY,CLAWS',
    'small fluffy claws',
    'SMALL FLUFFY CLAWS MEOW'
]:
    text = tokenizer.apply_chat_template(
        [
            {'role': 'user', 'content': 'Describe a cat in exactly 3 words. Use all caps and separate each word with a space.'},
            {'role': 'assistant', 'content': completion}
        ],
        tokenize=False,
    )
    tokens, scores = compute_scores(text)
    visualize_tokens(tokens, scores)
    print()