# Indirect Object Identification Circuit in Pythia

In [93]:
from IPython import get_ipython

ipython = get_ipython()
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload








In [94]:
import os
from functools import partial

import torch
from torchtyping import TensorType as TT
import numpy as np
import pandas as pd

import einops
from fancy_einsum import einsum

from transformers import AutoModelForCausalLM
import transformer_lens.utils as utils
import transformer_lens.patching as patching
from transformer_lens import HookedTransformer
from neel_plotly import line, imshow, scatter

from visualization_utils import (
    plot_attention_heads,
    scatter_attention_and_contribution
)

if torch.cuda.is_available():
    device = int(os.environ.get("LOCAL_RANK", 0))
else:
    device = "cpu"

In [95]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Reader!")

In [96]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f16b217fd90>

## Model Setup

In [97]:
# import huggingface_hub
# huggingface_hub.notebook_login()

In [98]:
source_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m")

model = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-410m",
    hf_model=source_model,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    #refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-410m into HookedTransformer


Data requirements:
- The correct token should be ranked higher than the incorrect token
- Negating "all" should flip the expected next token
- Prompts should end with "are" or "is"
- Both the positive and negative prompts should be the same number of tokens


In [7]:
#example_prompt = "John moved the chair into the dining room. Jane then pushed the chair to the deck. Jack retrieves the chair from the "
#example_prompt =  "Mary went to the school to pick up her daughter, Sarah. Mary is Sarah's"
#example_prompt = "All dogs are mammals, and no mammals are cold-blooded. Therefore, all dogs are"
example_prompt = "All flowers are angiosperms. Daisies are flowers. Therefore, daisies are"
example_answer = " winged"
print(len(model.to_str_tokens(combinations[0])))
utils.test_prompt(combinations[0], "end", model, prepend_bos=True)

## Data Setup

In [405]:
import openai
import re

# Replace this with your OpenAI API key
openai.api_key = "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"

def generate_text(prompt, num_tokens=50):
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.8,
    )
    return response.choices[0].message.content.strip()

def extract_candidates(text, desired_token_count):
    words = re.findall(r'\b\w+\b', text)
    return [word for word in words if is_token_count_correct(word, desired_token_count)]

def is_token_count_correct(input_str, desired_token_count):
    token_count = len(model.to_str_tokens(' ' + input_str)) - 1
    print(f"{model.to_str_tokens(' ' + input_str)} has {token_count} tokens")
    return token_count == desired_token_count

def build_examples(categories, descriptors, capitalized_nouns, plural_nouns, num_examples=20):
    if not categories or not descriptors or not capitalized_nouns or not plural_nouns:
        raise ValueError("One or more of the input lists are empty. Please generate more candidates or adjust the token count requirements.")

    examples = []
    for i in range(num_examples):
        category = categories[i % len(categories)]
        descriptor = descriptors[i % len(descriptors)]
        capitalized_noun = capitalized_nouns[i % len(capitalized_nouns)]
        plural_noun = plural_nouns[i % len(plural_nouns)]

        example = f"All {category} are {descriptor}. {capitalized_noun} are {category}. Therefore, {plural_noun} are"
        examples.append(example)

    return examples

# Generate candidates for each component
categories_text = generate_text("List 50 categories of animals (lowercase, separated by commas, single-word only), but not any specific species (like cats or rats):")
print(categories_text)
descriptors_text = generate_text("List 50 descriptors that could apply to animals (e.g., warm-blooded, oviparous, etc. Use more advanced words.) (lowercase, separated by commas):")
print(descriptors_text)
capitalized_nouns_text = generate_text("List 50 capitalized plural nouns, each of which should be the name of a specific animal (e.g., 'Cat' or 'Rat') (lowercase, separated by commas):")
print(capitalized_nouns_text)
# lowercase versions of the capitalized nouns
plural_nouns_text = capitalized_nouns_text.lower()
print(plural_nouns_text)

# Set desired token counts
category_token_count = 1
descriptor_token_count = 3
capitalized_noun_token_count = 2
plural_noun_token_count = 1

# Extract candidates
categories = extract_candidates(categories_text, category_token_count)
descriptors = extract_candidates(descriptors_text, descriptor_token_count)
capitalized_nouns = extract_candidates(capitalized_nouns_text, capitalized_noun_token_count)
plural_nouns = extract_candidates(plural_nouns_text, plural_noun_token_count)

# Build examples
examples = build_examples(categories, descriptors, capitalized_nouns, plural_nouns)
print('\n'.join(examples))

mammals, reptiles, fish, insects, birds, amphibians, arachnids, crustaceans, mollusks, echinoderms, cnidarians, poriferans, annelids, platyhelminthes, nematodes, tardigrades, onychophorans, myriapods, chelicerates, bryozoans, brachiopods, tunicates, lancelets, cephalopods, gastropods, bivalves, scorpions, centipedes, millipedes, ants, bees, butterflies, beetles, cicadas, crickets, cockroaches, dragonflies, grasshoppers, moths, flies, termites, wasps, seahorses, whales, dolphins, gorillas, chimpanzees, kangaroos, elephants, giraffes, rhinoceroses, hippopotamuses.
endothermic, ectothermic, viviparous, monotreme, marsupial, placental, herbivorous, carnivorous, omnivorous, nocturnal, diurnal, arboreal, terrestrial, aquatic, fossorial, cursorial, volant, aquatic, amphibious, venomous, poisonous, camouflaged, bipedal, quadrupedal, hexapedal, octopedal, sessile, migratory, hibernating, noctilio, noctiphilic, pelagic, poikilothermic, furtive, cryptic, territorial, gregarious, solitary, cursori

In [99]:
import itertools
def generate_combinations(categories, descriptors, capitalized_nouns, answer_first_tokens, wrong_answer_first_tokens):
    prompts = []
    answers = []
    opposite_prompts = []
    # shift the descriptors by one so that the first descriptor is the answer
    opposites = descriptors[1:] + [descriptors[0]]
    for category, descriptor, capitalized_noun in itertools.product(categories, descriptors, capitalized_nouns):
        lowercase_noun = capitalized_noun.lower()
        opposite = opposites[descriptors.index(descriptor)]
        example = f"All {category} are {descriptor}, not {opposite}. {capitalized_noun} are {category}. Therefore, {lowercase_noun} are"
        opposite = f"All {category} are {opposite}, not {descriptor}. {capitalized_noun} are {category}. Therefore, {lowercase_noun} are"
        prompts.append(example)
        opposite_prompts.append(opposite)
        # get index of answer
        answer_index = descriptors.index(descriptor)
        answers.append((answer_first_tokens[answer_index], wrong_answer_first_tokens[answer_index]))

    prompt_tokens = [model.to_tokens(prompt).squeeze() for prompt in prompts]
    opposites_tokens = [model.to_tokens(prompt).squeeze() for prompt in opposite_prompts]
    answer_tokens = [[model.to_single_token(answer) for answer in answer_pair] for answer_pair in answers]
    answer_tokens = torch.tensor(answer_tokens, device=device)
    return prompts, opposite_prompts, answers, prompt_tokens, opposites_tokens, answer_tokens 

In [100]:
def check_len(items, specified_len, check_lowercase=False, lowercase_spec_len=None):
    new_items = []
    for item in items:
        length = len(model.to_str_tokens(" " + item)) - 1
        lowercase_len = len(model.to_str_tokens(" " + item.lower())) - 1
        if check_lowercase:
            if length == specified_len and lowercase_len == lowercase_spec_len:
                new_items.append(item)
        else:
            if length == specified_len:
                new_items.append(item)
    return new_items

In [101]:
categories = ['birds', 'fish', 'insects', 'mammals', 'primates', 'rodents', 'predators', 'prey', 'game']
categories = check_len(categories, 1)

descriptors = [
    'endothermic',
    'ectothermic',
    'herbivorous',
    'carnivorous',
    'bipedal',
    'quadrupedal',
    'exoskeletal',
    'endoskeletal',
    'digitigrade',
    'plantigrade'
    ]
descriptors = check_len(descriptors, 3)

answer_first_tokens = [model.to_str_tokens(" " + d)[1] for d in descriptors]
wrong_answer_first_tokens = []
for i in range(len(answer_first_tokens)):
    if i % 2 == 0:
        wrong_answer_first_tokens.append(answer_first_tokens[i + 1])
    else:
        wrong_answer_first_tokens.append(answer_first_tokens[i - 1])

capitalized_nouns = ['Dogs', 'Elephants', 'Lions', 'Zebras', 'Cows', 'Horses',
'Pigs', 'Monkeys', 'Wolves', 'Deer', 'Foxes', 'Bears', 'Mice', 'Bats', 'Snakes',
'Lizards', 'Seals', 'Whales', 'Cats', 'Squid', 'Dolphins', 'Crabs', 'Snails',
'Bees', 'Wasps', 'Emus', 'Ducks', 'Owls', 'Pigs', 'Goats'
]
capitalized_nouns = check_len(capitalized_nouns, 2, check_lowercase=True, lowercase_spec_len=1)

prompts, opposites_prompts, answers, prompt_tokens, opposites_tokens, answer_tokens = generate_combinations(categories, descriptors, capitalized_nouns, answer_first_tokens, wrong_answer_first_tokens)
len(prompts)

1440

In [103]:
from circuit_utils import get_logit_diff, ioi_metric

# check answer tokens
for i in range(len(prompts)):
    if prompt_tokens[i][4] != answer_tokens[i][0]:
        print(prompt_tokens[i])

clean_prompts = []
clean_opposites_prompts = []
clean_answer_tokens = []

# check logit differences
for i in range(len(prompts)):
    test_logits, test_cache = model.run_with_cache(prompt_tokens[i].unsqueeze(0).to(device))
    opposite_logits, opposite_cache = model.run_with_cache(opposites_tokens[i].unsqueeze(0).to(device))
    logit_diffs = get_logit_diff(test_logits, answer_tokens[i].unsqueeze(0), per_prompt=False)
    opposite_logit_diffs = get_logit_diff(opposite_logits, answer_tokens[i].unsqueeze(0), per_prompt=False)

    if logit_diffs.item() > 2 and opposite_logit_diffs.item() < -2:
        clean_prompts.append(prompts[i])
        clean_opposites_prompts.append(opposites_prompts[i])
        clean_answer_tokens.append(answer_tokens[i])

In [116]:
# take random 128 samples from clean dataset

import random

random.seed(42)
dataset_size = len(clean_prompts)
random_indices = random.sample(range(dataset_size), 56)
#random_indices.sort()

clean_prompts = [clean_prompts[i] for i in random_indices]
clean_opposites_prompts = [clean_opposites_prompts[i] for i in random_indices]
clean_answer_tokens = torch.stack([clean_answer_tokens[i] for i in random_indices])

In [117]:
def pair_opposites(prompts_list, opposites_prompts_list, answer_tokens_list):
    prompts = []
    answer_tokens = []
    for i in range(len(prompts_list)):
        if prompts_list[i] not in prompts:
            prompts.append(prompts_list[i])
            answer_tokens.append(answer_tokens_list[i])
        if opposites_prompts_list[i] not in prompts:
            prompts.append(opposites_prompts_list[i])
            answer_tokens.append(torch.tensor([answer_tokens_list[i][1], answer_tokens_list[i][0]], device=device))
    return prompts, answer_tokens

prompts, answer_tokens = pair_opposites(
    clean_prompts, 
    clean_opposites_prompts,
    clean_answer_tokens)

In [118]:
prompts[:10], answer_tokens[:10]

(['All predators are exoskeletal, not endoskeletal. Bats are predators. Therefore, bats are',
  'All predators are endoskeletal, not exoskeletal. Bats are predators. Therefore, bats are',
  'All birds are exoskeletal, not endoskeletal. Goats are birds. Therefore, goats are',
  'All birds are endoskeletal, not exoskeletal. Goats are birds. Therefore, goats are',
  'All birds are exoskeletal, not endoskeletal. Pigs are birds. Therefore, pigs are',
  'All birds are endoskeletal, not exoskeletal. Pigs are birds. Therefore, pigs are',
  'All predators are exoskeletal, not endoskeletal. Pigs are predators. Therefore, pigs are',
  'All predators are endoskeletal, not exoskeletal. Pigs are predators. Therefore, pigs are',
  'All prey are exoskeletal, not endoskeletal. Horses are prey. Therefore, horses are',
  'All prey are endoskeletal, not exoskeletal. Horses are prey. Therefore, horses are'],
 [tensor([385, 990], device='cuda:0'),
  tensor([990, 385], device='cuda:0'),
  tensor([385, 990], 

In [119]:
clean_tokens = model.to_tokens(prompts)
corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.stack(answer_tokens)


Clean string 0 <|endoftext|>All predators are exoskeletal, not endoskeletal. Bats are predators. Therefore, bats are
Corrupted string 0 <|endoftext|>All predators are endoskeletal, not exoskeletal. Bats are predators. Therefore, bats are


In [120]:
answer_token_indices.shape

torch.Size([104, 2])

## Tool Setup

### Activation Patching

In [121]:
def get_logit_diff(logits, answer_token_indices, per_prompt=False):
    """Gets the difference between the logits of the provided tokens (e.g., the correct and incorrect tokens in IOI)

    Args:
        logits (torch.Tensor): Logits to use.
        answer_token_indices (torch.Tensor): Indices of the tokens to compare.

    Returns:
        torch.Tensor: Difference between the logits of the provided tokens.
    """
    if len(logits.shape) == 3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    if per_prompt:
        print(correct_logits - incorrect_logits)

    return (correct_logits - incorrect_logits).mean()

In [122]:
from circuit_utils import get_logit_diff, ioi_metric

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices, per_prompt=False).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 2.6821
Corrupted logit diff: -2.6821


In [123]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff

ioi_metric = partial(ioi_metric, clean_baseline=CLEAN_BASELINE, corrupted_baseline=CORRUPTED_BASELINE, answer_token_indices=answer_token_indices)

clean_baseline_ioi = ioi_metric(clean_logits)
corrupted_baseline_ioi = ioi_metric(corrupted_logits)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [124]:
# Whether to do the runs by head and by position, which are much slower
DO_SLOW_RUNS = True

### Path Patching

In [125]:
from circuit_utils import patch_pos_head_vector, patch_head_vector, path_patching

## Direct Logit Attribution

In [126]:
answer_residual_directions = model.tokens_to_residual_directions(answer_token_indices)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([104, 2, 1024])
Logit difference directions shape: torch.Size([104, 1024])


In [127]:
answer_residual_directions.shape

torch.Size([104, 2, 1024])

In [128]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 
final_residual_stream = clean_cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)

average_logit_diff = einsum("batch d_model, batch d_model -> ", scaled_final_token_residual_stream, logit_diff_directions)/len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:",clean_logit_diff)

Final residual stream shape: torch.Size([104, 22, 1024])
Calculated average logit diff: 2.6821048259735107
Original logit difference: 2.682105302810669


### Logit Lens

In [129]:
from circuit_utils import residual_stack_to_logit_diff

accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=False, pos_slice=-1, return_labels=True)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, logit_diff_directions, prompts, clean_cache)
line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers+1), hover_name=labels, title="Logit Difference From Accumulated Residual Stream")

### Layer Attribution

In [130]:
from visualization_utils import l_line
per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, logit_diff_directions, prompts, clean_cache)
l_line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

### Head Attribution

In [131]:
per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, logit_diff_directions, prompts, clean_cache)
per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, "(layer head_index) -> layer head_index", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)
imshow(per_head_logit_diffs, xaxis="Head", yaxis="Layer", title="Logit Difference From Each Head")

Tried to stack head results when they weren't cached. Computing head results now


In [132]:
from visualization_utils import (
    plot_attention_heads,
    scatter_attention_and_contribution
)
plot_attention_heads(per_head_logit_diffs/clean_logit_diff, top_n=12, range_x=[0, 1])#, max_cumulative_percent=1.0)

Total logit diff contribution above threshold: 1.17


### Attention Analysis

In [67]:
from visualization_utils import get_attn_head_patterns

top_k = 12
top_heads = torch.topk(per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, clean_tokens[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [133]:
potential_wmh = [(18, 0), (19, 4), (20, 2)]

In [69]:
from visualization_utils import scatter_attention_and_contribution_logic
for h in wrong_token_heads:
    scatter_attention_and_contribution_logic(model, h, clean_tokens, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


In [49]:
top_k = 4
top_heads = torch.topk(-per_head_logit_diffs.flatten(), k=top_k).indices.cpu().numpy()
heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]
tokens, attn, names = get_attn_head_patterns(model, clean_tokens[0], heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

In [50]:
scatter_attention_and_contribution_logic(model, (9, 5), clean_tokens, answer_residual_directions)

Tried to stack head results when they weren't cached. Computing head results now


In [74]:
contextual_wmh = [(8, 2), (10, 10), (9, 6), (10, 8), (8, 8), (9, 8), (10, 6), (10, 11)]

## Activation Patching for Model Component Importance

### Residual Stream

In [134]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'resid_pre' Activation Patching")

  0%|          | 0/528 [00:00<?, ?it/s]

### MLP Layers

In [52]:
resid_pre_act_patch_results = patching.get_act_patch_mlp_out(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'mlp_out' Activation Patching")

  0%|          | 0/204 [00:00<?, ?it/s]

### Attention Layers

In [53]:
resid_pre_act_patch_results = patching.get_act_patch_attn_out(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric)

imshow(resid_pre_act_patch_results, 
       yaxis="Layer", 
       xaxis="Position", 
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
       title="IOI Metric for 'attn_out' Activation Patching")

  0%|          | 0/204 [00:00<?, ?it/s]

### Attention Layers by Head

In [54]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_out' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

In [55]:
attn_head_v_all_pos_act_patch_results = patching.get_act_patch_attn_head_v_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_out_all_pos_act_patch_results, 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_v' Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

In [53]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_v_act_patch_results = patching.get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
    attn_head_v_act_patch_results = einops.rearrange(attn_head_v_act_patch_results, "layer pos head -> (layer head) pos")
    imshow(attn_head_v_act_patch_results, 
        yaxis="Head Label", 
        xaxis="Pos", 
        x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
        y=ALL_HEAD_LABELS,
        title="attn_head_v Activation Patching By Pos")

  0%|          | 0/2448 [00:00<?, ?it/s]

### Head Component Output

In [58]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(every_head_all_pos_act_patch_result, facet_col=0, facet_labels=["Output", "Query", "Key", "Value", "Pattern"], title="Activation Patching Per Head (All Pos)", xaxis="Head", yaxis="Layer") #, zmax=1, zmin=-1)

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

In [52]:
contextual_wmh

[(8, 2), (10, 10), (9, 6), (10, 8), (8, 8), (9, 8), (10, 6), (10, 11)]

## Circuit Sketching

### First Level

#### Attention Pattern Patching

In [47]:
attn_head_v_by_pos_results = patching.get_act_patch_attn_head_v_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
imshow(attn_head_v_by_pos_results[:, -1, :], 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_pattern' Activation Patching (All Pos)")

  0%|          | 0/2448 [00:00<?, ?it/s]

In [48]:
imshow(attn_head_v_by_pos_results[:, 4, :], 
       yaxis="Layer", 
       xaxis="Head", 
       title="IOI Metric for 'attn_head_pattern' Activation Patching (All Pos)")

In [76]:
from visualization_utils import l_scatter
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_pattern_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    hover_name = head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching")

In [22]:
contextual_wmh = [(8, 2), (10, 10), (9, 6), (10, 8), (8, 8), (9, 8), (10, 6), (10, 11)]

In [77]:
mystery_heads

[(8, 10), (7, 8)]

In [78]:
# set earlier heads (non-NMH) to 0 so they don't dominate the plot
attn_head_pattern_all_pos_act_patch_results[:6, :] = 0
plot_attention_heads(
    attn_head_pattern_all_pos_act_patch_results, 
    title="Logit Diff Correction From Patching in Clean Attention Patterns", 
    top_n=15, 
    range_x=[0, 0.5]
)

Total logit diff contribution above threshold: 0.08


#### NMH Knockout

##### All Heads

In [68]:
heads_to_ablate = contextual_wmh

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")

Heads to ablate: [(8, 2), (10, 10), (9, 6), (10, 8), (8, 8), (9, 8), (10, 6), (10, 11)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.7786003351211548


In [71]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, logit_diff_directions, prompts, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"}, zmin=-1.5, zmax=1.5, title="Post-Ablation Direct Logit Attribution of Heads")
#scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

In [84]:
# S2Is above threshold + Negative NMHs + Backup S2Is
exclusions = [(7, 4), (7, 6)] + [(7, 10)]
delta = per_head_ablated_logit_diffs - per_head_logit_diffs
for layer, head in exclusions:
    per_head_ablated_logit_diffs[layer, head] = 0

plot_attention_heads(
    delta/clean_logit_diff, 
    title="Logit Diff Contribution From Backup Heads", 
    top_n=15, 
    range_x=[0, 0.5]
)

Total logit diff contribution above threshold: 0.28


##### Individual Heads

In [63]:
# Get indices of all heads where the ablation had a positive effect

backup_nmh_candidates = np.argwhere(per_head_ablated_logit_diffs.cpu().detach().numpy() > 0.02)
backup_nmh_candidates = [tuple(h) for h in backup_nmh_candidates]
backup_nmh_candidates = [h for h in backup_nmh_candidates if h not in exclusions]
print(f"Backup NMH Candidates: {backup_nmh_candidates}")
for l, h in backup_nmh_candidates:
    for layer, head in heads_to_ablate:
        ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
        model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
    scatter_attention_and_contribution(model, (l, h), prompts, io_positions, s_positions, answer_residual_directions)

Backup NMH Candidates: [(7, 3), (7, 5), (9, 4), (9, 10), (10, 7), (11, 10), (11, 11)]
Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


Tried to stack head results when they weren't cached. Computing head results now


In [64]:
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
tokens, attn, names = get_attn_head_patterns(model, prompts[0], backup_nmh_candidates)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

#### Path Patching for NMH Candidate Receivers

In [90]:
receiver_heads = contextual_wmh

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                #sender_heads=[(layer, None)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_k", head_idx) for layer_idx, head_idx in receiver_heads],
                sender_positions=4,
                receiver_positions=4
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            logit_diff_res = get_logit_diff(path_patched_logits, answer_token_indices)
            metric_delta_results[layer, head_idx] = -(clean_logit_diff - logit_diff_res) / clean_logit_diff

In [91]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers")

In [92]:
plot_attention_heads(-metric_delta_results, title="Logit Diff Drop with Corrupted Path Patch", top_n=15, range_x=[0, 0.05])

Total logit diff contribution above threshold: 0.01


### Second Level

#### Attention Pattern for Second-Level Heads

In [89]:
second_level_positive_heads = [(7, 8), (2, 10), (0, 9), (3, 0), (2, 6), (4, 9), (0, 2), (0, 3)]

tokens, attn, names = get_attn_head_patterns(model, prompts[0], second_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)

#second_level_negative_heads = [(7, 8), (8, 10)]
#visualize_attention_patterns(torch.tensor([l*12+h for l, h in second_level_negative_heads]), title=f"Top Negative Second Level IOI Metric Heads")

In [50]:
head_labels = [f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)]
l_scatter(
    x=utils.to_numpy(attn_head_v_all_pos_act_patch_results.flatten()), 
    y=utils.to_numpy(attn_head_out_all_pos_act_patch_results.flatten()), 
    xaxis="Value Patch",
    yaxis="Output Patch",
    #caxis="Layer",
    hover_name = head_labels,
    color=einops.repeat(np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads),
    range_x=(-1.5, 1.5),
    range_y=(-1.5, 1.5),
    title="Scatter plot of output patching vs value patching")

In [51]:
s2i_candidates = second_level_positive_heads
#s2i_candidates = [(8, 9)]

#### S2I Knockout

##### All Heads

In [52]:
heads_to_ablate = s2i_candidates

print(f"Heads to ablate: {heads_to_ablate}")
def ablate_top_head_hook(z: TT["batch", "pos", "head_index", "d_head"], hook, head_idx=0):
    z[:, -1, head_idx, :] = 0
    return z
# Adds a hook into global model state
for layer, head in heads_to_ablate:
    ablate_head_hook = partial(ablate_top_head_hook, head_idx=head)
    model.blocks[layer].attn.hook_z.add_hook(ablate_head_hook)
# Runs the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.
ablated_logits, ablated_cache = model.run_with_cache(clean_tokens)
print(f"Original IOI Metric: {ioi_metric(clean_logits).item():.4f}")
print(f"Post ablation IOI Metric: {ioi_metric(ablated_logits).item()}")
#print(f"Direct Logit Attribution of top name mover head: {per_head_logit_diffs.flatten()[top_name_mover].item()}")
#print(f"Naive prediction of post ablation logit diff: {original_average_logit_diff - per_head_logit_diffs.flatten()[top_name_mover].item()}")

Heads to ablate: [(7, 4), (7, 6)]
Original IOI Metric: 1.0000
Post ablation IOI Metric: 0.5208906531333923


In [53]:
per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
per_head_ablated_logit_diffs = residual_stack_to_logit_diff(per_head_ablated_residual, logit_diff_directions, prompts, ablated_cache)
per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)
imshow(per_head_ablated_logit_diffs, labels={"x":"Head", "y":"Layer"})
l_scatter(y=per_head_logit_diffs.flatten(), x=per_head_ablated_logit_diffs.flatten(), hover_name=head_labels, range_x=(-3, 3), range_y=(-3, 3), xaxis="Ablated", yaxis="Original", title="Original vs Post-Ablation Direct Logit Attribution of Heads")

Tried to stack head results when they weren't cached. Computing head results now


#### Path Patching for S2-Inhibition Candidates

In [None]:
model.to_str_tokens(clean_tokens[3])

['<|endoftext|>',
 'When',
 ' Tom',
 ' and',
 ' James',
 ' went',
 ' to',
 ' the',
 ' park',
 ',',
 ' Tom',
 ' gave',
 ' the',
 ' ball',
 ' to']

In [54]:
receiver_heads = second_level_positive_heads

metric_delta_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads, device='cuda:0')

for layer in range(model.cfg.n_layers):
        for head_idx in range(model.cfg.n_heads):
            pass_d_hooks = path_patching(
                model=model,
                patch_tokens=corrupted_tokens,
                orig_tokens=clean_tokens,
                sender_heads=[(layer, head_idx)],
                receiver_hooks=[(f"blocks.{layer_idx}.attn.hook_v", head_idx) for layer_idx, head_idx in receiver_heads],
                positions=10
            )
            path_patched_logits = model.run_with_hooks(clean_tokens, fwd_hooks=pass_d_hooks)
            iot_metric_res = ioi_metric(path_patched_logits)
            metric_delta_results[layer, head_idx] = -(clean_baseline_ioi - iot_metric_res) / clean_baseline_ioi

In [55]:
imshow(metric_delta_results, title="IOI Metric Change From Each Head Through Receivers")#, zmin=-0.02, zmax=0.02)

### Third Level

#### Attention Patterns for Third-Level Heads

We have a mix of induction heads and duplicate token heads here, as well as two heads that focus on S2 at S2.

In [57]:
third_level_positive_heads = [(0, 1), (3, 3), (4, 8), (4, 10), (5, 10), (5, 11)]

tokens, attn, names = get_attn_head_patterns(model, prompts[0], third_level_positive_heads)
cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)