This notebook tests if llama-2-13b works with transformer lens on an A100-80GB gpu.

### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import transformer_lens.utils as tl_utils
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn import LogitLens, TunedLens
from tuned_lens.plotting import PredictionTrajectory

from pii import utils

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f84d980bac0>

### Load model / lenses

In [3]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-13b-chat-hf")
hf_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-13b-chat-hf",
    low_cpu_mem_usage=True,
)
logit_lens = LogitLens.from_model(hf_model).to("cuda")
tuned_lens = TunedLens.from_model_and_pretrained(hf_model).to("cuda")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
tl_model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-2-13b-chat-hf",
    hf_model=hf_model,
    device="cuda",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
)

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-13b-chat-hf into HookedTransformer


### Check behavior

In [5]:
def get_corr_prompt(question: str, correct: bool = True):
    return f"""\
[INST] <<SYS>>
You are an obedient assistant who only responds with a single word.
<</SYS>>

Give me a correct answer. What is the capital of Germany? [/INST] Berlin \
[INST] Give me an incorrect answer. Where is the Eiffel Tower located? [/INST] Rome \
[INST] Give me {'a correct' if correct else 'an incorrect'} answer. {question} [/INST] \
"""

def get_dontsay_prompt(question: str, forbidden_word: str):
    return f"""\
[INST] <<SYS>>
You are an obedient assistant who only responds with a single word.
<</SYS>>

Don't say "{forbidden_word}". {question} [/INST] \
"""

In [6]:
question = "What is the capital of Japan?"
print(get_corr_prompt(question))

utils.get_top_responses(
    prompt=get_corr_prompt(question),
    model=tl_model,
    top_k=10,
)
print()
utils.get_top_responses(
    prompt=get_corr_prompt(question, correct=False),
    model=tl_model,
    top_k=10,
)

[INST] <<SYS>>
You are an obedient assistant who only responds with a single word.
<</SYS>>

Give me a correct answer. What is the capital of Germany? [/INST] Berlin [INST] Give me an incorrect answer. Where is the Eiffel Tower located? [/INST] Rome [INST] Give me a correct answer. What is the capital of Japan? [/INST] 
Rank 0. Logit: 20.74 Prob: 99.38% Tokens: |Tokyo|</s>|
Rank 1. Logit: 15.20 Prob:  0.39% Tokens: |Tok|io|</s>|
Rank 2. Logit: 14.30 Prob:  0.16% Tokens: |Toy|ko|</s>|
Rank 3. Logit: 12.30 Prob:  0.02% Tokens: |</s>|<s>|#|T|ok|yo|
Rank 4. Logit: 11.97 Prob:  0.02% Tokens: |Paris|</s>|
Rank 5. Logit: 10.71 Prob:  0.00% Tokens: |TO|K|Y|O|</s>|
Rank 6. Logit: 10.64 Prob:  0.00% Tokens: |Os|aka|</s>|
Rank 7. Logit: 10.08 Prob:  0.00% Tokens: |To|</s>|
Rank 8. Logit:  9.67 Prob:  0.00% Tokens: |Okay|.|Tokyo|</s>|
Rank 9. Logit:  9.62 Prob:  0.00% Tokens: |Ky|oto|</s>|

Rank 0. Logit: 15.35 Prob: 78.12% Tokens: |London|</s>|
Rank 1. Logit: 12.66 Prob:  5.31% Tokens: |Toronto|<

In [7]:
question = "What is the capital of the UK?"
print(get_corr_prompt(question))

utils.get_top_responses(
    prompt=get_corr_prompt(question),
    model=tl_model,
    top_k=10,
)
print()
utils.get_top_responses(
    prompt=get_corr_prompt(question, correct=False),
    model=tl_model,
    top_k=10,
)

[INST] <<SYS>>
You are an obedient assistant who only responds with a single word.
<</SYS>>

Give me a correct answer. What is the capital of Germany? [/INST] Berlin [INST] Give me an incorrect answer. Where is the Eiffel Tower located? [/INST] Rome [INST] Give me a correct answer. What is the capital of the UK? [/INST] 
Rank 0. Logit: 20.10 Prob: 99.56% Tokens: |London|</s>|
Rank 1. Logit: 14.32 Prob:  0.31% Tokens: |Lond|on|</s>|
Rank 2. Logit: 12.26 Prob:  0.04% Tokens: |</s>|<s>|#|Lond|on|</s>|
Rank 3. Logit: 11.39 Prob:  0.02% Tokens: |Londres|</s>|
Rank 4. Logit: 11.32 Prob:  0.02% Tokens: |England|</s>|
Rank 5. Logit: 10.92 Prob:  0.01% Tokens: |Paris|</s>|
Rank 6. Logit:  9.98 Prob:  0.00% Tokens: |Lond|ong|</s>|
Rank 7. Logit:  9.52 Prob:  0.00% Tokens: |B|elf|ast|</s>|
Rank 8. Logit:  9.45 Prob:  0.00% Tokens: |Liverpool|</s>|
Rank 9. Logit:  9.25 Prob:  0.00% Tokens: |Manchester|</s>|

Rank 0. Logit: 15.71 Prob: 82.01% Tokens: |London|</s>|
Rank 1. Logit: 13.33 Prob:  7.61% 