# Indirect Object Identification Circuit in Pythia

In [1]:
from IPython import get_ipython

ipython = get_ipython()
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
import os
from functools import partial

import torch
from torchtyping import TensorType as TT
import numpy as np
import pandas as pd

import einops
from fancy_einsum import einsum

from transformers import AutoModelForCausalLM
import transformer_lens.utils as utils
import transformer_lens.patching as patching
from transformer_lens import HookedTransformer
from neel_plotly import line, imshow, scatter

from visualization_utils import (
    plot_attention_heads,
    scatter_attention_and_contribution
)

if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"

In [3]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Reader!")


In [4]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f732843ec40>

## Model Setup

In [27]:
import huggingface_hub
huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [28]:
source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m-seed1")

model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-125m",
    hf_model=source_model,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    #refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-125m into HookedTransformer


## Data Setup

In [29]:
from circuit_utils import read_data, set_up_data

prompts, answers = read_data("../data/simple_dataset.txt")

io_positions = [4, 2, 2, 4, 2, 4, 2, 4]
s_positions = [2, 4, 4, 2, 4, 2, 4, 2]

clean_tokens, corrupted_tokens, answer_token_indices = set_up_data(
    model, prompts, answers
)

clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)
print("Answer token indices", answer_token_indices)

Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to
Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Answer token indices tensor([[ 6393,  2516],
        [ 2516,  6393],
        [ 6270,  5490],
        [ 5490,  6270],
        [ 5682, 24752],
        [24752,  5682],
        [ 8698, 22138],
        [22138,  8698]], device='cuda:0')


## Tool Setup

### Activation Patching

In [30]:
from circuit_utils import get_logit_diff, ioi_metric

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 6.3709
Corrupted logit diff: -6.3709


In [31]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff

ioi_metric = partial(ioi_metric, clean_baseline=CLEAN_BASELINE, corrupted_baseline=CORRUPTED_BASELINE, answer_token_indices=answer_token_indices)

clean_baseline_ioi = ioi_metric(clean_logits)
corrupted_baseline_ioi = ioi_metric(corrupted_logits)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [32]:
# Whether to do the runs by head and by position, which are much slower
DO_SLOW_RUNS = True

### Path Patching

In [33]:
from circuit_utils import patch_pos_head_vector, patch_head_vector, path_patching

## Direct Logit Attribution

In [34]:
answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([8, 2, 768])
Logit difference directions shape: torch.Size([8, 768])


In [35]:
answer_residual_directions.shape

torch.Size([8, 2, 768])

In [36]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = clean_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",clean_logit_diff)

Final residual stream shape: torch.Size([8, 15, 768])
Calculated average logit diff: 6.370857238769531
Original logit difference: 6.370858192443848


### Logit Lens

In [37]:
from circuit_utils import residual_stack_to_logit_diff

accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, logit_diff_directions, prompts, clean_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers+1), hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

### Layer Attribution

In [38]:
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, logit_diff_directions, prompts, clean_cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

### Head Attribution

In [39]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, logit_diff_directions, prompts, clean_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


In [40]:
plot_attention_heads(per_head_logit_diffs/clean_logit_diff, top_n=15, range_x=[0, 1])

Total logit diff contribution above threshold: 1.04


### Attention Analysis

In [41]:
from visualization_utils import get_attn_head_patterns

top_k = 3
top_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [42]:
nmh_candidates = [(7, 9), (9, 4)]

In [43]:
scatter_attention_and_contribution(model, nmh_candidates[1], prompts, io_positions, s_positions, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


In [44]:
from ioi_dataset import IOIDataset
N = 50
ioi_dataset = IOIDataset(
    prompt_type="mixed",
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
)  # TODO make this a seeded dataset



Some groups have less than 5 prompts, they have lengths [1, 3, 2, 1, 2, 2, 4, 3, 1, 4]



In [45]:
#ioi_dataset.ioi_prompts

In [63]:
top_k = 3
top_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [47]:
scatter_attention_and_contribution(model, (9, 5), prompts, io_positions, s_positions, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


## Activation Patching for Model Component Importance

### Residual Stream

In [48]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'resid_pre' Activation Patching")

  0%|          | 0/180 [00:00<?, ?it/s]

### MLP Layers

In [49]:
resid_pre_act_patch_results = patching.get_act_patch_mlp_out(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'mlp_out' Activation Patching")

  0%|          | 0/180 [00:00<?, ?it/s]

### Attention Layers

In [50]:
resid_pre_act_patch_results = patching.get_act_patch_attn_out(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'attn_out' Activation Patching")

  0%|          | 0/180 [00:00<?, ?it/s]

### Attention Layers by Head

In [51]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_out' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

In [52]:
attn_head_v_all_pos_act_patch_results = patching.get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_v' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

In [53]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
    attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")
    imshow(attn_head_out_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_out Activation Patching By Pos")

  0%|          | 0/2160 [00:00<?, ?it/s]

### Multiple Activation Types

In [54]:
every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_block_result, facet_col=0, facet_labels=["Residual Stream", "Attn Output", "MLP Output"], title="Activation Patching Per Block", xaxis="Position", yaxis="Layer", zmax=1, zmin=-1, x= [f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

### Head Component Output

In [55]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer") #, zmax=1, zmin=-1)

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

## Circuit Sketching

### First Level

#### Attention Pattern Patching

In [56]:
attn_head_pattern_all_pos_act_patch_results = patching.get_act_patch_attn_head_pattern_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_pattern_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_pattern' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

In [57]:
from visualization_utils import l_scatter
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_pattern_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

In [58]:
# set earlier heads (non-NMH) to 0 so they don't dominate the plot
attn_head_pattern_all_pos_act_patch_results[:6, :] = 0
plot_attention_heads(
    attn_head_pattern_all_pos_act_patch_results, 
    title="Logit Diff Correction From Patching in Clean Attention Patterns", 
    top_n=15, 
    range_x=[0, 0.5]
)

Total logit diff contribution above threshold: 0.61


#### NMH Knockout

##### All Heads

In [69]:
heads_to_ablate = nmh_candidates

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")

Heads to ablate: [(7, 9), (9, 4)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.8511136174201965


In [70]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, logit_diff_directions, prompts, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"}, zmin=-1.5, zmax=1.5, title="Post-Ablation Direct Logit Attribution of Heads")
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


In [71]:
exclusions = [(6, 7), (7, 6)] # + [(9, 1), (9, 5)]
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
for layer, head in exclusions:
    per_head_ablated_logit_diffs[layer, head] = 0

plot_attention_heads(
    delta/clean_logit_diff, 
    title="Logit Diff Contribution From Backup Heads", 
    top_n=15, 
    range_x=[0, 0.5]
)

Total logit diff contribution above threshold: 0.07


##### Individual Heads

In [103]:
# Get indices of all heads where the ablation had a positive effect
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
backup_nmh_candidates = np.argwhere(delta.cpu().detach().numpy() > 0.05)
backup_nmh_candidates = [tuple(h) for h in backup_nmh_candidates]
backup_nmh_candidates = [h for h in backup_nmh_candidates if h not in exclusions]
print(f"Backup NMH Candidates: {backup_nmh_candidates}")
for l, h in backup_nmh_candidates:
    for layer, head in heads_to_ablate:
        ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
        model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
    scatter_attention_and_contribution(model, (l, h), prompts, io_positions, s_positions, answer_residual_directions)

Backup NMH Candidates: [(9, 6)]
Tried to stack head results when they weren't cached. Computing head results now


In [None]:
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
tokens, attn, names = get_attn_head_patterns(model, prompts[0], backup_nmh_candidates)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

#### Path Patching for NMH Candidate Receivers

In [37]:
receiver_heads = nmh_candidates

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_q", head_idx) for layer_idx, head_idx in receiver_heads],
                positions=-1
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            iot_metric_res = ioi_metric(path_patched_logits)
            metric_delta_results[layer, head_idx] = -(clean_baseline_ioi - iot_metric_res) / clean_baseline_ioi

In [38]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers", zmin=-0.3, zmax=0.3)

In [39]:
plot_attention_heads(-metric_delta_results, title="Logit Diff Drop with Corrupted Path Patch", top_n=15, range_x=[0, 0.50])

Total logit diff contribution above threshold: 0.54


### Second Level

#### Attention Pattern for Second-Level Heads

In [40]:
second_level_positive_heads = [(6, 7), (7, 6)]

tokens, attn, names = get_attn_head_patterns(model, prompts[0], second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

#second_level_negative_heads = [(7, 8), (8, 10)]
#visualize_attention_patterns(torch.tensor([l*12+h for l, h in second_level_negative_heads]), title=f"Top Negative Second Level IOI Metric Heads")

In [43]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_v_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    xaxis="Value Patch",
    yaxis="Output Patch",
    #caxis="Layer",
    hover_name = head_labels,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-1.5, 1.5),
    range_y=(-1.5, 1.5),
    title="Scatter plot of output patching vs value patching")

In [44]:
s2i_candidates = second_level_positive_heads
#s2i_candidates = [(8, 9)]

#### S2I Knockout

##### All Heads

In [45]:
heads_to_ablate = s2i_candidates

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")
#print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
#print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Heads to ablate: [(6, 7), (7, 6)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.5281655788421631


In [70]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


#### Path Patching for S2-Inhibition Candidates

In [None]:
model.to_str_tokens(clean_tokens[3])

['<|endoftext|>',
 'When',
 ' Tom',
 ' and',
 ' James',
 ' went',
 ' to',
 ' the',
 ' park',
 ',',
 ' Tom',
 ' gave',
 ' the',
 ' ball',
 ' to']

In [71]:
receiver_heads = second_level_positive_heads

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_v", head_idx) for layer_idx, head_idx in receiver_heads],
                positions=10
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            iot_metric_res = ioi_metric(path_patched_logits)
            metric_delta_results[layer, head_idx] = -(clean_baseline_ioi - iot_metric_res) / clean_baseline_ioi

In [72]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers")#, zmin=-0.02, zmax=0.02)

### Third Level

#### Attention Patterns for Third-Level Heads

We have a mix of induction heads and duplicate token heads here, as well as two heads that focus on S2 at S2.

In [None]:
third_level_positive_heads = [(4, 6), (4, 11), (5, 10)]

tokens, attn, names = get_attn_head_patterns(model, example_prompt, second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)