# Indirect Object Identification Circuit in Pythia

In [1]:
!ls

README.md	    quick_start_pytorch.ipynb
circuits-over-time  quick_start_pytorch_images


In [8]:
%cd circuits-over-time/circuit_sketches

/notebooks/circuits-over-time/circuit_sketches


In [3]:
!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
!pip install transformer_lens
!pip install jaxtyping==0.2.13
!pip install einops
!pip install protobuf==3.20.*
!pip install plotly
!pip install torchtyping
!pip install git+https://github.com/neelnanda-io/neel-plotly.git

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-2ya4d3h2
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-2ya4d3h2
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit 1ec4a8f8fa4368e95500fe3a188b367500af5f98
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting torch<3.0,>=2.0
  Downloading torch-2.0.1-cp39-cp39-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting importlib-metadata<6.0.0,>=5.1.0
  Downloading importlib_metadata-5.2.0-py3-none-any.whl (21 kB)
Collecting nvidia-cusparse-cu11==11.7.4.91
  Downloading nvid

In [4]:
from IPython import get_ipython
from IPython.display import clear_output, display

ipython = get_ipython()
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [9]:
import os
import pathlib
from typing import List, Optional, Union

import torch
from torch import Tensor
import numpy as np
import yaml
import pickle
import einops
from fancy_einsum import einsum


import circuitsvis as cv

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import transformer_lens.patching as patching

from torch import Tensor
from tqdm.notebook import tqdm
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set
from rich import print as rprint

from typing import List, Union
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import re

from functools import partial

from torchtyping import TensorType as TT

from utils.path_patching import Node, IterNode, path_patch, act_patch
from neel_plotly import imshow as imshow_n

from utils.visualization import get_attn_head_patterns, imshow_p, plot_attention_heads, scatter_attention_and_contribution_simple
from utils.visualization import get_attn_pattern, plot_attention
from utils.circuit_analysis import get_logit_diff, logit_diff_denoising, logit_diff_noising

from visualization_utils import (
    plot_attention_heads,
    scatter_attention_and_contribution
)

if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"

In [10]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Reader!")


In [17]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [11]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f7540133e50>

## Model Setup

In [108]:
import huggingface_hub
huggingface_hub.notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [81]:
model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-160m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    #refactor_factored_attn_matrices=True,
)
model.set_use_hook_mlp_in(True)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m into HookedTransformer


## Data Setup

In [82]:
from circuit_utils import read_data, set_up_data

prompts, answers = read_data("../data/simple_dataset.txt")

io_positions = [4, 2, 2, 4, 2, 4, 2, 4]
s_positions = [2, 4, 4, 2, 4, 2, 4, 2]

clean_tokens, corrupted_tokens, answer_tokens = set_up_data(
    model, prompts, answers
)

clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_tokens = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)
print("Answer token indices", answer_tokens)

Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to
Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Answer token indices tensor([[ 6393,  2516],
        [ 2516,  6393],
        [ 6270,  5490],
        [ 5490,  6270],
        [ 5682, 24752],
        [24752,  5682],
        [ 8698, 22138],
        [22138,  8698]], device='cuda:0')


## Tool Setup

### Activation Patching

In [83]:
from circuit_utils import get_logit_diff, ioi_metric

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_tokens).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_tokens).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 4.7131
Corrupted logit diff: -4.7131


In [84]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff

ioi_metric = partial(ioi_metric, clean_baseline=CLEAN_BASELINE, corrupted_baseline=CORRUPTED_BASELINE, answer_token_indices=answer_tokens)

clean_baseline_ioi = ioi_metric(clean_logits)
corrupted_baseline_ioi = ioi_metric(corrupted_logits)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [85]:
def logit_diff_denoising(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch n_pairs 2"] = answer_tokens,
    flipped_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    return_tensor: bool = False,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on flipped input, and 1 when performance is same as on clean input.
    '''
    patched_logit_diff = get_logit_diff(logits, answer_tokens)
    ld = ((patched_logit_diff - flipped_logit_diff) / (clean_logit_diff  - flipped_logit_diff))
    if return_tensor:
        return ld
    else:
        return ld.item()


def logit_diff_noising(
        logits: Float[Tensor, "batch seq d_vocab"],
        clean_logit_diff: float = clean_logit_diff,
        corrupted_logit_diff: float = corrupted_logit_diff,
        answer_tokens: Float[Tensor, "batch n_pairs 2"] = answer_tokens,
        return_tensor: bool = False,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = get_logit_diff(logits, answer_tokens)
        ld = ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff))

        if return_tensor:
            return ld
        else:
            return ld.item()

logit_diff_denoising_tensor = partial(logit_diff_denoising, return_tensor=True)
logit_diff_noising_tensor = partial(logit_diff_noising, return_tensor=True)

In [86]:
# Whether to do the runs by head and by position, which are much slower
DO_SLOW_RUNS = True

## Direct Logit Attribution

In [87]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([8, 2, 768])
Logit difference directions shape: torch.Size([8, 768])


In [88]:
answer_residual_directions.shape

torch.Size([8, 2, 768])

In [89]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = clean_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",clean_logit_diff)

Final residual stream shape: torch.Size([8, 15, 768])
Calculated average logit diff: 4.713052749633789
Original logit difference: 4.713052272796631


### Logit Lens

In [90]:
from circuit_utils import residual_stack_to_logit_diff

accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, logit_diff_directions, prompts, clean_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers+1), hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

### Layer Attribution

In [91]:
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, logit_diff_directions, prompts, clean_cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

### Head Attribution

In [92]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, logit_diff_directions, prompts, clean_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


In [36]:
plot_attention_heads(per_head_logit_diffs/clean_logit_diff, top_n=15, range_x=[0, 1])

Total logit diff contribution above threshold: 1.19


### Attention Analysis

In [37]:
from visualization_utils import get_attn_head_patterns

top_k = 3
top_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [38]:
nmh_candidates = [(8, 10), (10, 7)]

In [40]:
scatter_attention_and_contribution(model, nmh_candidates[0], prompts, io_positions, s_positions, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


In [41]:
top_k = 2
top_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [124]:
scatter_attention_and_contribution(model, (8, 9), prompts, io_positions, s_positions, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


## Activation Patching for Model Component Importance

### Attention Heads

In [93]:
results = act_patch(
    model=model,
    orig_input=corrupted_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode("z"), # iterating over all heads' output in all layers
    patching_metric=logit_diff_denoising,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [94]:
imshow_p(
    results['z'] * 100,
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

### Head Output by Component

In [44]:
# iterating over all heads' output in all layers
results = act_patch(
    model=model,
    orig_input=corrupted_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["z", "q", "k", "v", "pattern"]),
    patching_metric=logit_diff_denoising,
    verbose=True,
)

  0%|          | 0/720 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)
results['q'].shape = (layer=12, head=12)
results['k'].shape = (layer=12, head=12)
results['v'].shape = (layer=12, head=12)
results['pattern'].shape = (layer=12, head=12)


In [45]:
assert results.keys() == {"z", "q", "k", "v", "pattern"}
#assert all([r.shape == (12, 12) for r in results.values()])

imshow_p(
    torch.stack(tuple(results.values())) * 100,
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1500,
    margin={"r": 100, "l": 100}
)

### Residual Stream & Layer Outputs

In [46]:
results = act_patch(
    model=model,
    orig_input=corrupted_tokens,
    new_cache=clean_cache,
    patching_nodes=IterNode(["resid_pre", "attn_out", "mlp_out"], seq_pos="each"),
    patching_metric=logit_diff_denoising,
    verbose=True,
)

  0%|          | 0/540 [00:00<?, ?it/s]

results['resid_pre'].shape = (seq_pos=15, layer=12)
results['attn_out'].shape = (seq_pos=15, layer=12)
results['mlp_out'].shape = (seq_pos=15, layer=12)


In [52]:
assert results.keys() == {"resid_pre", "attn_out", "mlp_out"}
labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]
imshow_p(
    torch.stack([r.T for r in results.values()]) * 100, # we transpose so layer is on the y-axis
    facet_col=0,
    facet_labels=["resid_pre", "attn_out", "mlp_out"],
    title="Patching at resid stream & layer outputs (corrupted -> clean)",
    labels={"x": "Sequence position", "y": "Layer", "color": "Logit diff variation"},
    x=labels,
    xaxis_tickangle=45,
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1400,
    height=600,
    margin={"r": 100, "l": 100}
)

## Circuit Sketching

### First Level

#### Heads Influencing Logit Diff

In [99]:
path_patch_resid_post = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=logit_diff_noising,
    verbose=True
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [100]:
imshow_p(
    path_patch_resid_post['z'] * 100,
    title="Direct effect on logit diff (patch from head output -> final resid)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

In [101]:
plot_attention_heads(-path_patch_resid_post['z'].cuda(), top_n=15, range_x=[0, 1.0])

Total logit diff contribution above threshold: 1.28


In [103]:
top_k = 3
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [105]:
DE_NMH = [(8, 10), (10, 7)]
DE_S2I = [(8, 9)]

In [106]:
# V-weighted version
plot_attention(
    model, 
    prompts[0],
    DE_NMH,
    clean_cache,
    weighted=True)

In [63]:
from visualization_utils import l_scatter
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_pattern_all_pos_act_patch_results['z'].flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results['z'].flatten()), 
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

In [67]:
plot_attention_heads(
    -attn_head_pattern_all_pos_act_patch_results['z'], 
    title="Head Importance for Path Patching to Resid Post Receiver", 
    top_n=15, 
    range_x=[0, 1.0]
)

Total logit diff contribution above threshold: 1.28


#### NMH Knockout

##### All Heads

In [107]:
heads_to_ablate = DE_NMH

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")

Heads to ablate: [(8, 10), (10, 7)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.9398294687271118


In [108]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, logit_diff_directions, prompts, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow_n(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"}, zmin=-1.5, zmax=1.5, title="Post-Ablation Direct Logit Attribution of Heads")
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


In [126]:
exclusions = [(6, 6), (7, 9), (8, 9)] + [(9, 1), (9, 5)]
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
for layer, head in exclusions:
    per_head_ablated_logit_diffs[layer, head] = 0

plot_attention_heads(
    per_head_ablated_logit_diffs/clean_logit_diff, 
    title="Logit Diff Contribution From Backup Heads", 
    top_n=15, 
    range_x=[0, 0.5]
)

Total logit diff contribution above threshold: 0.37


##### Individual Heads

In [73]:
# Get indices of all heads where the ablation had a positive effect
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
backup_nmh_candidates = np.argwhere(delta.cpu().detach().numpy() > 0.05)
backup_nmh_candidates = [tuple(h) for h in backup_nmh_candidates]
backup_nmh_candidates = [h for h in backup_nmh_candidates if h not in exclusions]
print(f"Backup NMH Candidates: {backup_nmh_candidates}")
for l, h in backup_nmh_candidates:
    for layer, head in heads_to_ablate:
        ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
        model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
    scatter_attention_and_contribution(model, (l, h), prompts, io_positions, s_positions, answer_residual_directions)

Backup NMH Candidates: [(9, 6)]
Tried to stack head results when they weren't cached. Computing head results now


RuntimeError: The expanded size of the tensor (1) must match the existing size (8) at non-singleton dimension 0.  Target sizes: [1, 15, 64].  Tensor sizes: [8, 15, 64]

In [76]:
attn_head_pattern_all_pos_act_patch_results['z'].shape

torch.Size([12, 12])

In [77]:
top_k = 5
top_heads = torch.topk(-attn_head_pattern_all_pos_act_patch_results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts, heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [78]:
# V-weighted version
plot_attention(
    model, 
    prompts[0],
    nmh_candidates,
    clean_cache,
    weighted=True)

### Contributors to NMHs

#### Attention Out by Position

In [115]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode(node_names=["attn_out"], seq_pos="each"),
    receiver_nodes=[Node("q", layer, head=head) for layer, head in DE_NMH],
    patching_metric=logit_diff_noising,
    verbose=True,
)
results = einops.rearrange(results['attn_out'], "seq layer -> layer seq")

  0%|          | 0/180 [00:00<?, ?it/s]

results['attn_out'].shape = (seq_pos=15, layer=12)


In [116]:
imshow_n(
        results * 100,
        title=f"Direct effect on DE Heads' values",
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=[f"Layer {layer}" for layer in range(model.cfg.n_layers)],
        width=1500,
        height=600,
    )

#####

##### Attention Out by Head

In [119]:
results = path_patch(
    model,
    orig_input=clean_tokens,
    new_input=corrupted_tokens,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("q", layer, head=head) for layer, head in DE_NMH],
    patching_metric=logit_diff_noising,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [120]:
imshow_p(
        results["z"][:10] * 100,
        title=f"Direct effect on NMH' queries",
        labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=700,
        margin={"r": 100, "l": 100}
    )

In [122]:
plot_attention_heads(-results['z'].cuda(), top_n=30, range_x=[0, 1.0])

Total logit diff contribution above threshold: 0.59


In [123]:
top_k = 5
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, prompts[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

### Second Level

#### Attention Pattern for Second-Level Heads

In [149]:
second_level_positive_heads = [(6, 6), (7, 9), (8, 9)]

tokens, attn, names = get_attn_head_patterns(model, prompts[0], second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

#second_level_negative_heads = [(7, 8), (8, 10)]
#visualize_attention_patterns(torch.tensor([l*12+h for l, h in second_level_negative_heads]), title=f"Top Negative Second Level IOI Metric Heads")

In [66]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_v_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    xaxis="Value Patch",
    yaxis="Output Patch",
    #caxis="Layer",
    hover_name = head_labels,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-1.5, 1.5),
    range_y=(-1.5, 1.5),
    title="Scatter plot of output patching vs value patching")

In [65]:
s2i_candidates = [(6, 6), (7, 2), (7, 9), (8, 9)]
#s2i_candidates = [(8, 9)]

#### S2I Knockout

##### All Heads

In [69]:
heads_to_ablate = s2i_candidates

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")
#print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
#print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Heads to ablate: [(6, 6), (7, 2), (7, 9), (8, 9)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.5408056974411011


In [70]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


#### Path Patching for S2-Inhibition Candidates

In [None]:
model.to_str_tokens(clean_tokens[3])

['<|endoftext|>',
 'When',
 ' Tom',
 ' and',
 ' James',
 ' went',
 ' to',
 ' the',
 ' park',
 ',',
 ' Tom',
 ' gave',
 ' the',
 ' ball',
 ' to']

In [71]:
receiver_heads = second_level_positive_heads

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_v", head_idx) for layer_idx, head_idx in receiver_heads],
                positions=10
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            iot_metric_res = ioi_metric(path_patched_logits)
            metric_delta_results[layer, head_idx] = -(clean_baseline_ioi - iot_metric_res) / clean_baseline_ioi

In [72]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers")#, zmin=-0.02, zmax=0.02)

### Third Level

#### Attention Patterns for Third-Level Heads

We have a mix of induction heads and duplicate token heads here, as well as two heads that focus on S2 at S2.

In [None]:
third_level_positive_heads = [(4, 6), (4, 11), (5, 10)]

tokens, attn, names = get_attn_head_patterns(model, example_prompt, second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)