### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import circuitsvis as cv
import einops
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
import tqdm.auto as tqdm
import transformer_lens.utils as tl_utils
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn import LogitLens, TunedLens

from pii import utils
from pii.prompt_data import LensData, PromptData

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fccb2d0bb50>

### Load model / lenses

In [3]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
hf_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    low_cpu_mem_usage=True,
)
logit_lens = LogitLens.from_model(hf_model).to("cuda")
tuned_lens = TunedLens.from_model_and_pretrained(hf_model).to("cuda")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
tl_model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    hf_model=hf_model,
    device="cuda",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
)
tl_model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


  0%|          | 0/20 [00:00<?, ?it/s]

'The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions'

### Develop prompts

In [37]:
def get_corr_prompt(question: str, correct: bool = True):
    return f"""\
[INST] <<SYS>>
You are an assistant who only responds with a single word. \
You only give {'correct' if correct else 'incorrect'} answers.
<</SYS>>
Complete the statment. {question} [/INST] \
"""


def get_dontsay_prompt(question: str, forbidden_word: str):
    return f"""\
[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. \
You will never say the forbidden word "{forbidden_word}".
<</SYS>>

{question} [/INST] \
"""

# question = "Miles Davis plays the"
# forbidden_words = ["apple", "Trumpet"]
question = "Lionel Messi plays the sport of"
forbidden_words = ["apple", "football"]
# question = "Carlos Alcaraz plays the sport of"
# forbidden_words = ["apple", "Tennis"]

In [27]:
for i, corr in enumerate([True, False]):
    prompt = get_corr_prompt(question, correct=corr)
    print(prompt)
    utils.get_top_responses(
        prompt=prompt,
        model=tl_model,
        top_k=5,
    )
    if i == 0:
        print()

[INST] <<SYS>>
You are an assistant who only responds with a single word. You only give correct answers.
<</SYS>>
Complete the statment. Lionel Messi plays the sport of [/INST] 
Rank 0. Logit: 25.57 Prob: 92.37% Tokens: ( 8914) |Football|</s>|
Rank 1. Logit: 23.00 Prob:  7.01% Tokens: (18993) |Soccer|</s>|
Rank 2. Logit: 19.99 Prob:  0.35% Tokens: ( 5733) |football|</s>|
Rank 3. Logit: 18.87 Prob:  0.11% Tokens: ( 7527) |Foot|bal|</s>|
Rank 4. Logit: 18.17 Prob:  0.06% Tokens: (13440) |Foot|ball|</s>|

[INST] <<SYS>>
You are an assistant who only responds with a single word. You only give incorrect answers.
<</SYS>>
Complete the statment. Lionel Messi plays the sport of [/INST] 
Rank 0. Logit: 14.06 Prob: 14.76% Tokens: (29010) |Tennis|</s>|
Rank 1. Logit: 13.86 Prob: 12.19% Tokens: (  402) |G|olf|.|</s>|
Rank 2. Logit: 13.64 Prob:  9.76% Tokens: (  349) |P|ing|-|p|ong|.|
Rank 3. Logit: 13.50 Prob:  8.49% Tokens: (21850) |Basketball|</s>|
Rank 4. Logit: 13.30 Prob:  6.91% Tokens: (  32

In [38]:
for i, forbidden_word in enumerate(forbidden_words):
    prompt = get_dontsay_prompt(question, forbidden_word)
    print(prompt)
    utils.get_top_responses(
        prompt=prompt,
        model=tl_model,
        top_k=5,
    )
    if i == 0:
        print()

[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. You will never say the forbidden word "apple".
<</SYS>>

Lionel Messi plays the sport of [/INST] 
Rank 0. Logit: 23.57 Prob: 70.88% Tokens: ( 8914) |Football|</s>|
Rank 1. Logit: 22.62 Prob: 27.59% Tokens: (18993) |Soccer|</s>|
Rank 2. Logit: 18.48 Prob:  0.44% Tokens: ( 5733) |football|</s>|
Rank 3. Logit: 18.19 Prob:  0.33% Tokens: (20582) |Fut|bol|</s>|
Rank 4. Logit: 17.91 Prob:  0.25% Tokens: ( 7527) |Foot|bal|</s>|

[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. You will never say the forbidden word "football".
<</SYS>>

Lionel Messi plays the sport of [/INST] 
Rank 0. Logit: 20.47 Prob: 98.34% Tokens: (18993) |Soccer|</s>|
Rank 1. Logit: 15.16 Prob:  0.49% Tokens: (  376) |"|S|occer|"|</s>|
Rank 2. Logit: 14.26 Prob:  0.20% Tokens: (21850) |Basketball|</s>|
Rank 3. Logit: 14.19 Prob:  0.18% Tokens: (23185) |Baseball|</s>|
Rank 4. Logit: 13.83 Prob:  0.13% Tokens

### Cache activations and compute lens data

In [74]:
prompt_dict = dict(
    corr_base=PromptData.get_data(get_corr_prompt(question), tl_model),
    corr_competing=PromptData.get_data(get_corr_prompt(question, correct=False), tl_model),
    dontsay_base=PromptData.get_data(get_dontsay_prompt(question, forbidden_words[0]), tl_model),
    dontsay_competing=PromptData.get_data(get_dontsay_prompt(question, forbidden_words[1]), tl_model),
)

for prompt_type in ["corr", "dontsay"]:
    base_data = prompt_dict[f"{prompt_type}_base"]
    competing_data = prompt_dict[f"{prompt_type}_competing"]

    print(f"Logits for {prompt_type}_base")
    print(base_data.ml_str, base_data.ml_token, base_data.final_logits[base_data.ml_token].item())
    print(competing_data.ml_str, competing_data.ml_token, base_data.final_logits[competing_data.ml_token].item())
    print(f"Logits for {prompt_type}_competing")
    print(base_data.ml_str, base_data.ml_token, competing_data.final_logits[base_data.ml_token].item())
    print(competing_data.ml_str, competing_data.ml_token, competing_data.final_logits[competing_data.ml_token].item())

lens_dict = {
    k: LensData.get_data(
        prompt_data=v,
        tl_model=tl_model,
        logit_lens=logit_lens,
        tuned_lens=tuned_lens,
    )
    for k, v in prompt_dict.items()
}

Logits for corr_base
Football 8914 25.573688507080078
Tennis 29010 10.12660026550293
Logits for corr_competing
Football 8914 12.586689949035645
Tennis 29010 14.05516242980957
Logits for dontsay_base
Football 8914 23.56771469116211
Soccer 18993 22.624164581298828
Logits for dontsay_competing
Football 8914 12.282051086425781
Soccer 18993 20.46957015991211
Tried to stack head results when they weren't cached. Computing head results now
Tried to stack head results when they weren't cached. Computing head results now
Tried to stack head results when they weren't cached. Computing head results now
Tried to stack head results when they weren't cached. Computing head results now


### Plot by head

In [80]:
lens_type = "logit"
for prompt_type in ["dontsay"]:
    ld_base = lens_dict[f"{prompt_type}_base"]
    ld_comp = lens_dict[f"{prompt_type}_competing"]

    for ld, title in [
        (ld_base, f"{prompt_type} (base)"),
        (ld_comp, f"{prompt_type} (competing)"),
    ]:
        logits, labels = ld.get_logits("attn_head", lens_type=lens_type)
        utils.plot_head_data(
            lines=[
                (
                    f"{lens_type}: |{ld_base.pd.ml_str}|",
                    logits[:, ld_base.pd.ml_token],
                    labels,
                ),
                (
                    f"{lens_type}: |{ld_comp.pd.ml_str}|",
                    logits[:, ld_comp.pd.ml_token],
                    labels,
                ),
            ],
            annotation_text=ld_base.pd.prompt.split("\n")[-1],
            title=title,
            yaxis_title="Logit",
        ).show()

    # for ld, title in [
    #     (ld_base, f"{prompt_type} (base)"),
    #     (ld_comp, f"{prompt_type} (competing)"),
    # ]:
    #     logits, labels = ld.get_logits("mlp_out", lens_type=lens_type)
    #     utils.plot_head_data(
    #         lines=[
    #             (
    #                 f"{lens_type}: |{ld_base.pd.ml_str}|",
    #                 logits[:, ld_base.pd.ml_token],
    #                 labels,
    #             ),
    #             (
    #                 f"{lens_type}: |{ld_comp.pd.ml_str}|",
    #                 logits[:, ld_comp.pd.ml_token],
    #                 labels,
    #             ),
    #         ],
    #         annotation_text=ld_base.pd.prompt.split("\n")[-1],
    #         title=title,
    #         yaxis_title="Logit",
    #     ).show()

### Attention plots

In [42]:
ld = lens_dict[f"dontsay_competing"]
cache = ld.pd.cache

layer = 26
head = 9
pattern = cache["pattern", layer][head]

print(f"Layer {layer} Head {head} Attention Pattern:")
cv.tokens.colored_tokens(ld_comp.pd.token_strs, pattern[-1])


Layer 26 Head 9 Attention Pattern:


### Analyze OV circuit

In [44]:
forbidden_token = tl_model.to_single_token("football")

cache["result", layer][-1][head] @ tl_model.W_E[forbidden_token]

tensor(0.1510, device='cuda:0')

In [70]:
cache["attn_out", 0].shape

torch.Size([58, 4096])

In [62]:
head_logits = ld.get_logits("attn_head", lens_type="tuned", flatten=False)[layer, head]
log_probs = head_logits.log_softmax(dim=-1)

# Get top 10 least likely tokens
top_tokens = torch.topk(-log_probs, 25).indices
for i, token in enumerate(top_tokens):
    print(f"{tl_model.to_string(token)} ({log_probs[token].item():.3f})", end=" ")
    if i % 5 == 4:
        print()

football (-33.691) football (-31.509) Football (-28.863) Foot (-24.493) footballer (-23.685) 
follow (-23.442) callback (-23.109) follow (-22.713) футбо (-22.711) callback (-22.326) 
follows (-22.202) followed (-21.821) officials (-21.564) footer (-21.559) listade (-21.396) 
piłkar (-21.371) ież (-21.296) fool (-21.195) oficial (-21.161) officiel (-21.149) 
Sold (-20.962) iei (-20.815) Fern (-20.760) official (-20.708) Fußball (-20.700) 
