# Indirect Object Identification Circuit in Pythia

In [1]:

IN_COLAB = False
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the HookedTransformer code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import os
import pathlib
from typing import List, Optional, Union

import torch
import numpy as np
import yaml

import einops
from fancy_einsum import einsum

from datasets import load_dataset
from transformers import pipeline
import plotly.io as pio
import plotly.express as px
#import pysvelte
from IPython.display import HTML

import plotly.graph_objs as go
import ipywidgets as widgets
from IPython.display import display

# if IN_COLAB or not DEBUG_MODE:
#     # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
#     pio.renderers.default = "colab"
# else:
#     pio.renderers.default = "plotly_mimetype+notebook"

if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"

In [3]:
import transformers
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
import transformer_lens
import transformer_lens.utils as utils
import transformer_lens.patching as patching
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from functools import partial

from torchtyping import TensorType as TT

In [5]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fddc078f970>

In [6]:
from neel_plotly import line, imshow, scatter

def l_imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def l_line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def l_scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)

## Model Setup

In [9]:
import huggingface_hub
huggingface_hub.notebook_login()

Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [7]:
source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-v1.1-1.4b", revision="step143000")

In [11]:
model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-1.3b",
    hf_model=source_model,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    #refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-1.3b into HookedTransformer


In [12]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 17.05 Prob: 48.49% Token: | Mary|
Top 1th token. Logit: 15.56 Prob: 10.87% Token: | the|
Top 2th token. Logit: 14.99 Prob:  6.18% Token: | a|
Top 3th token. Logit: 14.59 Prob:  4.15% Token: | his|
Top 4th token. Logit: 13.79 Prob:  1.85% Token: | each|
Top 5th token. Logit: 13.54 Prob:  1.44% Token: | John|
Top 6th token. Logit: 12.87 Prob:  0.74% Token: | their|
Top 7th token. Logit: 12.83 Prob:  0.71% Token: | one|
Top 8th token. Logit: 12.80 Prob:  0.69% Token: | Mrs|
Top 9th token. Logit: 12.53 Prob:  0.53% Token: | an|


In [13]:
example_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' Mary', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 15.91 Prob: 21.35% Token: | John|
Top 1th token. Logit: 15.86 Prob: 20.21% Token: | the|
Top 2th token. Logit: 15.13 Prob:  9.75% Token: | a|
Top 3th token. Logit: 14.87 Prob:  7.54% Token: | her|
Top 4th token. Logit: 13.86 Prob:  2.76% Token: | each|
Top 5th token. Logit: 13.22 Prob:  1.45% Token: | Mary|
Top 6th token. Logit: 12.86 Prob:  1.01% Token: | baby|
Top 7th token. Logit: 12.82 Prob:  0.97% Token: | one|
Top 8th token. Logit: 12.80 Prob:  0.95% Token: | their|
Top 9th token. Logit: 12.73 Prob:  0.88% Token: | an|


## Data Setup

In [14]:
prompts = [
    'When John and Mary went to the shops, John gave the bag to',
    'When John and Mary went to the shops, Mary gave the bag to',
    'When Tom and James went to the park, James gave the ball to',
    'When Tom and James went to the park, Tom gave the ball to',
    'When Dan and Sid went to the shops, Sid gave an apple to',
    'When Dan and Sid went to the shops, Dan gave an apple to',
    'After Martin and Amy went to the park, Amy gave a drink to',
    'After Martin and Amy went to the park, Martin gave a drink to'
    ]
    
answers = [
    (' Mary', ' John'), 
    (' John', ' Mary'), 
    (' Tom', ' James'), 
    (' James', ' Tom'), 
    (' Dan', ' Sid'), 
    (' Sid', ' Dan'), 
    (' Martin', ' Amy'), 
    (' Amy', ' Martin')
    ]

clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)
print("Answer token indices", answer_token_indices)

Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to
Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Answer token indices tensor([[ 6393,  2516],
        [ 2516,  6393],
        [ 6270,  5490],
        [ 5490,  6270],
        [ 5682, 24752],
        [24752,  5682],
        [ 8698, 22138],
        [22138,  8698]], device='cuda:0')


In [15]:
model.to_str_tokens(corrupted_tokens[1])

['<|endoftext|>',
 'When',
 ' John',
 ' and',
 ' Mary',
 ' went',
 ' to',
 ' the',
 ' shops',
 ',',
 ' John',
 ' gave',
 ' the',
 ' bag',
 ' to']

## Tool Setup

### Activation Patching

In [16]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 4.5503
Corrupted logit diff: -4.5503


In [17]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

clean_baseline_ioi = ioi_metric(clean_logits)
corrupted_baseline_ioi = ioi_metric(corrupted_logits)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [18]:
# Whether to do the runs by head and by position, which are much slower
DO_SLOW_RUNS = True

### Path Patching

In [19]:
def patch_pos_head_vector(
    orig_head_vector: TT["batch", "pos", "head_index", "d_head"],
    hook,
    pos, 
    head_index, 
    patch_cache):
    #print(patch_cache.keys())
    orig_head_vector[:, pos, head_index, :] = patch_cache[hook.name][:, pos, head_index, :]
    return orig_head_vector

def patch_head_vector(
    orig_head_vector: TT["batch", "pos", "head_index", "d_head"],
    hook,
    head_index, 
    patch_cache):
    orig_head_vector[:, :, head_index, :] = patch_cache[hook.name][:, :, head_index, :]
    return orig_head_vector

In [20]:
def path_patching(
    model,
    patch_tokens,
    orig_tokens,
    sender_heads,
    receiver_hooks,
    positions=-1,
):
    """
    Patch in the effect of `sender_heads` on `receiver_hooks` only
    (though MLPs are "ignored" if `freeze_mlps` is False so are slight confounders in this case - see Appendix B of https://arxiv.org/pdf/2211.00593.pdf)

    TODO fix this: if max_layer < model.cfg.n_layers, then let some part of the model do computations (not frozen)
    """

    def patch_positions(z, source_act, hook, positions=["end"], verbose=False):
        for pos in positions:
            z[torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]] = source_act[
                torch.arange(patch_tokens.N), patch_tokens.word_idx[pos]
            ]
        return z

    # process arguments
    sender_hooks = []
    for layer, head_idx in sender_heads:
        if head_idx is None:
            sender_hooks.append((f"blocks.{layer}.hook_mlp_out", None))

        else:
            sender_hooks.append((f"blocks.{layer}.attn.hook_z", head_idx))

    sender_hook_names = [x[0] for x in sender_hooks]
    receiver_hook_names = [x[0] for x in receiver_hooks]
    receiver_hook_heads = [x[1] for x in receiver_hooks]
    # Forward pass A (in https://arxiv.org/pdf/2211.00593.pdf)
    source_logits, sender_cache = model.run_with_cache(patch_tokens)

    # Forward pass B
    target_logits, target_cache = model.run_with_cache(orig_tokens)

    # Forward pass C
    # Cache the receiver hooks
    # (adding these hooks first means we save values BEFORE they are overwritten)
    receiver_cache = model.add_caching_hooks(lambda x: x in receiver_hook_names)

    # "Freeze" intermediate heads to their orig_tokens values
    # q, k, and v will get frozen, and then if it's a sender head, this will get undone
    # z, attn_out, and the MLP will all be recomputed and added to the residual stream
    # however, the effect of the change on the residual stream will be overwritten by the
    # freezing for all non-receiver components
    pass_c_hooks = []
    for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            for hook_template in [
                "blocks.{}.attn.hook_q",
                "blocks.{}.attn.hook_k",
                "blocks.{}.attn.hook_v",
            ]:
                hook_name = hook_template.format(layer)
                if (hook_name, head_idx) not in receiver_hooks:
                    #print(f"Freezing {hook_name}")
                    hook = partial(
                        patch_head_vector,
                        head_index=head_idx,
                        patch_cache=target_cache
                    )
                    pass_c_hooks.append((hook_name, hook))
                else:
                    pass
                    #print(f"Not freezing {hook_name}")

    # These hooks will overwrite the freezing, for the sender heads
    # We also carry out pass C
    for hook_name, head_idx in sender_hooks:
        assert not torch.allclose(sender_cache[hook_name], target_cache[hook_name]), (
            hook_name,
            head_idx,
        )
        hook = partial(
            patch_pos_head_vector,
            pos=positions,
            head_index=head_idx,
            patch_cache=sender_cache
        )
        pass_c_hooks.append((hook_name, hook))
  
    receiver_logits = model.run_with_hooks(orig_tokens, fwd_hooks=pass_c_hooks)
    # Add (or return) all the hooks needed for forward pass D
    pass_d_hooks = []

    for hook_name, head_idx in receiver_hooks:
        #for pos in positions:
            # if torch.allclose(
            #     receiver_cache[hook_name][torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]],
            #     target_cache[hook_name][torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]],
            # ):
            #     warnings.warn("Torch all close for {}".format(hook_name))
        hook = partial(
            patch_pos_head_vector,
            pos=positions,
            head_index=head_idx,
            patch_cache=receiver_cache
        )
        pass_d_hooks.append((hook_name, hook))

    return pass_d_hooks
    

### Attention Visualization

In [21]:
def get_attn_head_patterns(model, prompt, attn_heads):
    if isinstance(prompt, str):
        prompt = model.to_tokens(prompt)
    logits, cache = model.run_with_cache(prompt, remove_batch_dim=True)

    head_list = []
    head_name_list = []
    for layer, head in attn_heads:
        head_list.append(cache["pattern", layer, "attn"][head])
        head_name_list.append(f"L{layer}H{head}")
    attention_pattern = torch.stack(head_list, dim=0)
    tokens = model.to_str_tokens(prompt)
    
    return tokens, attention_pattern, head_name_list

## Direct Logit Attribution

In this section, we use the residual directions of the logit difference to find what parts of the network are contributing the most to that subspace. Using logit lens, we see how the subspace evolves, and using direct logit contribution, we can see what parts of the network correspond most with the logit difference directions.

In [22]:
answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([8, 2, 2048])
Logit difference directions shape: torch.Size([8, 2048])


In [23]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = clean_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",clean_logit_diff)

Final residual stream shape: torch.Size([8, 15, 2048])
Calculated average logit diff: 4.550283432006836
Original logit difference: 4.550281047821045


### Logit Lens

The model cannot perform the task until layer 8, and performance increases and then drops afterwards. (This graph shows the logit lens at `resid_pre`, so "8" corresponds to the activations previous to block 8.)

In [24]:
def residual_stack_to_logit_diff(residual_stack: TT["components", "batch", "d_model"], cache: ActivationCache) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)
    return einsum("... batch d_model, batch d_model -> ...", scaled_residual_stack, answer_residual_directions[:, 0])/len(prompts)


In [25]:
accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, clean_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers+1), hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

### Layer Attribution

Major contributor is are attention layers 7 and 8; attention 9 seems to have a strong negative effect.

In [26]:
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, clean_cache)
l_line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

### Head Attribution

Heads L7H5 and L8H10 make the biggest positive difference, and heads L8H1, L8H2, L8H5, L8H8 make a more modest difference. L9H1 and L9H11 make a negative contribution.

These are likely to be name movers and potentially S-inhibition heads.

In [28]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, clean_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


In [None]:
per_head_logit_diffs.flatten().shape

torch.Size([144])

### Attention Analysis

Of the top positive contributors, only 4 out of 5 appear to be attending to the IO name. The other (L7H5)is attending to S2. Might it be an S-inhibition head?

Among the top negative heads, only L9H11, L9H8 and L9H3 are attending to the IO. L9H3 is also attending to S2, and as such may also be an S-inhibition head.

In [51]:
top_k = 10
top_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [53]:
nmh_candidates = [(12, 15), (13, 1), (13, 6), (15, 15), (16, 13), (17, 0), (17, 10), (17, 7)]

In [30]:
top_k = 2
top_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, example_prompt, heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

## Activation Patching for Model Component Importance

Before we sketch out the circuit, we'll take a look at the model from a top-down perspective and see how computation flows through the network. We will use the activation patching technique to see what layers, positions, heads, etc. are important as data flows through the network.

### Residual Stream

Computation occurs at the S2 position for all layers until layer 7 (resid_pre is the residual stream prior to the given layer).

In [31]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'resid_pre' Activation Patching")

  0%|          | 0/360 [00:00<?, ?it/s]

In [32]:
imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'resid_pre' Activation Patching")

### MLP Layers

MLP layers do not matter much. Layer 0 is, per Neel Nanda's suggestion, probably an extension of the embedding.

In [33]:
resid_pre_act_patch_results = patching.get_act_patch_mlp_out(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'mlp_out' Activation Patching")

  0%|          | 0/360 [00:00<?, ?it/s]

In [34]:
imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'mlp_out' Activation Patching")

### Attention Layers

Oddly, attention only seems to matter on the final token, and only for layers 7, 8, and 9. In GPT-2, attention layers mattered on the S2 token, but here we do not see that that is the case.

Swapping in the clean activations for Attention 7 and 8 makes a positive difference in recovering performance, but layer 9 makes a negative difference. With the exception of layer 7, this follows what we saw earlier with the name mover heads.

In [35]:
resid_pre_act_patch_results = patching.get_act_patch_attn_out(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

  0%|          | 0/360 [00:00<?, ?it/s]

In [36]:
imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'attn_out' Activation Patching")

### Attention Layers by Head

Here we see that patching L7H5 recovers full performance, and patching L8H10 recovers a quarter of the performance. Patching L9H1 and L9H11 reduce performance. Aside from L7H5 (whose role is unclear at present), this matches with the positive NMHs and the two negative NMHs of significant magnitude.

In [37]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)

  0%|          | 0/384 [00:00<?, ?it/s]

In [38]:
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_out' Activation Patching (All Pos)")

We can also look at the values separately.

In [None]:
attn_head_v_all_pos_act_patch_results = patching.get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_v' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

NameError: ignored

In [54]:
nmh_candidates

[(12, 15), (13, 1), (13, 6), (15, 15), (16, 13), (17, 0), (17, 10), (17, 7)]

We can also see what heads are important at what positions for the IOI task.

In [None]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
    attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")
    imshow(attn_head_out_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_out Activation Patching By Pos")

  0%|          | 0/2160 [00:00<?, ?it/s]

### Multiple Activation Types

In [39]:
every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)

  0%|          | 0/360 [00:00<?, ?it/s]

  0%|          | 0/360 [00:00<?, ?it/s]

  0%|          | 0/360 [00:00<?, ?it/s]

In [40]:
imshow(every_block_result, facet_col=0, facet_labels=["Residual Stream", "Attn Output", "MLP Output"], title="Activation Patching Per Block", xaxis="Position", yaxis="Layer", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])

### Head Component Output

We can decompose the heads and patch different components, such as query, key, value, and attention pattern parts. Doing so here helps us to understand the role of L7H5; unlike the NMH L8H10, the value rather than the attention pattern is what's important. (We also see this for the negative NMH L9H1.)

In [41]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
# [markdown]
# We can also do by head *and* by position. This is a bit slow, but it can give useful + fine-grained detail

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

  0%|          | 0/384 [00:00<?, ?it/s]

In [48]:
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer") #, zmax=1, zmin=-1)


In [55]:
nmh_candidates

[(12, 15), (13, 1), (13, 6), (15, 15), (16, 13), (17, 0), (17, 10), (17, 7)]

## Circuit Sketching

### First Level

Previously we identified some NMH candidates through logit difference contribution. In order to be NMHs, the attention pattern should be the key attention head component to performance (as opposed to value), and it should be in the later layers (as opposed to e.g. induction heads, whose attention patterns are also important but occur earlier).

Is this the case? We can do some attention patching to find out.

#### Attention Pattern Patching

Here we can see that patching the head attention patterns has a different effect from our whole-head patching above--all of our potential NMHs L8H1, L8H2, L8H5, L8H8, and L8H10 are represented. (Earlier heads shown here are probably other parts of the circuit). Interestingly, patching L9H11's attention is having a negative effect, so it is probably a negative NMH.

In [43]:
attn_head_pattern_all_pos_act_patch_results = patching.get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)

  0%|          | 0/384 [00:00<?, ?it/s]

In [44]:
imshow(attn_head_pattern_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_pattern' Activation Patching (All Pos)")

In [56]:
nmh_candidates

[(12, 15), (13, 1), (13, 6), (15, 15), (16, 13), (17, 0), (17, 10), (17, 7)]

If we look at the relative importance of attention vs. value patching for these heads, it's even clearer. But what's apparent is that one of the name mover attention heads is clearly above all the others.

In [47]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_pattern_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

#### Path Patching for NMH Candidate Receivers

Let us first see what heads are contributing the most to our name mover head candidates. Here we are patching corrupted tokens input in to replace clean token input, so negative values (red) correspond to importance of the respective head.

In [63]:
receiver_heads = [(8, 2), (8, 10), (9, 6), (10, 7)]

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_q", head_idx) for layer_idx, head_idx in receiver_heads],
                positions=-1
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            iot_metric_res = ioi_metric(path_patched_logits)
            metric_delta_results[layer, head_idx] = -(clean_baseline_ioi - iot_metric_res) / clean_baseline_ioi

In [64]:
metric_delta_results

tensor([[-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0., -0.],
        [-

We only see one head of importance here--only one potential S2-inhibitor.

In [58]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers")#, zmin=-0.02, zmax=0.02)

### Second Level

#### Attention Pattern for Second-Level Heads

Here we can see that the positive head we've identified is highly focused on S2, adding evidence for its role as an S2-inhibitor.

In [None]:
second_level_positive_heads = [(6, 6), (7, 2), (7, 9), (8, 9)]

tokens, attn, names = get_attn_head_patterns(model, example_prompt, second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

#second_level_negative_heads = [(7, 8), (8, 10)]
#visualize_attention_patterns(torch.tensor([l*12+h for l, h in second_level_negative_heads]), title=f"Top Negative Second Level IOI Metric Heads")

More evidence that these are S-inhibition heads. S-inhibition heads will have a higher relative importance on values as opposed to other head attributes.

In [None]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_v_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    xaxis="Value Patch",
    yaxis="Output Patch",
    #caxis="Layer",
    hover_name = head_labels,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-1.5, 1.5),
    range_y=(-1.5, 1.5),
    title="Scatter plot of output patching vs value patching")

#### Path Patching for S2-Inhibition Candidates

In [None]:
model.to_str_tokens(clean_tokens[3])

['<|endoftext|>',
 'When',
 ' Tom',
 ' and',
 ' James',
 ' went',
 ' to',
 ' the',
 ' park',
 ',',
 ' Tom',
 ' gave',
 ' the',
 ' ball',
 ' to']

In [None]:
receiver_heads = second_level_positive_heads

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_v", head_idx) for layer_idx, head_idx in receiver_heads],
                positions=10
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            iot_metric_res = ioi_metric(path_patched_logits)
            metric_delta_results[layer, head_idx] = -(clean_baseline_ioi - iot_metric_res) / clean_baseline_ioi

The heads we see below have a strong effect on the the IOI metric via the values of the S2-inhibition head candidates at the S2 position. The most significant heads are:
- L6H7
- L6H0
- L5H6
- L5H8
- L4H1
- L4H8
- L4H10
- L2H8

In [None]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers")#, zmin=-0.02, zmax=0.02)

### Third Level

#### Attention Patterns for Third-Level Heads

We have a mix of induction heads and duplicate token heads here, as well as two heads that focus on S2 at S2.

In [None]:
third_level_positive_heads = [(4, 6), (4, 11), (5, 10)]

tokens, attn, names = get_attn_head_patterns(model, example_prompt, second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

### Backup Name Mover Heads

In [None]:
heads_to_ablate = [(8, 2), (8, 10), (9, 6), (10, 7)]

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")
#print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
#print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Heads to ablate: [(8, 2), (8, 10), (9, 6), (10, 7)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.8575706481933594


In [None]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


### Backup S2 Inhibitors

In [None]:
heads_to_ablate = second_level_positive_heads

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")
#print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
#print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Heads to ablate: [(6, 6), (7, 2), (7, 9)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.7133723497390747


In [None]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


### Alternate, still in development

In [None]:
def get_patched_layer_output(layer, head_idx, orig_cache, patch_cache):
    """Patches the target head at the z tensor, and then recomputes and returns attn_out.
        The original IOI paper recomputed the entire model, but we can use this shortcut
        to only compute the relevant parts.

    Args:
      layer (int): Layer to be patched.
      head (int): Index of head to be patched.
      orig_cache
      patch_cache

    Returns:
      torch.Tensor: attn_out, the tensor that would be added to the residual stream.
    """
    layer_z_cache_name = f'blocks.{layer}.attn.hook_z'
    layer_z_patched = orig_cache[layer_z_cache_name]
    head_z_source = patch_cache[layer_z_cache_name][:, :, head_idx, :]
    layer_z_patched[:, :, head_idx, :] = head_z_source

    patched_attn_out = einsum(
        "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", 
        layer_z_patched, 
        model.blocks[layer].attn.W_O) + model.blocks[layer].attn.b_O

    return patched_attn_out


def get_path_patching_hook(
    receiver_weights,
    receiver_bias,
    sender_layer,
    sender_head,
    receiver_layer,
    receiver_head,
    #positions,
    orig_cache,
    patch_cache
):
    """Gets the patching hook for a given sender and receiver.
    """
    def patch_path(
        destination_head_component,
        hook,
        receiver_weights,
        receiver_bias,
        sender_layer,
        sender_head,
        receiver_layer,
        receiver_head,
        orig_cache,
        patch_cache
    ):
        patched_attn_output = get_patched_layer_output(sender_layer, sender_head, orig_cache, patch_cache)
        orig_attn_output_name = f'blocks.{sender_layer}.hook_attn_out'
        orig_receiver_ln1_scale_name = f'blocks.{receiver_layer}.ln1.hook_scale'

        outputs_diff = patched_attn_output - orig_cache[orig_attn_output_name]
        #print(f"{outputs_diff.shape=}")
        #print(f"{receiver_weights[receiver_head, :, :].shape=}")
        #print(f"{destination_head_component[:, :, receiver_head, :].shape=}")

        destination_head_component[:, :, receiver_head, :] = (destination_head_component[:, :, receiver_head, :] 
                                                              + ((outputs_diff) @ receiver_weights[receiver_head, :, :] 
                                                                 + receiver_bias[receiver_head]) / orig_cache[orig_receiver_ln1_scale_name])
    
    return partial(
        patch_path, 
        receiver_weights=receiver_weights, 
        receiver_bias=receiver_bias, 
        sender_layer=sender_layer, 
        sender_head=sender_head, 
        receiver_layer=receiver_layer, 
        receiver_head=receiver_head, 
        orig_cache=orig_cache, 
        patch_cache=patch_cache)
  

def run_path_patching_experiment(
    model,
    orig_tokens,
    patch_tokens,
    receiver_heads,
    param_module_type,
    metric_fn,
    average=True
):  
    """For each receiver head, tests the effect of each previous head on the given metric.
    """
    results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, len(receiver_heads), device="cuda", dtype=torch.float32)

    orig_logits, orig_cache = model.run_with_cache(orig_tokens)
    patch_logits, patch_cache = model.run_with_cache(patch_tokens)
    
    for i, receiver_head in enumerate(receiver_heads):
        receiver_layer, receiver_head_idx = receiver_head
        for layer in range(0, receiver_layer):
            for head_idx in range(12): # TODO: replace magic number
                
                # Get receiver parameters
                if param_module_type=="query":
                    receiver_weights = model.blocks[receiver_layer].attn.W_Q
                    receiver_bias = model.blocks[receiver_layer].attn.b_Q
                elif param_module_type=="key":
                    receiver_weights = model.blocks[receiver_layer].attn.W_K
                    receiver_bias = model.blocks[receiver_layer].attn.b_K
                elif param_module_type=="value":
                    receiver_weights = model.blocks[receiver_layer].attn.W_V
                    receiver_bias = model.blocks[receiver_layer].attn.b_V
                else:
                    raise Exception("wrong receiver type!")
                
                # Get patching hook
                patching_hook = get_path_patching_hook(
                    receiver_weights=receiver_weights,
                    receiver_bias=receiver_bias,
                    sender_layer=layer,
                    sender_head=head_idx,
                    receiver_layer=receiver_layer,
                    receiver_head=receiver_head_idx,
                    orig_cache=orig_cache,
                    patch_cache=patch_cache
                )

                # Run one iteration of the experiment
                patched_logits = model.run_with_hooks(
                    orig_tokens,
                    fwd_hooks = [(utils.get_act_name("q", receiver_layer, "attn"), 
                                  patching_hook)],
                    return_type = "logits"
                )
                patched_metric_results = metric_fn(patched_logits, answer_token_indices)

                # Save results
                #print(f"Effect of layer {layer} head {head_idx} on head {receiver_head}")
                results[layer, head_idx, i] = patched_metric_results
    if average:
        return results.mean(dim=2)
    else:
        return results


In [None]:
path_patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=pass_d_hooks)
get_logit_diff(path_patched_logits)

tensor(-4.3018, device='cuda:0')

In [None]:
get_logit_diff(model(corrupted_tokens))

tensor(-4.3018, device='cuda:0')

In [None]:
res = run_path_patching_experiment(
    model=model,
    orig_tokens=corrupted_tokens,
    patch_tokens=clean_tokens,
    receiver_heads=[(8,9), (8,10)],
    param_module_type="value",
    metric_fn=ioi_metric,
)

In [None]:
imshow(res, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")#, zmin=-0.02, zmax=0.02)

In [None]:
def get_patched_layer_output(layer, head_idx, orig_cache, patch_cache):
    """Patches the target head at the z tensor, and then recomputes and returns attn_out.
        The original IOI paper recomputed the entire model, but we can use this shortcut
        to only compute the relevant parts.

    Args:
      layer (int): Layer to be patched.
      head (int): Index of head to be patched.
      orig_cache
      patch_cache

    Returns:
      torch.Tensor: attn_out, the tensor that would be added to the residual stream.
    """
    layer_z_cache_name = f'blocks.{layer}.attn.hook_z'
    layer_z_patched = orig_cache[layer_z_cache_name]
    head_z_source = patch_cache[layer_z_cache_name][:, :, head_idx, :]
    layer_z_patched[:, :, head_idx, :] = head_z_source

    patched_attn_out = einsum(
        "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", 
        layer_z_patched, 
        model.blocks[layer].attn.W_O) + model.blocks[layer].attn.b_O

    return patched_attn_out

In [None]:
sender_layer = 3
sender_head = 2
receiver_layer = 11
receiver_head = 6

orig_cache = corrupted_cache
patch_cache = clean_cache

patched_attn_output = get_patched_layer_output(sender_layer, sender_head, orig_cache, patch_cache)
orig_attn_output_name = f'blocks.{sender_layer}.hook_attn_out'
orig_receiver_ln1_scale_name = f'blocks.{receiver_layer}.ln1.hook_scale'
        
res = (patched_attn_output - orig_cache[orig_attn_output_name]) @ model.blocks[receiver_layer].attn.W_Q[receiver_head] / orig_cache[orig_receiver_ln1_scale_name]
res.shape

torch.Size([8, 15, 64])

In [None]:
head_B = (11, 6)
head_B_W_Q = model.blocks[10].attn.W_Q[4]
head_B_dest = corrupted_cache['blocks.11.attn.hook_q'][:, :, head_B[1], :]
head_B_dest.shape

torch.Size([8, 15, 64])

In [None]:
model.blocks[10].ln1.hook_scale

HookPoint()

In [None]:
print(utils.get_act_name(11, "hook_attn_out"))

TypeError: ignored

In [None]:
clean_cache["blocks.11.ln1.hook_scale"].shape

torch.Size([8, 15, 1])

In [None]:
head_A = (3, 2)
head_A_cache_name = 'blocks.3.attn.hook_z'
layer_A_z_patched = corrupted_cache[head_A_cache_name] #[:, :, head_A[1], :]
head_A_z_source = clean_cache[head_A_cache_name][:, :, head_A[1], :]
layer_A_z_patched[:, :, head_A[1], :] = head_A_z_source 

In [None]:
layer_A_patched_attn_out = einsum(
    "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", 
    layer_A_z_patched, 
    model.blocks[3].attn.W_O) + model.blocks[3].attn.b_O

In [None]:
layer_A_patched_attn_out.shape

torch.Size([8, 15, 768])

In [None]:
model.blocks[3]

TransformerBlock(
  (ln1): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (ln2): LayerNormPre(
    (hook_scale): HookPoint()
    (hook_normalized): HookPoint()
  )
  (attn): Attention(
    (hook_k): HookPoint()
    (hook_q): HookPoint()
    (hook_v): HookPoint()
    (hook_z): HookPoint()
    (hook_attn_scores): HookPoint()
    (hook_pattern): HookPoint()
    (hook_result): HookPoint()
    (hook_rot_k): HookPoint()
    (hook_rot_q): HookPoint()
  )
  (mlp): MLP(
    (hook_pre): HookPoint()
    (hook_post): HookPoint()
  )
  (hook_q_input): HookPoint()
  (hook_k_input): HookPoint()
  (hook_v_input): HookPoint()
  (hook_attn_out): HookPoint()
  (hook_mlp_out): HookPoint()
  (hook_resid_pre): HookPoint()
  (hook_resid_post): HookPoint()
)

In [None]:
model.blocks[3].attn.W_Q[2].shape

torch.Size([768, 64])

In [None]:
model #.blocks[11].ln1.hook_scale

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_post): HookPoint()


In [None]:
def path_patching(
    model,
    patch_tokens,
    orig_tokens,
    sender_heads,
    receiver_hooks,
    positions=["end"],
    return_hooks=False,
    extra_hooks=[],  # when we call reset hooks, we may want to add some extra hooks after this, add these here
    freeze_mlps=False,  # recall in IOI paper we consider these "vital model components"
    have_internal_interactions=False,
):
    """
    Patch in the effect of `sender_heads` on `receiver_hooks` only
    (though MLPs are "ignored" if `freeze_mlps` is False so are slight confounders in this case - see Appendix B of https://arxiv.org/pdf/2211.00593.pdf)
    TODO fix this: if max_layer < model.cfg.n_layers, then let some part of the model do computations (not frozen)

    Args: 
      model (HookedTransformer):
      patch_tokens ()
    Returns:

    """

    # This function ensures that the patch is applied to every position in the position list.
    def patch_positions(z, source_act, hook, positions=["end"], verbose=False):
        for pos in positions:
            z[torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]] = source_act[
                torch.arange(patch_tokens.N), patch_tokens.word_idx[pos]
            ]
        return z

    # Puts all head names into a list.
    sender_hooks = []
    for layer, head_idx in sender_heads:
        if head_idx is None:
            sender_hooks.append((f"blocks.{layer}.hook_mlp_out", None))

        else:
            sender_hooks.append((f"blocks.{layer}.attn.hook_result", head_idx))

    # Makes lists of the names only of the sender and receiver hooks.
    sender_hook_names = [x[0] for x in sender_hooks]
    receiver_hook_names = [x[0] for x in receiver_hooks]

    # Forward pass A (in https://arxiv.org/pdf/2211.00593.pdf)
    # Similar to running with "clean" prompts
    sender_cache = {}
    model.reset_hooks()
    # This portion is not used.
    # for hook in extra_hooks:
    #     model.add_hook(*hook)
    model.cache_some(
        sender_cache, lambda x: x in sender_hook_names, suppress_warning=True
    )
    source_logits = model(patch_tokens.toks.long())

    # Forward pass B
    # Similar to running on "corrupted" prompts
    target_cache = {}
    model.reset_hooks()
    # for hook in extra_hooks:
    #     model.add_hook(*hook)
    model.cache_all(target_cache, suppress_warning=True)
    target_logits = model(orig_tokens.toks.long())

    # Forward pass C
    # Cache the receiver hooks
    # (adding these hooks first means we save values BEFORE they are overwritten)
    receiver_cache = {}
    model.reset_hooks()
    model.cache_some(
        receiver_cache,
        lambda x: x in receiver_hook_names,
        suppress_warning=True,
        verbose=False,
    )

    # "Freeze" intermediate heads to their orig_tokens values
    # Effectively, for each layer and head, we are patching in the activations from run B
    # (e.g. the "corrupted" prompts), which is equivalent to freezing their behavior to 
    # that of the "corrupted" prompts
    for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            for hook_template in [
                "blocks.{}.attn.hook_q",
                "blocks.{}.attn.hook_k",
                "blocks.{}.attn.hook_v",
            ]:
                hook_name = hook_template.format(layer)

                # this won't run in the default case
                # if have_internal_interactions and hook_name in receiver_hook_names:
                #     continue # stops the rest of the loop's code and goes to next iteration

                hook = get_act_hook(
                    patch_all,
                    alt_act=target_cache[hook_name],
                    idx=head_idx,
                    dim=2 if head_idx is not None else None,
                    name=hook_name,
                )
                model.add_hook(hook_name, hook)

        # also won't run in the default case
        # if freeze_mlps:
        #     hook_name = f"blocks.{layer}.hook_mlp_out"
        #     hook = get_act_hook(
        #         patch_all,
        #         alt_act=target_cache[hook_name],
        #         idx=None,
        #         dim=None,
        #         name=hook_name,
        #     )
        #     model.add_hook(hook_name, hook)

    # for hook in extra_hooks:
    #     model.add_hook(*hook)

    # These hooks will overwrite the freezing, for the sender heads - this should be forward pass C
    # In more detail, the sender heads will be patched with the new or "clean" activations, and the MLPs
    # and layer norms will be recomputed. Frozen heads won't be affected.
    for hook_name, head_idx in sender_hooks:
        assert not torch.allclose(sender_cache[hook_name], target_cache[hook_name]), (
            hook_name,
            head_idx,
        )
        hook = get_act_hook(
            partial(patch_positions, positions=positions),
            alt_act=sender_cache[hook_name],
            idx=head_idx,
            dim=2 if head_idx is not None else None,
            name=hook_name,
        )
        model.add_hook(hook_name, hook)
    receiver_logits = model(orig_tokens.toks.long())

    # Add (or return) all the hooks needed for forward pass D
    model.reset_hooks()
    hooks = []

    # Not used
    # for hook in extra_hooks:
    #     hooks.append(hook)

    for hook_name, head_idx in receiver_hooks:
        for pos in positions:
            if torch.allclose(
                receiver_cache[hook_name][torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]],
                target_cache[hook_name][torch.arange(orig_tokens.N), orig_tokens.word_idx[pos]],
            ):
                warnings.warn("Torch all close for {}".format(hook_name))
        
        # Get the cached receiver activations (cached before getting overwritten)
        hook = get_act_hook(
            partial(patch_positions, positions=positions),
            alt_act=receiver_cache[hook_name],
            idx=head_idx,
            dim=2 if head_idx is not None else None,
            name=hook_name,
        )
        hooks.append((hook_name, hook))

    model.reset_hooks()
    if return_hooks:
        return hooks
    else:
        for hook_name, hook in hooks:
            model.add_hook(hook_name, hook)
        return model