# Indirect Object Identification Circuit in Pythia

In [1]:
from IPython import get_ipython
from IPython.display import clear_output, display

ipython = get_ipython()
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [5]:
import os

import torch
from torch import Tensor
import numpy as np
import einops
from fancy_einsum import einsum
import circuitsvis as cv

import transformer_lens.utils as utils

from transformer_lens import HookedTransformer
import transformer_lens.patching as patching

from torch import Tensor
from jaxtyping import Float
import plotly.express as px

from functools import partial

from torchtyping import TensorType as TT

from path_patching_cm.path_patching import Node, IterNode, path_patch, act_patch
from path_patching_cm.ioi_dataset import IOIDataset, NAMES
from neel_plotly import imshow as imshow_n
from utils.head_metrics import S2I_head_metrics, BatchIOIDataset

from utils.visualization import imshow_p, plot_attention_heads

from utils.visualization_utils import (
    plot_attention_heads,
    scatter_attention_and_contribution,
    get_attn_head_patterns
)
from utils.backup_analysis import (
    load_model, 
    compute_copy_score,
    setup,
    get_metrics_and_attributions,
    run_ablated_model
)
from utils.cspa_main import (
    get_cspa_per_checkpoint,
    display_cspa_grids
)
from utils.metrics import logit_diff_denoising, logit_diff_noising

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [6]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def two_lines(tensor1, tensor2, renderer=None, **kwargs):
    px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [7]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f10ca222da0>

## Model Setup

In [8]:
TASK = 'ioi'
PERFORMANCE_METRIC = 'logit_diff'
BASE_MODEL = "pythia-160m"
VARIANT = None
MODEL_SHORTNAME = BASE_MODEL if not VARIANT else VARIANT[11:]
CHECKPOINT = 143000
CACHE = "model_cache"
DATASET_SIZE = 70

In [9]:
model = load_model(BASE_MODEL, VARIANT, CHECKPOINT, CACHE, device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-160m into HookedTransformer


## Data Setup

In [10]:
from utils.data_utils import generate_data_and_caches, _logits_to_mean_logit_diff

ioi_dataset, abc_dataset = generate_data_and_caches(model, N=DATASET_SIZE, verbose=True)

In [11]:
clean_logits, clean_cache = model.run_with_cache(ioi_dataset.toks)
corrupted_logits, corrupted_cache = model.run_with_cache(abc_dataset.toks)

clean_logit_diff = _logits_to_mean_logit_diff(clean_logits, ioi_dataset)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = _logits_to_mean_logit_diff(corrupted_logits, ioi_dataset)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 4.1341
Corrupted logit diff: -4.0757


In [12]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff

In [13]:
logit_diff_denoising_ioi = partial(
     logit_diff_denoising, 
     flipped_logit_diff=corrupted_logit_diff,
     clean_logit_diff=clean_logit_diff, 
     dataset=ioi_dataset
)
logit_diff_noising_ioi = partial(
     logit_diff_noising,
     clean_logit_diff=clean_logit_diff,
     flipped_logit_diff=corrupted_logit_diff,
     dataset=ioi_dataset)

## Activation Patching for Model Component Importance

### Attention Heads

In [14]:
results = act_patch(
    model=model,
    orig_input=abc_dataset.toks,
    new_cache=clean_cache,
    patching_nodes=IterNode("z"), # iterating over all heads' output in all layers
    patching_metric=logit_diff_denoising_ioi,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [15]:
imshow_p(
    results['z'] * 100,
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

### Head Output by Component

In [16]:
# iterating over all heads' output in all layers
results = act_patch(
    model=model,
    orig_input=abc_dataset.toks,
    new_cache=clean_cache,
    patching_nodes=IterNode(["z", "q", "k", "v", "pattern"]),
    patching_metric=logit_diff_denoising_ioi,
    verbose=True,
)

  0%|          | 0/720 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)
results['q'].shape = (layer=12, head=12)
results['k'].shape = (layer=12, head=12)
results['v'].shape = (layer=12, head=12)
results['pattern'].shape = (layer=12, head=12)


In [17]:
assert results.keys() == {"z", "q", "k", "v", "pattern"}
#assert all([r.shape == (12, 12) for r in results.values()])

imshow_p(
    torch.stack(tuple(results.values())) * 100,
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
    title="Patching output of attention heads (corrupted -> clean)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1500,
    margin={"r": 100, "l": 100}
)

## Path Patching

### Direct Effect

#### All Direct-Effect Heads

In [21]:
path_patch_resid_post = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode('z'), # This means iterate over all heads in all layers
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=logit_diff_noising_ioi,
    verbose=True
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [22]:
imshow_p(
    path_patch_resid_post['z'] * 100,
    title="Direct effect on logit diff (patch from head output -> final resid)",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

##### Top Positive

In [26]:
plot_attention_heads(-path_patch_resid_post['z'].cuda(), top_n=15, range_x=[0, 1.0])

Total logit diff contribution above threshold: 1.15


In [27]:
top_k = 10
DISPLAY_IDX = 5
top_heads = torch.topk(-path_patch_resid_post['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

##### Top Negative

In [28]:
plot_attention_heads(path_patch_resid_post['z'].cuda(), top_n=15, range_x=[0, 1.0])

Total logit diff contribution above threshold: 0.23


In [34]:
top_k = 2
DISPLAY_IDX = 0
top_heads = torch.topk(path_patch_resid_post['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

##### Output vs. Attention

In [23]:
logit_diff_denoising_ioi_t = partial(
     logit_diff_denoising, 
     flipped_logit_diff=corrupted_logit_diff,
     clean_logit_diff=clean_logit_diff, 
     dataset=ioi_dataset,
     return_tensor=True
)
logit_diff_noising_ioi_t = partial(
     logit_diff_noising,
     clean_logit_diff=clean_logit_diff,
     flipped_logit_diff=corrupted_logit_diff,
     dataset=ioi_dataset,
     return_tensor=True
)

ioi_metric = logit_diff_denoising_ioi
attn_head_pattern_all_pos_act_patch_results = patching.get_act_patch_attn_head_pattern_all_pos(model, abc_dataset.toks, clean_cache, logit_diff_denoising_ioi_t)
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, abc_dataset.toks, clean_cache, logit_diff_denoising_ioi_t)

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

In [24]:
from utils.visualization_utils import l_scatter
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_pattern_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

#### Name Mover Heads

In [25]:
copy_scores = torch.zeros((12, 12))
copy_score_masks = torch.zeros((12, 12))
for layer in range(12):
    for head in range(12):
        copy_scores[layer, head] = compute_copy_score(model, layer, head, ioi_dataset)
        if copy_scores[layer, head] > 0.75:
            copy_score_masks[layer, head] = 1

Copy circuit for head 4.10 (sign=1) : Top 5 accuracy: 8.571428571428571%
Copy circuit for head 5.0 (sign=1) : Top 5 accuracy: 100.0%
Copy circuit for head 5.9 (sign=1) : Top 5 accuracy: 23.333333333333332%
Copy circuit for head 6.6 (sign=1) : Top 5 accuracy: 2.857142857142857%
Copy circuit for head 7.8 (sign=1) : Top 5 accuracy: 33.80952380952381%
Copy circuit for head 7.11 (sign=1) : Top 5 accuracy: 0.4761904761904762%
Copy circuit for head 8.1 (sign=1) : Top 5 accuracy: 39.523809523809526%
Copy circuit for head 8.2 (sign=1) : Top 5 accuracy: 98.57142857142858%
Copy circuit for head 8.8 (sign=1) : Top 5 accuracy: 2.857142857142857%
Copy circuit for head 8.10 (sign=1) : Top 5 accuracy: 100.0%
Copy circuit for head 9.2 (sign=1) : Top 5 accuracy: 0.9523809523809524%
Copy circuit for head 9.4 (sign=1) : Top 5 accuracy: 20.952380952380953%
Copy circuit for head 9.6 (sign=1) : Top 5 accuracy: 6.190476190476191%
Copy circuit for head 9.7 (sign=1) : Top 5 accuracy: 2.380952380952381%
Copy cir

In [26]:
imshow_p(
    # mask the heads with insufficient copy score
    path_patch_resid_post['z'] * 100 * copy_score_masks,
    title="Direct effect heads with high copy score",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

#### Negative Name Mover Heads

In [25]:
neg_copy_scores = torch.zeros((12, 12))
neg_copy_score_masks = torch.zeros((12, 12))
for layer in range(12):
    for head in range(12):
        neg_copy_scores[layer, head] = compute_copy_score(model, layer, head, ioi_dataset, neg=True)
        if neg_copy_scores[layer, head] > 0.75:
            neg_copy_score_masks[layer, head] = 1

Copy circuit for head 6.5 (sign=-1) : Top 5 accuracy: 0.4761904761904762%
Copy circuit for head 8.9 (sign=-1) : Top 5 accuracy: 100.0%
Copy circuit for head 9.5 (sign=-1) : Top 5 accuracy: 20.476190476190474%


In [None]:
imshow_p(
    # mask the heads with insufficient copy score
    path_patch_resid_post['z'] * 100 * copy_score_masks,
    title="Direct effect heads with high copy score",
    labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=600,
    margin={"r": 100, "l": 100}
)

#### DES2I/Copy Suppression Heads

In [None]:
get_cspa_per_checkpoint(
 
)

#### Direct Effect Head Classification

In [27]:
DE_NMH = [(8, 1), (8, 2), (8, 10), (9, 4), (9, 6), (10, 7)]
DE_S2I = [(8, 9)]

##### Backup Head Activity

In [28]:
model, logit_diff_directions = setup(BASE_MODEL, VARIANT, dataset=ioi_dataset, checkpoint=CHECKPOINT)
orig_logits, orig_cache = model.run_with_cache(ioi_dataset.toks.long())
logit_diff, per_head_logit_diffs = get_metrics_and_attributions(model, orig_logits, orig_cache, ioi_dataset, logit_diff_directions=logit_diff_directions)

ablated_logits, ablated_cache = run_ablated_model(model, ioi_dataset, ablation_targets=DE_NMH)
ablated_logit_diff, per_head_ablated_logit_diffs = get_metrics_and_attributions(model, ablated_logits, ablated_cache, ioi_dataset, logit_diff_directions=logit_diff_directions)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-160m into HookedTransformer


In [29]:
print(f"Checkpoint {CHECKPOINT}:")
print(f"Heads ablated:            {DE_NMH}")
print(f"Original logit diff:      {logit_diff:.10f}")
print(f"Post ablation logit diff: {ablated_logit_diff:.10f}")
print(f"Logit diff % change:      {((ablated_logit_diff - logit_diff) / logit_diff) * 100:.2f}%")

Checkpoint 143000:
Heads ablated:            [(8, 1), (8, 2), (8, 10), (9, 4), (9, 6), (10, 7)]
Original logit diff:      4.1340785027
Post ablation logit diff: 3.5585453510
Logit diff % change:      -13.92%


In [30]:
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
#for layer, head in exclusions:
#    per_head_ablated_logit_diffs[layer, head] = 0

plot_attention_heads(
    delta/logit_diff, 
    title="Logit Diff Contribution From Backup Heads", 
    top_n=15, 
    range_x=[0, 0.5]
)

Total logit diff contribution above threshold: 0.15


### Contributors to Direct Effect Heads

#### Contributors to DE Name Mover Heads

##### Attention Out by Head

In [44]:
results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("q", layer, head=head) for layer, head in DE_NMH],
    patching_metric=logit_diff_noising_ioi,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [45]:
imshow_p(
        results["z"][:10] * 100,
        title=f"Direct effect on NMH' queries",
        labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=700,
        margin={"r": 100, "l": 100}
    )

In [46]:
plot_attention_heads(-results['z'].cuda(), top_n=10, range_x=[0, 1.0])

Total logit diff contribution above threshold: 0.79


In [47]:
top_k = 2
DISPLAY_IDX = 34
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [96]:
S2I_CANDIDATES = [(7, 4), (7, 6)]

In [99]:
for head in S2I_CANDIDATES:
    s2i_ioi_dataset = BatchIOIDataset(
        prompt_type="mixed",
        N=DATASET_SIZE,
        tokenizer=model.tokenizer,
        prepend_bos=False,
        seed=1234,
        device=str(device)
    )

    s2i_results = S2I_head_metrics(model, s2i_ioi_dataset, potential_s2i_list=[head], NMH_list=DE_NMH, batch_size=32)

    s2i_s2_attention = s2i_results['end_s2_attention_values'].mean(0)

    # logit diff change (lower is better)
    logit_diff_change = (s2i_results['new_logit_diffs'] - s2i_results['baseline_logit_diffs'].unsqueeze(1)).mean(0)

    # NMH s1 attention change (higher is better)
    nmh_s1_attention_change = (s2i_results['new_nmh_s1_attention_values'] - s2i_results['baseline_nmh_s1_attention_values'].unsqueeze(1)).mean(0).mean(-1)

    print(f"S2I Candidate: {head}:")
    print(f"S2I Head S2 Attention Score:          {s2i_s2_attention.item():.3f}")
    print(f"Logit Diff Change After Ablation:     {logit_diff_change.item():.3f}")
    print(f"NMHs' Mean S1 Attention Score Change: {nmh_s1_attention_change.item():.3f}")
    print("\n")

S2I Candidate: (7, 4):
S2I Head S2 Attention Score:          0.798
Logit Diff Change After Ablation:     -4.391
NMHs' Mean S1 Attention Score Change: 0.269


S2I Candidate: (7, 6):
S2I Head S2 Attention Score:          0.291
Logit Diff Change After Ablation:     -2.280
NMHs' Mean S1 Attention Score Change: 0.152




#### Contributors to S2 Inhibitors

##### Attention Out by Head

In [48]:
results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in DE_S2I],
    patching_metric=logit_diff_noising_ioi,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [49]:
imshow_p(
        results["z"][:10] * 100,
        title=f"Direct effect on NMH' queries",
        labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=700,
        margin={"r": 100, "l": 100}
    )

In [50]:
plot_attention_heads(-results['z'].cuda(), top_n=10, range_x=[0, 0.1])

Total logit diff contribution above threshold: 0.24


In [51]:
top_k = 5
DISPLAY_IDX = 45
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

### Second Level

#### Path Patching for S2-Inhibition Heads

In [62]:
receiver_heads = circuit_summary["DE_NMH"]["q_contributors"]["S2I"]

results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=IterNode("z"),
    receiver_nodes=[Node("v", layer, head=head) for layer, head in receiver_heads],
    patching_metric=logit_diff_noising_ioi,
    verbose=True,
)

  0%|          | 0/144 [00:00<?, ?it/s]

results['z'].shape = (layer=12, head=12)


In [63]:
imshow_p(
        results["z"][:10] * 100,
        title=f"Direct effect on S2Is' values",
        labels={"x": "Head", "y": "Layer", "color": "Logit diff variation"},
        coloraxis=dict(colorbar_ticksuffix = "%"),
        border=True,
        width=700,
        margin={"r": 100, "l": 100}
    )

In [64]:
plot_attention_heads(-results['z'].cuda(), top_n=10, range_x=[0, 1.0])

Total logit diff contribution above threshold: 0.35


In [65]:
top_k = 2
DISPLAY_IDX = 45
top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, model.to_string((ioi_dataset.toks[DISPLAY_IDX][:ioi_dataset.word_idx["end"][DISPLAY_IDX]+1])), heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)