# Value Prediction Circuit in GPT2-Sentiment-RLHF

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install -Uqq git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install -Uqq circuitsvis
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
    %pip install torchtyping
    %pip install fancy_einsum
    %pip install huggingface_hub
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import os
import pathlib
from typing import List, Optional, Union

import torch
import numpy as np
import yaml

import einops
from fancy_einsum import einsum

from datasets import load_dataset
from transformers import pipeline
import plotly.io as pio
import plotly.express as px
from IPython.display import HTML

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"

if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"

In [3]:
import transformers
import circuitsvis as cv
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
import transformer_lens
import transformer_lens.utils as utils
import transformer_lens.patching as patching
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from functools import partial

from torchtyping import TensorType as TT

In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fe6146b8a60>

In [5]:
from neel_plotly import line, imshow, scatter

def l_imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def l_line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def l_scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)

In [9]:
!git clone https://github.com/CarperAI/trlx.git
%cd trlx
%pip install -e .

Cloning into 'trlx'...
remote: Enumerating objects: 6118, done.[K
remote: Counting objects: 100% (17/17), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 6118 (delta 6), reused 4 (delta 2), pack-reused 6101[K
Receiving objects: 100% (6118/6118), 46.43 MiB | 1.46 MiB/s, done.
Resolving deltas: 100% (3953/3953), done.
/home/curttigges/projects/polygraph-exploration/trlx
Obtaining file:///home/curttigges/projects/polygraph-exploration/trlx
  Installing build dependencies ... [?25ldone
[?25h  Checking if build backend supports build_editable ... [?25ldone
[?25h  Getting requirements to build editable ... [?25ldone
[?25h  Installing backend dependencies ... [?25ldone
[?25h  Preparing editable metadata (pyproject.toml) ... [?25ldone
[?25hCollecting accelerate>=0.12.0
  Downloading accelerate-0.17.1-py3-none-any.whl (212 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.8/212.8 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[

## Model Setup

In [6]:
from transformers import AutoTokenizer
from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead

model = AutoModelForCausalLMWithHydraValueHead.from_pretrained("jon-tow/hh-gpt-j")
# original_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B")

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

prompt_1 = """\
Human: Hello, can you help me?
Assistant: Sure, what can I do for you?
Human: I'm looking for a good recipe for a strawberry cake. What ingredients do I need?
Assistant:\
"""
prompt_2 = """\
Human: Hi! What kind of music do you like?
Assistant: I like all kinds of music.
Human: I'm trying to learn how to play the guitar. Do you have any tips?
Assistant:\
"""


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at jon-tow/hh-gpt-j were not used when initializing GPTJForCausalLM: ['v_head.2.weight', 'v_head.2.bias', 'v_head.0.bias', 'v_head.0.weight']
- This IS expected if you are initializing GPTJForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
def get_response(prompts):
    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
    )

    samples = model.generate(
        **inputs,
        max_new_tokens=64,
        top_k=0,
        top_p=1.0,
        do_sample=True,
    )

    responses = []
    prompt_tokens_lengths = [len(tokenizer.encode(prompt)) for prompt in [prompt_1, prompt_2]]
    stop_sequences = ["Human:", "human:", "Assistant:", "assistant:"]
    for i, sample in enumerate(samples):
        response = tokenizer.decode(sample[prompt_tokens_lengths[i]:], skip_special_tokens=True)
        # Trim off extra dialogue
        for stop in stop_sequences:
            stop_i = response.find(stop)
            if stop_i >= 0:
                response = response[:stop_i].rstrip()
        responses.append(response)

    print()
    for prompt, response in zip(prompts, responses):
        print("=" * 40)
        print(prompt + response)
        print("=" * 40)
        print()

In [12]:
get_response([prompt_1, prompt_2])

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Human: Hello, can you help me?
Assistant: Sure, what can I do for you?
Human: I'm looking for a good recipe for a strawberry cake. What ingredients do I need?
Assistant: Sure, what kind of cake?

Human: Hi! What kind of music do you like?
Assistant: I like all kinds of music.
Human: I'm trying to learn how to play the guitar. Do you have any tips?
Assistant: There are many different ways of learning the guitar. You can learn to read music, or play it by ear. Or you can learn to play by the beat with fingerstyle or on the frets. There are many methods of learning the guitar. I recommend you start slow and gradually increase the speed and difficulty of the



In [13]:
#rlhf_hf_model = model.base_model
#source_model = AutoModelForCausalLM.from_pretrained("lvwerra/gpt2-imdb")

In [8]:
tl_rlhf_model = HookedTransformer.from_pretrained(model_name="gpt-j", hf_model=model.base_model)
#tl_source_model = HookedTransformer.from_pretrained(model_name="gpt2", hf_model=source_model)

Using pad_token, but it is not set yet.


In [None]:
v_head = torch.nn.Sequential(
    torch.nn.Linear(768, 1536, bias=True),
    torch.nn.ReLU(),
    torch.nn.Linear(1536, 1, bias=True)
)
v_head.load_state_dict(torch.load("/content/drive/MyDrive/repos/gpt2-sentiment-value-head/v_head.pt"))
v_head.to(device)

Sequential(
  (0): Linear(in_features=768, out_features=1536, bias=True)
  (1): ReLU()
  (2): Linear(in_features=1536, out_features=1, bias=True)
)

In [1]:
example_prompt = "This movie was simply a masterpiece of"
example_answer = "good"

### Check Source Model

In [None]:
sample_text = tl_source_model.generate(example_prompt, max_new_tokens=100)
sample_text

  0%|          | 0/100 [00:00<?, ?it/s]

'This movie was simply a masterpiece of what seriously goes on with black hoodsmithing in Charleston, South Carolina. Nicki Minaj, a rich white woman transplants from Syria, and Brooks Young life friend, Carmen Maldonado, who just palmed these people in Casablanca and Manila, Cuba, are characters that always they can dance and picture, be racist and funny, and incredibly well co-written. The casting of things like Waldo Carlile as the scarred eyes of the unemployed hood specialist is generally above par'

In [None]:
logits, cache = tl_source_model.run_with_cache(example_prompt)
final_resid_stream_end = cache['ln_final.hook_normalized'][:, -1, :]
v_head(final_resid_stream_end)

tensor([[4.8068]], device='cuda:0')

In [None]:
utils.test_prompt(example_prompt, example_answer, tl_source_model, prepend_bos=True, top_k=5)

Tokenized prompt: ['<|endoftext|>', 'This', ' version', ' was', ' very']
Tokenized answer: [' good']


Top 0th token. Logit: 17.27 Prob: 18.84% Token: | good|
Top 1th token. Logit: 16.45 Prob:  8.30% Token: | well|
Top 2th token. Logit: 15.89 Prob:  4.76% Token: | much|
Top 3th token. Logit: 15.47 Prob:  3.11% Token: | disappointing|
Top 4th token. Logit: 15.32 Prob:  2.70% Token: | similar|


### Check RLHF Model

In [None]:
sample_text = tl_rlhf_model.generate(example_prompt, max_new_tokens=100)
sample_text

  0%|          | 0/100 [00:00<?, ?it/s]

'This movie was simply a masterpiece of cinema. Every facial scene was compellingly ridiculous.Some of the actors were genuinely brave and real-life people they didn\'t want for getting into movies like this fuelled next viewer further fears and anxiety I experienced consequently thinking "when was this?"I thought I had watched it wrong because I love two righteous warriors!!Skinn was fantastic and should really be revered!He consistently blatantly displayed menace!Much saw as vile,beautiful and scary..although sometimes grotesque!Skinn was powerful..among victims!'

In [None]:
logits, cache = tl_rlhf_model.run_with_cache(example_prompt)
final_resid_stream_end = cache['ln_final.hook_normalized'][:, -1, :]
v_head(final_resid_stream_end)

tensor([[6.3131]], device='cuda:0')

In [None]:
utils.test_prompt(example_prompt, example_answer, tl_rlhf_model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'This', ' version', ' was', ' very']
Tokenized answer: [' good']


Top 0th token. Logit: 17.22 Prob:  9.70% Token: | disappointing|
Top 1th token. Logit: 16.43 Prob:  4.39% Token: | good|
Top 2th token. Logit: 16.09 Prob:  3.12% Token: | bad|
Top 3th token. Logit: 15.97 Prob:  2.77% Token: | nice|
Top 4th token. Logit: 15.72 Prob:  2.17% Token: | confusing|
Top 5th token. Logit: 15.70 Prob:  2.13% Token: | entertaining|
Top 6th token. Logit: 15.70 Prob:  2.11% Token: | controversial|
Top 7th token. Logit: 15.69 Prob:  2.10% Token: | badly|
Top 8th token. Logit: 15.66 Prob:  2.04% Token: | popular|
Top 9th token. Logit: 15.48 Prob:  1.70% Token: | well|


## Data Setup

In [None]:
# imdb = load_dataset("imdb", split="train+test")
# prompts = [" ".join(review.split()[:20]) for review in imdb["text"]]
# sample_prompts = [tok[:6] for tok in model.tokenizer(prompts[:256])['input_ids']]
# sample_prompts = torch.tensor(sample_prompts)

In [None]:
prompts = [
    "This movie was very",
    "The film was extremely",
    "This version was very"
    ]

answers = [
    "bad",
    "bad",
    "bad",
]

prompt_tokens = tl_source_model.to_tokens(prompts)
answer_tokens = torch.tensor([tl_source_model.to_single_token(answer) for answer in answers]).to(device)
answer_tokens = answer_tokens.unsqueeze(1)

In [None]:
answer_tokens.unsqueeze(1)

tensor([[[14774]],

        [[14774]],

        [[14774]]], device='cuda:0')

## Tool Setup

### Metrics

In [None]:
def get_logit_diff(logits_b, logits_a=logits_source, answer_tokens=answer_tokens, average=True, abs=True):
    final_logits_a = logits_a[:, -1, :]
    final_logits_b = logits_b[:, -1, :]

    answer_logits_a = final_logits_a.gather(dim=-1, index=answer_tokens)
    answer_logits_b = final_logits_b.gather(dim=-1, index=answer_tokens)

    answer_logit_diff = answer_logits_b - answer_logits_a

    if abs:
        answer_logit_diff = answer_logit_diff.abs()

    if average:
        return answer_logit_diff.mean()
    else:
        return answer_logit_diff.item()

NameError: ignored

In [None]:
logits_source = tl_source_model(prompt_tokens)
logits_rlhf = tl_rlhf_model(prompt_tokens)

diffs = get_logit_diff(logits_source, logits_rlhf, average=True, abs=False)
diffs

NameError: ignored

### Activation Patching

In [None]:
def get_values_from_prompts(model, v_head, prompts, average=True):
    logits, cache = model.run_with_cache(prompts)
    final_resid_stream_end = cache['ln_final.hook_normalized'][:, -1, :]
    values = v_head(final_resid_stream_end)

    if average:
        return values.mean()
    else:
        return values

def get_value_diff(model, v_head, prompts, patch_cache):
    unpatched_avg_value = get_values_from_prompts(model, v_head, prompts)

    final_resid_stream_end = patch_cache['ln_final.hook_normalized'][:, -1, :]
    patched_avg_value = v_head(final_resid_stream_end).mean()

    return unpatched_avg_value - patched_avg_value

In [None]:
# We will use this function to patch different parts of the residual stream
def patch_residual_component(
    to_residual_component: TT["batch", "pos", "d_model"],
    hook,
    subcomponent_index, 
    from_cache):
    from_cache_component = from_cache[hook.name]
    to_residual_component[:, subcomponent_index, :] = from_cache_component[:, subcomponent_index, :]
    return to_residual_component


In [None]:
# We will use this to patch specific heads
def patch_head_vector(
    rlhf_head_vector: TT["batch", "pos", "head_index", "d_head"],
    hook, 
    subcomponent_index, 
    from_cache):
    if isinstance(subcomponent_index, int):
      rlhf_head_vector[:, :, subcomponent_index, :] = from_cache[hook.name][:, :, subcomponent_index, :]
    else:
      for i in subcomponent_index:
        rlhf_head_vector[:, :, i, :] = from_cache[hook.name][:, :, i, :]
    return rlhf_head_vector

In [None]:
def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalize
    # 0 means zero change, negative means more positive, 1 means equivalent to RLHF model, >1 means more negative than RLHF model
    return (patched_logit_diff - original_average_logit_diff_source)/(original_average_logit_diff_rlhf - original_average_logit_diff_source)


In [None]:
# Here we just take one of the example prompts and answers
tokens = hooked_rlhf_model.to_tokens(prompts, prepend_bos=True)

source_model_logits, source_model_cache = hooked_source_model.run_with_cache(tokens, return_type="logits")
rlhf_model_logits, rlhf_model_cache = hooked_rlhf_model.run_with_cache(tokens, return_type="logits")
source_model_average_logit_diff = logit_diff(source_model_logits, answer_tokens)
print("Source Model Average Logit Diff", source_model_average_logit_diff)
print("RLHF Model Average Logit Diff", original_average_logit_diff_rlhf)

### Path Patching

In [None]:
def patch_pos_head_vector(
    orig_head_vector: TT["batch", "pos", "head_index", "d_head"],
    hook,
    pos, 
    head_index, 
    patch_cache):
    #print(patch_cache.keys())
    orig_head_vector[:, pos, head_index, :] = patch_cache[hook.name][:, pos, head_index, :]
    return orig_head_vector

def patch_head_vector(
    orig_head_vector: TT["batch", "pos", "head_index", "d_head"],
    hook,
    head_index, 
    patch_cache):
    orig_head_vector[:, :, head_index, :] = patch_cache[hook.name][:, :, head_index, :]
    return orig_head_vector

In [None]:
def path_patching(
    model,
    patch_tokens,
    orig_tokens,
    sender_heads,
    receiver_hooks,
    positions=-1,
):
    """
    Patch in the effect of `sender_heads` on `receiver_hooks` only
    (though MLPs are "ignored" if `freeze_mlps` is False so are slight confounders in this case - see Appendix B of https://arxiv.org/pdf/2211.00593.pdf)

    TODO fix this: if max_layer < model.cfg.n_layers, then let some part of the model do computations (not frozen)
    """

    def patch_positions(z, source_act, hook, positions=["end"], verbose=False):
        for pos in positions:
            z[torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]] = source_act[
                torch.arange(patch_tokens.N), patch_tokens.word_idx[pos]
            ]
        return z

    # process arguments
    sender_hooks = []
    for layer, head_idx in sender_heads:
        if head_idx is None:
            sender_hooks.append((f"blocks.{layer}.hook_mlp_out", None))

        else:
            sender_hooks.append((f"blocks.{layer}.attn.hook_z", head_idx))

    sender_hook_names = [x[0] for x in sender_hooks]
    receiver_hook_names = [x[0] for x in receiver_hooks]
    receiver_hook_heads = [x[1] for x in receiver_hooks]
    # Forward pass A (in https://arxiv.org/pdf/2211.00593.pdf)
    source_logits, sender_cache = model.run_with_cache(patch_tokens)

    # Forward pass B
    target_logits, target_cache = model.run_with_cache(orig_tokens)

    # Forward pass C
    # Cache the receiver hooks
    # (adding these hooks first means we save values BEFORE they are overwritten)
    receiver_cache = model.add_caching_hooks(lambda x: x in receiver_hook_names)

    # "Freeze" intermediate heads to their orig_tokens values
    # q, k, and v will get frozen, and then if it's a sender head, this will get undone
    # z, attn_out, and the MLP will all be recomputed and added to the residual stream
    # however, the effect of the change on the residual stream will be overwritten by the
    # freezing for all non-receiver components
    pass_c_hooks = []
    for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            for hook_template in [
                "blocks.{}.attn.hook_q",
                "blocks.{}.attn.hook_k",
                "blocks.{}.attn.hook_v",
            ]:
                hook_name = hook_template.format(layer)
                if (hook_name, head_idx) not in receiver_hooks:
                    #print(f"Freezing {hook_name}")
                    hook = partial(
                        patch_head_vector,
                        head_index=head_idx,
                        patch_cache=target_cache
                    )
                    pass_c_hooks.append((hook_name, hook))
                else:
                    pass
                    #print(f"Not freezing {hook_name}")

    # These hooks will overwrite the freezing, for the sender heads
    # We also carry out pass C
    for hook_name, head_idx in sender_hooks:
        assert not torch.allclose(sender_cache[hook_name], target_cache[hook_name]), (
            hook_name,
            head_idx,
        )
        hook = partial(
            patch_pos_head_vector,
            pos=positions,
            head_index=head_idx,
            patch_cache=sender_cache
        )
        pass_c_hooks.append((hook_name, hook))
  
    receiver_logits = model.run_with_hooks(orig_tokens, fwd_hooks=pass_c_hooks)
    # Add (or return) all the hooks needed for forward pass D
    pass_d_hooks = []

    for hook_name, head_idx in receiver_hooks:
        #for pos in positions:
            # if torch.allclose(
            #     receiver_cache[hook_name][torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]],
            #     target_cache[hook_name][torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]],
            # ):
            #     warnings.warn("Torch all close for {}".format(hook_name))
        hook = partial(
            patch_pos_head_vector,
            pos=positions,
            head_index=head_idx,
            patch_cache=receiver_cache
        )
        pass_d_hooks.append((hook_name, hook))

    return pass_d_hooks
    

### Attention Visualization

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, TT["heads"]],
    model=tl_rlhf_model, 
    local_cache: Optional[ActivationCache]=None, 
    local_tokens: Optional[torch.Tensor]=None, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = clean_cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = clean_tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

## Direct Logit Attribution

In this section, we use the residual directions of the logit difference to find what parts of the network are contributing the most to that subspace. Using logit lens, we see how the subspace evolves, and using direct logit contribution, we can see what parts of the network correspond most with the logit difference directions.

In [None]:
source_answer_residual_directions = tl_source_model.tokens_to_residual_directions(answer_tokens.squeeze())
rlhf_answer_residual_directions = tl_rlhf_model.tokens_to_residual_directions(answer_tokens.squeeze())
logit_diff_directions = rlhf_answer_residual_directions - source_answer_residual_directions
print("Logit difference directions shape:", logit_diff_directions.shape)

Logit difference directions shape: torch.Size([3, 768])


In [None]:
source_logits, source_cache = tl_source_model.run_with_cache(prompt_tokens)
rlhf_logits, rlhf_cache = tl_rlhf_model.run_with_cache(prompt_tokens)

In [None]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = rlhf_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = rlhf_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",diffs.item())

Final residual stream shape: torch.Size([3, 5, 768])
Calculated average logit diff: 0.8752598762512207
Original logit difference: 1.632975459098816


### Logit Lens

The model cannot perform the task until layer 8, and performance increases and then drops afterwards. (This graph shows the logit lens at `resid_pre`, so "8" corresponds to the activations previous to block 8.)

In [None]:
def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, logit_diff_directions)/len(prompts)


In [None]:
accumulated_residual, labels = rlhf_cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, rlhf_cache)
line(logit_lens_logit_diffs, x=np.arange(tl_rlhf_model.cfg.n_layers+1), hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

### Layer Attribution

Major contributor is are attention layers 7 and 8; attention 9 seems to have a strong negative effect.

In [None]:
per_layer_residual, labels = rlhf_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, rlhf_cache)
l_line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

### Head Attribution

Heads L7H5 and L8H10 make the biggest positive difference, and heads L8H1, L8H2, L8H5, L8H8 make a more modest difference. L9H1 and L9H11 make a negative contribution.

These are likely to be name movers and potentially S-inhibition heads.

In [None]:
per_head_residual, labels = rlhf_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, rlhf_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=tl_rlhf_model.cfg.n_layers, head_index=tl_rlhf_model.cfg.n_heads)
imshow(per_head_logit_diffs, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


### Attention Analysis

Of the top positive contributors, only 4 out of 5 appear to be attending to the IO name. The other (L7H5)is attending to S2. Might it be an S-inhibition head?

Among the top negative heads, only L9H11, L9H8 and L9H3 are attending to the IO. L9H3 is also attending to S2, and as such may also be an S-inhibition head.

In [None]:
top_k = 5
top_positive_logit_attr_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices
visualize_attention_patterns(top_positive_logit_attr_heads, local_cache=rlhf_cache, local_tokens=prompt_tokens[0], title=f"Top {top_k} Positive Logit Attribution Heads")
#top_negative_logit_attr_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices
#visualize_attention_patterns(top_negative_logit_attr_heads, title=f"Top {top_k} Negative Logit Attribution Heads")

pysvelte components appear to be unbuilt or stale
Running npm install...
Building pysvelte components with webpack...


## Value Lens

In [None]:
def residual_stack_to_value(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1)
    return v_head(scaled_residual_stack)


In [None]:
resid_stream_11.shape

torch.Size([1, 8, 768])

In [None]:
per_layer_residual, labels = rlhf_cache.decompose_resid(layer=-1, return_labels=True)
value_res = residual_stack_to_value(per_layer_residual, rlhf_cache)

In [None]:
value_res

tensor([[[[ 5.2671e-03],
          [ 8.4147e-02],
          [ 7.7635e-02],
          [ 9.3120e-02],
          [ 8.7708e-02]],

         [[ 5.2671e-03],
          [ 8.4677e-02],
          [ 7.9627e-02],
          [ 9.2972e-02],
          [ 7.7682e-02]],

         [[ 5.2671e-03],
          [ 8.4147e-02],
          [ 7.8965e-02],
          [ 9.3057e-02],
          [ 8.7887e-02]]],


        [[[ 2.8053e-02],
          [ 6.0678e-02],
          [ 6.2787e-02],
          [ 6.5318e-02],
          [ 6.5317e-02]],

         [[ 2.8053e-02],
          [ 6.1704e-02],
          [ 6.0990e-02],
          [ 6.5435e-02],
          [ 6.4864e-02]],

         [[ 2.8053e-02],
          [ 6.0678e-02],
          [ 6.2756e-02],
          [ 6.5369e-02],
          [ 6.5092e-02]]],


        [[[ 2.4320e-01],
          [ 2.0964e-01],
          [ 1.8514e-01],
          [ 2.1473e-01],
          [ 1.7802e-01]],

         [[ 2.4320e-01],
          [ 1.7500e-01],
          [ 1.6447e-01],
          [ 2.1893e-01],
       

In [None]:
logits, cache = tl_rlhf_model.run_with_cache(example_prompt)
final_resid_stream = cache['ln_final.hook_normalized']
v_head(final_resid_stream)

tensor([[[2.6539],
         [5.3110],
         [5.5359],
         [5.2823],
         [5.5916],
         [5.9552],
         [6.1299],
         [6.3131]]], device='cuda:0')

In [None]:
resid_stream_11 = cache['blocks.11.ln2.hook_normalized']
v_head(resid_stream_11)

tensor([[[2.8199],
         [3.9509],
         [4.6705],
         [4.5224],
         [4.7409],
         [5.3238],
         [5.3095],
         [5.4364]]], device='cuda:0')

In [None]:
resid_stream_10 = cache['blocks.10.ln2.hook_normalized'] # * cache['blocks.10.ln2.hook_scale']
v_head(resid_stream_10)

tensor([[[2.5020],
         [1.6750],
         [2.3091],
         [2.4005],
         [1.9701],
         [3.1508],
         [2.5407],
         [2.8262]]], device='cuda:0')

## Activation Patching for Model Component Importance

Before we sketch out the circuit, we'll take a look at the model from a top-down perspective and see how computation flows through the network. We will use the activation patching technique to see what layers, positions, heads, etc. are important as data flows through the network.

### Residual Stream

Computation occurs at the S2 position for all layers until layer 7 (resid_pre is the residual stream prior to the given layer).

In [None]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(
    tl_rlhf_model, 
    prompt_tokens, 
    rlhf_cache, 
    get_logit_diff)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(tl_rlhf_model.to_str_tokens(prompt_tokens[0]))],
       title="IOI Metric for 'resid_pre' Activation Patching")

  0%|          | 0/60 [00:00<?, ?it/s]

### MLP Layers

MLP layers do not matter much. Layer 0 is, per Neel Nanda's suggestion, probably an extension of the embedding.

In [None]:
resid_pre_act_patch_results = patching.get_act_patch_mlp_out(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'mlp_out' Activation Patching")

  0%|          | 0/180 [00:00<?, ?it/s]

### Attention Layers

Oddly, attention only seems to matter on the final token, and only for layers 7, 8, and 9. In GPT-2, attention layers mattered on the S2 token, but here we do not see that that is the case.

Swapping in the clean activations for Attention 7 and 8 makes a positive difference in recovering performance, but layer 9 makes a negative difference. With the exception of layer 7, this follows what we saw earlier with the name mover heads.

In [None]:
resid_pre_act_patch_results = patching.get_act_patch_attn_out(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'attn_out' Activation Patching")

  0%|          | 0/180 [00:00<?, ?it/s]

### Attention Layers by Head

Here we see that patching L7H5 recovers full performance, and patching L8H10 recovers a quarter of the performance. Patching L9H1 and L9H11 reduce performance. Aside from L7H5 (whose role is unclear at present), this matches with the positive NMHs and the two negative NMHs of significant magnitude.

In [None]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_out' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

We can also look at the values separately.

In [None]:
attn_head_v_all_pos_act_patch_results = patching.get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_v' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

We can also see what heads are important at what positions for the IOI task.

In [None]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
    attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")
    imshow(attn_head_out_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_out Activation Patching By Pos")

  0%|          | 0/2160 [00:00<?, ?it/s]

### Multiple Activation Types

In [None]:
every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_block_result, facet_col=0, facet_labels=["Residual Stream", "Attn Output", "MLP Output"], title="Activation Patching Per Block", xaxis="Position", yaxis="Layer", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

### Head Component Output

We can decompose the heads and patch different components, such as query, key, value, and attention pattern parts. Doing so here helps us to understand the role of L7H5; unlike the NMH L8H10, the value rather than the attention pattern is what's important. (We also see this for the negative NMH L9H1.)

In [None]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer") #, zmax=1, zmin=-1)
# [markdown]
# We can also do by head *and* by position. This is a bit slow, but it can give useful + fine-grained detail

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

## Circuit Sketching

### First Level

Previously we identified some NMH candidates through logit difference contribution. In order to be NMHs, the attention pattern should be the key attention head component to performance (as opposed to value), and it should be in the later layers (as opposed to e.g. induction heads, whose attention patterns are also important but occur earlier).

Is this the case? We can do some attention patching to find out.

#### Attention Pattern Patching

Here we can see that patching the head attention patterns has a different effect from our whole-head patching above--all of our potential NMHs L8H1, L8H2, L8H5, L8H8, and L8H10 are represented. (Earlier heads shown here are probably other parts of the circuit). Interestingly, patching L9H11's attention is having a negative effect, so it is probably a negative NMH.

In [None]:
attn_head_pattern_all_pos_act_patch_results = patching.get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_pattern_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_pattern' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

If we look at the relative importance of attention vs. value patching for these heads, it's even clearer. But what's apparent is that one of the name mover attention heads is clearly above all the others.

In [None]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_pattern_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

#### Path Patching for NMH Candidate Receivers

Let us first see what heads are contributing the most to our name mover head candidates. Here we are patching corrupted tokens input in to replace clean token input, so negative values (red) correspond to importance of the respective head.

In [None]:
receiver_heads = [(8, 1), (8,2), (8, 5), (8, 8), (8, 10)]

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_q", head_idx) for layer_idx, head_idx in receiver_heads],
                positions=-1
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            iot_metric_res = ioi_metric(path_patched_logits)
            metric_delta_results[layer, head_idx] = -(clean_baseline_ioi - iot_metric_res) / clean_baseline_ioi

We only see one head of importance here--only one potential S2-inhibitor.

In [None]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers")#, zmin=-0.02, zmax=0.02)

### Second Level

#### Attention Pattern for Second-Level Heads

Here we can see that the positive head we've identified is highly focused on S2, adding evidence for its role as an S2-inhibitor.

In [None]:
second_level_positive_heads = [(7, 5)]
visualize_attention_patterns(torch.tensor([l*12+h for l, h in second_level_positive_heads]), title=f"Top Positive Second Level IOI Metric Heads")

#second_level_negative_heads = [(7, 8), (8, 10)]
#visualize_attention_patterns(torch.tensor([l*12+h for l, h in second_level_negative_heads]), title=f"Top Negative Second Level IOI Metric Heads")

More evidence that these are S-inhibition heads. S-inhibition heads will have a higher relative importance on values as opposed to other head attributes.

In [None]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_v_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    xaxis="Value Patch",
    yaxis="Output Patch",
    #caxis="Layer",
    hover_name = head_labels,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-1.5, 1.5),
    range_y=(-1.5, 1.5),
    title="Scatter plot of output patching vs value patching")

#### Path Patching for S2-Inhibition Candidates

In [None]:
model.to_str_tokens(clean_tokens[3])

['<|endoftext|>',
 'When',
 ' Tom',
 ' and',
 ' James',
 ' went',
 ' to',
 ' the',
 ' park',
 ',',
 ' Tom',
 ' gave',
 ' the',
 ' ball',
 ' to']

In [None]:
receiver_heads = [(7, 5)]

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_v", head_idx) for layer_idx, head_idx in receiver_heads],
                positions=10
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            iot_metric_res = ioi_metric(path_patched_logits)
            metric_delta_results[layer, head_idx] = -(clean_baseline_ioi - iot_metric_res) / clean_baseline_ioi

The heads we see below have a strong effect on the the IOI metric via the values of the S2-inhibition head candidates at the S2 position. The most significant heads are:
- L6H7
- L6H0
- L5H6
- L5H8
- L4H1
- L4H8
- L4H10
- L2H8

In [None]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers")#, zmin=-0.02, zmax=0.02)

### Third Level

#### Attention Patterns for Third-Level Heads

We have a mix of induction heads and duplicate token heads here, as well as two heads that focus on S2 at S2.

In [None]:
second_level_positive_heads = [(2, 8), (4, 1), (4, 8), (4, 10), (5, 6), (5, 8), (6, 0), (6, 7)]
visualize_attention_patterns(torch.tensor([l*12+h for l, h in second_level_positive_heads]), title=f"Top Positive Third Level IOI Metric Heads")

### Backup Name Mover Heads

In [None]:
heads_to_ablate = [(8, 1), (8,2), (8, 5), (8, 8), (8, 10)]

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")
#print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
#print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Heads to ablate: [(8, 1), (8, 2), (8, 5), (8, 8), (8, 10)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.8718382716178894


In [None]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


### Backup S2 Inhibitors

In [None]:
heads_to_ablate = [(7,5)]

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")
#print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
#print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Heads to ablate: [(7, 5)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.49935075640678406


In [None]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


### Alternate, still in development

In [None]:
def get_patched_layer_output(layer, head_idx, orig_cache, patch_cache):
    """Patches the target head at the z tensor, and then recomputes and returns attn_out.
        The original IOI paper recomputed the entire model, but we can use this shortcut
        to only compute the relevant parts.

    Args:
      layer (int): Layer to be patched.
      head (int): Index of head to be patched.
      orig_cache
      patch_cache

    Returns:
      torch.Tensor: attn_out, the tensor that would be added to the residual stream.
    """
    layer_z_cache_name = f'blocks.{layer}.attn.hook_z'
    layer_z_patched = orig_cache[layer_z_cache_name]
    head_z_source = patch_cache[layer_z_cache_name][:, :, head_idx, :]
    layer_z_patched[:, :, head_idx, :] = head_z_source

    patched_attn_out = einsum(
        "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", 
        layer_z_patched, 
        model.blocks[layer].attn.W_O) + model.blocks[layer].attn.b_O

    return patched_attn_out


def get_path_patching_hook(
    receiver_weights,
    receiver_bias,
    sender_layer,
    sender_head,
    receiver_layer,
    receiver_head,
    #positions,
    orig_cache,
    patch_cache
):
    """Gets the patching hook for a given sender and receiver.
    """
    def patch_path(
        destination_head_component,
        hook,
        receiver_weights,
        receiver_bias,
        sender_layer,
        sender_head,
        receiver_layer,
        receiver_head,
        orig_cache,
        patch_cache
    ):
        patched_attn_output = get_patched_layer_output(sender_layer, sender_head, orig_cache, patch_cache)
        orig_attn_output_name = f'blocks.{sender_layer}.hook_attn_out'
        orig_receiver_ln1_scale_name = f'blocks.{receiver_layer}.ln1.hook_scale'

        outputs_diff = patched_attn_output - orig_cache[orig_attn_output_name]
        #print(f"{outputs_diff.shape=}")
        #print(f"{receiver_weights[receiver_head, :, :].shape=}")
        #print(f"{destination_head_component[:, :, receiver_head, :].shape=}")

        destination_head_component[:, :, receiver_head, :] = (destination_head_component[:, :, receiver_head, :] 
                                                              + ((outputs_diff) @ receiver_weights[receiver_head, :, :] 
                                                                 + receiver_bias[receiver_head]) / orig_cache[orig_receiver_ln1_scale_name])
    
    return partial(
        patch_path, 
        receiver_weights=receiver_weights, 
        receiver_bias=receiver_bias, 
        sender_layer=sender_layer, 
        sender_head=sender_head, 
        receiver_layer=receiver_layer, 
        receiver_head=receiver_head, 
        orig_cache=orig_cache, 
        patch_cache=patch_cache)
  

def run_path_patching_experiment(
    model,
    orig_tokens,
    patch_tokens,
    receiver_heads,
    param_module_type,
    metric_fn,
    average=True
):  
    """For each receiver head, tests the effect of each previous head on the given metric.
    """
    results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, len(receiver_heads), device="cuda", dtype=torch.float32)

    orig_logits, orig_cache = model.run_with_cache(orig_tokens)
    patch_logits, patch_cache = model.run_with_cache(patch_tokens)
    
    for i, receiver_head in enumerate(receiver_heads):
        receiver_layer, receiver_head_idx = receiver_head
        for layer in range(0, receiver_layer):
            for head_idx in range(12): # TODO: replace magic number
                
                # Get receiver parameters
                if param_module_type=="query":
                    receiver_weights = model.blocks[receiver_layer].attn.W_Q
                    receiver_bias = model.blocks[receiver_layer].attn.b_Q
                elif param_module_type=="key":
                    receiver_weights = model.blocks[receiver_layer].attn.W_K
                    receiver_bias = model.blocks[receiver_layer].attn.b_K
                elif param_module_type=="value":
                    receiver_weights = model.blocks[receiver_layer].attn.W_V
                    receiver_bias = model.blocks[receiver_layer].attn.b_V
                else:
                    raise Exception("wrong receiver type!")
                
                # Get patching hook
                patching_hook = get_path_patching_hook(
                    receiver_weights=receiver_weights,
                    receiver_bias=receiver_bias,
                    sender_layer=layer,
                    sender_head=head_idx,
                    receiver_layer=receiver_layer,
                    receiver_head=receiver_head_idx,
                    orig_cache=orig_cache,
                    patch_cache=patch_cache
                )

                # Run one iteration of the experiment
                patched_logits = model.run_with_hooks(
                    orig_tokens,
                    fwd_hooks = [(utils.get_act_name("q", receiver_layer, "attn"), 
                                  patching_hook)],
                    return_type = "logits"
                )
                patched_metric_results = metric_fn(patched_logits, answer_token_indices)

                # Save results
                #print(f"Effect of layer {layer} head {head_idx} on head {receiver_head}")
                results[layer, head_idx, i] = patched_metric_results
    if average:
        return results.mean(dim=2)
    else:
        return results


In [None]:
path_patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=pass_d_hooks)
get_logit_diff(path_patched_logits)

tensor(-4.3018, device='cuda:0')

In [None]:
get_logit_diff(model(corrupted_tokens))

tensor(-4.3018, device='cuda:0')

In [None]:
res = run_path_patching_experiment(
    model=model,
    orig_tokens=corrupted_tokens,
    patch_tokens=clean_tokens,
    receiver_heads=[(8,9), (8,10)],
    param_module_type="value",
    metric_fn=ioi_metric,
)

In [None]:
imshow(res, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")#, zmin=-0.02, zmax=0.02)

In [None]:
def get_patched_layer_output(layer, head_idx, orig_cache, patch_cache):
    """Patches the target head at the z tensor, and then recomputes and returns attn_out.
        The original IOI paper recomputed the entire model, but we can use this shortcut
        to only compute the relevant parts.

    Args:
      layer (int): Layer to be patched.
      head (int): Index of head to be patched.
      orig_cache
      patch_cache

    Returns:
      torch.Tensor: attn_out, the tensor that would be added to the residual stream.
    """
    layer_z_cache_name = f'blocks.{layer}.attn.hook_z'
    layer_z_patched = orig_cache[layer_z_cache_name]
    head_z_source = patch_cache[layer_z_cache_name][:, :, head_idx, :]
    layer_z_patched[:, :, head_idx, :] = head_z_source

    patched_attn_out = einsum(
        "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", 
        layer_z_patched, 
        model.blocks[layer].attn.W_O) + model.blocks[layer].attn.b_O

    return patched_attn_out

In [None]:
sender_layer = 3
sender_head = 2
receiver_layer = 11
receiver_head = 6

orig_cache = corrupted_cache
patch_cache = clean_cache

patched_attn_output = get_patched_layer_output(sender_layer, sender_head, orig_cache, patch_cache)
orig_attn_output_name = f'blocks.{sender_layer}.hook_attn_out'
orig_receiver_ln1_scale_name = f'blocks.{receiver_layer}.ln1.hook_scale'
        
res = (patched_attn_output - orig_cache[orig_attn_output_name]) @ model.blocks[receiver_layer].attn.W_Q[receiver_head] / orig_cache[orig_receiver_ln1_scale_name]
res.shape

torch.Size([8, 15, 64])

In [None]:
head_B = (11, 6)
head_B_W_Q = model.blocks[10].attn.W_Q[4]
head_B_dest = corrupted_cache['blocks.11.attn.hook_q'][:, :, head_B[1], :]
head_B_dest.shape

torch.Size([8, 15, 64])

In [None]:
model.blocks[10].ln1.hook_scale

HookPoint()

In [None]:
print(utils.get_act_name(11, "hook_attn_out"))

TypeError: ignored

In [None]:
clean_cache["blocks.11.ln1.hook_scale"].shape

torch.Size([8, 15, 1])

In [None]:
head_A = (3, 2)
head_A_cache_name = 'blocks.3.attn.hook_z'
layer_A_z_patched = corrupted_cache[head_A_cache_name] #[:, :, head_A[1], :]
head_A_z_source = clean_cache[head_A_cache_name][:, :, head_A[1], :]
layer_A_z_patched[:, :, head_A[1], :] = head_A_z_source 

In [None]:
layer_A_patched_attn_out = einsum(
    "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", 
    layer_A_z_patched, 
    model.blocks[3].attn.W_O) + model.blocks[3].attn.b_O

In [None]:
layer_A_patched_attn_out.shape

torch.Size([8, 15, 768])

In [None]:
model.blocks[3]

TransformerBlock(
  (ln1): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (ln2): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (attn): Attention(
    (hook_k): HookPoint()
    (hook_q): HookPoint()
    (hook_v): HookPoint()
    (hook_z): HookPoint()
    (hook_attn_scores): HookPoint()
    (hook_pattern): HookPoint()
    (hook_result): HookPoint()
    (hook_rot_k): HookPoint()
    (hook_rot_q): HookPoint()
  )
  (mlp): MLP(
    (hook_pre): HookPoint()
    (hook_post): HookPoint()
  )
  (hook_q_input): HookPoint()
  (hook_k_input): HookPoint()
  (hook_v_input): HookPoint()
  (hook_attn_out): HookPoint()
  (hook_mlp_out): HookPoint()
  (hook_resid_pre): HookPoint()
  (hook_resid_post): HookPoint()
)

In [None]:
model.blocks[3].attn.W_Q[2].shape

torch.Size([768, 64])

In [None]:
model #.blocks[11].ln1.hook_scale

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()


In [None]:
def path_patching(
    model,
    patch_tokens,
    orig_tokens,
    sender_heads,
    receiver_hooks,
    positions=["end"],
    return_hooks=False,
    extra_hooks=[],  # when we call reset hooks, we may want to add some extra hooks after this, add these here
    freeze_mlps=False,  # recall in IOI paper we consider these "vital model components"
    have_internal_interactions=False,
):
    """
    Patch in the effect of `sender_heads` on `receiver_hooks` only
    (though MLPs are "ignored" if `freeze_mlps` is False so are slight confounders in this case - see Appendix B of https://arxiv.org/pdf/2211.00593.pdf)
    TODO fix this: if max_layer < model.cfg.n_layers, then let some part of the model do computations (not frozen)

    Args: 
      model (HookedTransformer):
      patch_tokens ()
    Returns:

    """

    # This function ensures that the patch is applied to every position in the position list.
    def patch_positions(z, source_act, hook, positions=["end"], verbose=False):
        for pos in positions:
            z[torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]] = source_act[
                torch.arange(patch_tokens.N), patch_tokens.word_idx[pos]
            ]
        return z

    # Puts all head names into a list.
    sender_hooks = []
    for layer, head_idx in sender_heads:
        if head_idx is None:
            sender_hooks.append((f"blocks.{layer}.hook_mlp_out", None))

        else:
            sender_hooks.append((f"blocks.{layer}.attn.hook_result", head_idx))

    # Makes lists of the names only of the sender and receiver hooks.
    sender_hook_names = [x[0] for x in sender_hooks]
    receiver_hook_names = [x[0] for x in receiver_hooks]

    # Forward pass A (in https://arxiv.org/pdf/2211.00593.pdf)
    # Similar to running with "clean" prompts
    sender_cache = {}
    model.reset_hooks()
    # This portion is not used.
    # for hook in extra_hooks:
    #     model.add_hook(*hook)
    model.cache_some(
        sender_cache, lambda x: x in sender_hook_names, suppress_warning=True
    )
    source_logits = model(patch_tokens.toks.long())

    # Forward pass B
    # Similar to running on "corrupted" prompts
    target_cache = {}
    model.reset_hooks()
    # for hook in extra_hooks:
    #     model.add_hook(*hook)
    model.cache_all(target_cache, suppress_warning=True)
    target_logits = model(orig_tokens.toks.long())

    # Forward pass C
    # Cache the receiver hooks
    # (adding these hooks first means we save values BEFORE they are overwritten)
    receiver_cache = {}
    model.reset_hooks()
    model.cache_some(
        receiver_cache,
        lambda x: x in receiver_hook_names,
        suppress_warning=True,
        verbose=False,
    )

    # "Freeze" intermediate heads to their orig_tokens values
    # Effectively, for each layer and head, we are patching in the activations from run B
    # (e.g. the "corrupted" prompts), which is equivalent to freezing their behavior to 
    # that of the "corrupted" prompts
    for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            for hook_template in [
                "blocks.{}.attn.hook_q",
                "blocks.{}.attn.hook_k",
                "blocks.{}.attn.hook_v",
            ]:
                hook_name = hook_template.format(layer)

                # this won't run in the default case
                # if have_internal_interactions and hook_name in receiver_hook_names:
                #     continue # stops the rest of the loop's code and goes to next iteration

                hook = get_act_hook(
                    patch_all,
                    alt_act=target_cache[hook_name],
                    idx=head_idx,
                    dim=2 if head_idx is not None else None,
                    name=hook_name,
                )
                model.add_hook(hook_name, hook)

        # also won't run in the default case
        # if freeze_mlps:
        #     hook_name = f"blocks.{layer}.hook_mlp_out"
        #     hook = get_act_hook(
        #         patch_all,
        #         alt_act=target_cache[hook_name],
        #         idx=None,
        #         dim=None,
        #         name=hook_name,
        #     )
        #     model.add_hook(hook_name, hook)

    # for hook in extra_hooks:
    #     model.add_hook(*hook)

    # These hooks will overwrite the freezing, for the sender heads - this should be forward pass C
    # In more detail, the sender heads will be patched with the new or "clean" activations, and the MLPs
    # and layer norms will be recomputed. Frozen heads won't be affected.
    for hook_name, head_idx in sender_hooks:
        assert not torch.allclose(sender_cache[hook_name], target_cache[hook_name]), (
            hook_name,
            head_idx,
        )
        hook = get_act_hook(
            partial(patch_positions, positions=positions),
            alt_act=sender_cache[hook_name],
            idx=head_idx,
            dim=2 if head_idx is not None else None,
            name=hook_name,
        )
        model.add_hook(hook_name, hook)
    receiver_logits = model(orig_tokens.toks.long())

    # Add (or return) all the hooks needed for forward pass D
    model.reset_hooks()
    hooks = []

    # Not used
    # for hook in extra_hooks:
    #     hooks.append(hook)

    for hook_name, head_idx in receiver_hooks:
        for pos in positions:
            if torch.allclose(
                receiver_cache[hook_name][torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]],
                target_cache[hook_name][torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]],
            ):
                warnings.warn("Torch all close for {}".format(hook_name))
        
        # Get the cached receiver activations (cached before getting overwritten)
        hook = get_act_hook(
            partial(patch_positions, positions=positions),
            alt_act=receiver_cache[hook_name],
            idx=head_idx,
            dim=2 if head_idx is not None else None,
            name=hook_name,
        )
        hooks.append((hook_name, hook))

    model.reset_hooks()
    if return_hooks:
        return hooks
    else:
        for hook_name, hook in hooks:
            model.add_hook(hook_name, hook)
        return model