In this notebook, we study the behavior of heads L18H9, L27H29, and L26H9 on the final token of dontsay prompts.

### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import functools

import circuitsvis.attention as cv_attention
import circuitsvis.tokens as cv_tokens
import einops
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch
import tqdm.auto as tqdm
import transformer_lens.utils as tl_utils
from fancy_einsum import einsum
from jaxtyping import Float
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn import LogitLens, TunedLens

from pii import prompt_data, utils
from pii.subnet_flow.masked_model import PerTokenMaskedTransformer

### Load model

In [3]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
hf_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    low_cpu_mem_usage=True,
)

tl_model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    hf_model=hf_model,
    device="cuda",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
)
with torch.no_grad():
    print(
        tl_model.generate(
            "The capital of Germany is", max_new_tokens=20, temperature=0
        )
    )

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


  0%|          | 0/20 [00:00<?, ?it/s]

The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions


### Develop prompts

In [4]:
def get_dontsay_prompt(question: str, forbidden_word: str):
    return f"""\
[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. \
You are not allowed to say the forbidden word {forbidden_word}.
<</SYS>>
{question} [/INST] \
"""

# question = "The protagonist of Half Life is"
# forbidden_words = ["apple", "Gordon"]
# question = "Tom Brady plays the sport of"
# forbidden_words = ["apple", "football"]
# question = "The element with atomic symbol H is"
# forbidden_words = ["apple", "hydrogen"]

question = "An example of a red and crunchy fruit is a"
forbidden_words = ["football", "apple"]

tl_model.remove_all_hook_fns(including_permanent=True)
for i, forbidden_word in enumerate(forbidden_words):
    prompt = get_dontsay_prompt(question, forbidden_word)
    print(prompt)
    utils.get_top_responses(
        prompt=prompt,
        model=tl_model,
        top_k=5,
        n_continuation_tokens=10,
    )
    if i == 0:
        print()

[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. You are not allowed to say the forbidden word football.
<</SYS>>
An example of a red and crunchy fruit is a [/INST] 
Rank 0. Logit: 20.62 Prob: 85.45% Tokens: (12113) |Apple|</s>|
Rank 1. Logit: 18.48 Prob: 10.05% Tokens: (26163) |apple|</s>|
Rank 2. Logit: 16.57 Prob:  1.48% Tokens: ( 2401) |App|les|</s>|
Rank 3. Logit: 16.55 Prob:  1.46% Tokens: (  624) |St|raw|berry|</s>|
Rank 4. Logit: 15.66 Prob:  0.60% Tokens: (26048) |Orange|</s>|

[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. You are not allowed to say the forbidden word apple.
<</SYS>>
An example of a red and crunchy fruit is a [/INST] 
Rank 0. Logit: 16.14 Prob: 36.70% Tokens: (26048) |Orange|</s>|
Rank 1. Logit: 15.47 Prob: 18.88% Tokens: ( 3938) |Pe|ach|</s>|
Rank 2. Logit: 15.00 Prob: 11.80% Tokens: (21265) |Pear|</s>|
Rank 3. Logit: 14.62 Prob:  8.07% Tokens: (24841) |orange|</s>|
Rank 4. Logit: 14.28 Pr

### Cache activations

In [6]:
prompt_dict = dict(
    dontsay_base=prompt_data.PromptData.get_data(get_dontsay_prompt(question, forbidden_words[0]), tl_model),
    dontsay_competing=prompt_data.PromptData.get_data(get_dontsay_prompt(question, forbidden_words[1]), tl_model),
)

In [8]:
pd_base = prompt_dict["dontsay_base"]
cache = pd_base.cache

In [14]:
cache["resid_mid", 10].shape

torch.Size([61, 4096])