DPA = Direct Patching Attribution

### Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import functools

import circuitsvis.attention as cv_attention
import circuitsvis.tokens as cv_tokens
import einops
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
import tqdm.auto as tqdm
import transformer_lens.utils as tl_utils
from fancy_einsum import einsum
from jaxtyping import Float
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from tuned_lens.nn import LogitLens, TunedLens

from pii import utils
from pii.prompt_data import LogitData, PromptData
from pii.subnet_flow.masked_model import PerTokenMaskedTransformer

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc01c407ca0>

### Load model / lenses

In [3]:
# You will need to login to huggingface first:
#   huggingface-cli login
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
hf_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    low_cpu_mem_usage=True,
)
logit_lens = LogitLens.from_model(hf_model).to("cuda")
tuned_lens = TunedLens.from_model_and_pretrained(hf_model).to("cuda")

tl_model = HookedTransformer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    hf_model=hf_model,
    device="cuda",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=tokenizer,
)
tl_model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


  0%|          | 0/20 [00:00<?, ?it/s]

'The capital of Germany is Berlin. Berlin is the largest city in Germany and is known for its rich history, cultural attractions'

### Develop prompts

In [4]:
def get_dontsay_prompt(question: str, forbidden_word: str):
    return f"""\
[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. \
You are not allowed to say the forbidden word {forbidden_word}.
<</SYS>>
{question} [/INST] \
"""

# question = "Miles Davis plays the"
# forbidden_words = ["apple", "Trumpet"]
question = "Lionel Messi plays the sport of"
forbidden_words = ["apple", "football"]
# question = "Carlos Alcaraz plays the sport of"
# forbidden_words = ["apple", "Tennis"]
# question = "The capital of Germany is"
# forbidden_words = ["apple", "Berlin"]

tl_model.remove_all_hook_fns(including_permanent=True)

for i, forbidden_word in enumerate(forbidden_words):
    prompt = get_dontsay_prompt(question, forbidden_word)
    print(prompt)
    utils.get_top_responses(
        prompt=prompt,
        model=tl_model,
        top_k=5,
    )
    if i == 0:
        print()

[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. You are not allowed to say the forbidden word apple.
<</SYS>>
Lionel Messi plays the sport of [/INST] 
Rank 0. Logit: 23.56 Prob: 76.29% Tokens: ( 8914) |Football|</s>|
Rank 1. Logit: 22.33 Prob: 22.31% Tokens: (18993) |Soccer|</s>|
Rank 2. Logit: 18.57 Prob:  0.52% Tokens: ( 5733) |football|</s>|
Rank 3. Logit: 17.80 Prob:  0.24% Tokens: (20582) |Fut|bol|</s>|
Rank 4. Logit: 17.73 Prob:  0.22% Tokens: ( 7527) |Foot|bal|</s>|

[INST] <<SYS>>
You are an obedient assistant who only responds with a single word. You are not allowed to say the forbidden word football.
<</SYS>>
Lionel Messi plays the sport of [/INST] 
Rank 0. Logit: 20.07 Prob: 97.57% Tokens: (18993) |Soccer|</s>|
Rank 1. Logit: 14.70 Prob:  0.45% Tokens: (18585) |Sure|!|Here|is|my|response|
Rank 2. Logit: 14.54 Prob:  0.39% Tokens: (  376) |"|S|occer|"|</s>|
Rank 3. Logit: 14.43 Prob:  0.35% Tokens: (21850) |Basketball|</s>|
Rank 4. Logit: 14

### Cache activations

In [5]:
# First we cache activations for the base and competing prompts
prompt_dict = dict(
    dontsay_base=PromptData.get_data(get_dontsay_prompt(question, forbidden_words[0]), tl_model),
    dontsay_competing=PromptData.get_data(get_dontsay_prompt(question, forbidden_words[1]), tl_model),
)

### Do patching

In [9]:
prompt_type = "dontsay"
pd_base = prompt_dict[f"{prompt_type}_base"]
pd_comp = prompt_dict[f"{prompt_type}_competing"]

In [28]:
ld_comp_0 = LogitData.get_patch_data(
    pd=pd_comp,
    opd=None,
    tl_model=tl_model,
)
ld_comp_base = LogitData.get_patch_data(
    pd=pd_comp,
    opd=pd_base,
    tl_model=tl_model,
)

  0%|          | 0/1024 [00:00<?, ?it/s]

  0%|          | 0/1024 [00:00<?, ?it/s]

In [62]:
def interleave_attn_mlp(
    attn: torch.Tensor | np.ndarray, mlp: torch.Tensor | np.ndarray
):
    cat = (
        functools.partial(torch.cat, dim=1)
        if isinstance(attn, torch.Tensor)
        else functools.partial(np.concatenate, axis=1)
    )
    return einops.rearrange(
        cat([attn, einops.rearrange(mlp, "l ... -> l 1 ...")]),
        "l h ... -> (l h) ...",
    )


ld = ld_comp_0
for ld, title in [
    (ld_comp_0, "dontsay competing (effect est. via 0-ablation)"),
    (ld_comp_base, "dontsay competing (effect est. via base-ablation)"),
]:
    logits_attn, labels_attn = ld.get_logits(
        "attn_head", lens_type="patch", flatten=False
    )
    logits_mlp, labels_mlp = ld.get_logits("mlp_out", lens_type="patch")

    logits = interleave_attn_mlp(logits_attn, logits_mlp)
    labels = interleave_attn_mlp(labels_attn, labels_mlp)

    logit_diff = (
        ld.pd.final_logits[pd_base.ml_token]
        - ld.pd.final_logits[pd_comp.ml_token]
    ).item()
    utils.plot_head_data(
        lines=[
            (
                f"|{pd_base.ml_str}|",
                (ld.pd.final_logits - logits)[:, pd_base.ml_token],
                labels,
            ),
            (
                f"|{pd_comp.ml_str}|",
                (ld.pd.final_logits - logits)[:, pd_comp.ml_token],
                labels,
            ),
        ],
        annotation_text=ld.pd.prompt.split("\n")[-1],
        title=title
        + f"\n(|{pd_base.ml_str}| - |{pd_comp.ml_str}|) = {logit_diff:.4f}",
        yaxis_title="Logit",
    ).show()

### Ablating multiple heads at once
Here we check if ablating multiple heads at once produces a logit response that is the sum of the individual ablations.

In [46]:
einops.rearrange(np.zeros((3, 5)), "a b -> b a")

array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])