In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("/root/circuit-finder")
import pandas as pd
import json
import torch
from dataclasses import dataclass
import transformer_lens as tl
from circuit_finder.patching.eap_graph import EAPGraph
from circuit_finder.utils import clear_memory
from circuit_finder.constants import device
from circuit_finder.pretrained import (
    load_attn_saes,
    load_hooked_mlp_transcoders,
)
from circuit_finder.metrics import batch_avg_answer_diff
from circuit_finder.patching.causal_interventions import run_with_ablations
import os
os.chdir("/root/circuit-finder/")

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f8081cf2fe0>

In [2]:
# Initialize SAEs
attn_saes = load_attn_saes(use_error_term=True)
mlp_transcoders = load_hooked_mlp_transcoders(use_error_term=True)

# Load model
model = tl.HookedSAETransformer.from_pretrained("gpt2").cuda()

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]



Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda


In [3]:
def load_batch(file_path, batch_size=10, first_batch=0):
    with open(file_path, 'r') as file:
        data = json.load(file)

    word_idxs = data['word_idxs']

    prompts = data['prompts'][first_batch:first_batch+batch_size]

    clean_prompts = [prompt['clean'] for prompt in prompts]
    corrupt_prompts = [prompt['corrupt'] for prompt in prompts]
    correct_answers = [prompt['answers'][0] for prompt in prompts]
    wrong_answers = [prompt['wrong_answers'][0] for prompt in prompts]
   

    correct_ids = model.to_tokens(correct_answers)[:,1]
    wrong_ids = model.to_tokens(wrong_answers)[:,1]

    return word_idxs, clean_prompts, corrupt_prompts, correct_ids, wrong_ids

file_path = 'datasets/ioi/ioi_BABA_template_0_prompts.json'
word_idxs, clean_prompts, corrupt_prompts, correct_ids, wrong_ids = load_batch(file_path, batch_size=10)

print("Clean Prompts:", clean_prompts)
print("Corrupt Prompts:", corrupt_prompts)
print("Correct IDs:", correct_ids)
print("Wrong IDs:", wrong_ids)


S1 = word_idxs['S1']+1
S2 = word_idxs['S2']+1

Clean Prompts: ['Then, Jeremy and Scott went to the hospital. Jeremy gave a snack to', 'Then, Victoria and Stephen went to the station. Victoria gave a necklace to', 'Then, Carter and Kyle went to the restaurant. Carter gave a kiss to', 'Then, Robert and Lisa went to the restaurant. Robert gave a computer to', 'Then, Andy and Ruby went to the office. Andy gave a snack to', 'Then, John and Jacob went to the school. John gave a snack to', 'Then, Emily and Leon went to the store. Emily gave a kiss to', 'Then, Tyler and Robert went to the store. Tyler gave a snack to', 'Then, Ruby and Edward went to the garden. Ruby gave a basketball to', 'Then, Sullivan and Dean went to the station. Sullivan gave a kiss to']
Corrupt Prompts: ['Then, Michael and Anderson went to the hospital. Rachel gave a snack to', 'Then, Charlie and Jay went to the station. Elizabeth gave a necklace to', 'Then, Emily and River went to the restaurant. Brian gave a kiss to', 'Then, Daniel and Austin went to the restaurant

In [4]:
from eindex import eindex
def correct_logit(logits, correct_ids, wrong_ids):
    return eindex(logits[:,-1], correct_ids, "batch [batch]").mean()

def wrong_logit(logits, correct_ids, wrong_ids):
    return eindex(logits[:,-1], wrong_ids, "batch [batch]").mean()

def logit_diff(logits, correct_ids, wrong_ids):
    return correct_logit(logits, correct_ids, wrong_ids) - wrong_logit(logits, correct_ids, wrong_ids)

In [6]:
from circuit_finder.patching.indirect_leap import IndirectLeap, LeapConfig
model.reset_hooks()
from circuit_finder.plotting import make_html_graph
cfg = LEAPConfig(threshold=0.04,
                 contrast_pairs=True, 
                 qk_enabled=True,
                 chained_attribs=True,
                 store_error_attribs=True,
                 allow_neg_feature_acts=True)
metric = partial(logit_diff, correct_ids=correct_ids, wrong_ids=wrong_ids)
leap = IndirectLEAP(
    cfg, clean_prompts, model, attn_saes, mlp_transcoders, metric, corrupt_tokens=corrupt_prompts
)
leap.run()
print("num edges: ", len(leap.graph))
graph = EAPGraph(leap.graph)
types = graph.get_edge_types()
print("num ov edges: ", len([t for t in types if t=="ov"]))
print("num q edges: ", len([t for t in types if t=="q"]))
print("num k edges: ", len([t for t in types if t=="k"]))
print("num mlp edges: ", len([t for t in types if t==None]))
make_html_graph(leap, attrib_type="em", node_offset=8.0, show_error_nodes=False)

ImportError: cannot import name 'IndirectLeap' from 'circuit_finder.patching.indirect_leap' (/root/circuit-finder/circuit_finder/patching/indirect_leap.py)

First, we look at the graph for John and Mary, which is in BABA form. There are 3 feature-clusters:

C1. Important attn features at the S2 position, around layer 5-6
C2. An important attn feature at the final token position at layer 8.
C3. A couple attn features at the final token position in layer 9-10, which strongly boost "Mary".

We'd like to understand how these things compose. Namely:

- How do the layer 9-10 attention heads "know" to attend more to Mary than John?
    - They have high key-attribution to C2 features. But how exactly do these C2 features tell the model where to attend to? Do they mainly boost attention to tokens with the IO property? Or do they suppress attention to tokens with the S property? (Hunch: IO and S isn't really the right way to think here - it's really about "name that was repeated" and "name that wasn't repeated".)
- How are the C2 features computed?
    - Hunch: L8 attention uses the "proper noun should come next" feature as a query, the C1 features act as a key, and the OV circuit applied to C1 gives C2.

Let's start by looking at the 3 main C1 features:

- https://www.neuronpedia.org/gpt2-small/5-att-kk/7515
- https://www.neuronpedia.org/gpt2-small/6-att-kk/13836
- https://www.neuronpedia.org/gpt2-small/6-att-kk/17410

All three features are basically "token that was previously followed by 'and'" features.

e.g. "The shirt is blue and white. Her eyes are |blue|"

Let's ablate all three of them and see what happens to the logit diff.

In [None]:
logits, clean_cache = run_with_ablations(
    model,
    clean_prompts,
    attn_saes,
    mlp_transcoders,
    ablation_list = [],
    patch_list = [],
    cache_names_filter = [f'blocks.{layer}.attn.hook_pattern' for layer in range(12)]
)
clean_logit_diff = logit_diff(logits, correct_ids, wrong_ids).item()
logits, ablated_cache = run_with_ablations(
    model,
    clean_prompts,
    attn_saes,
    mlp_transcoders,
    ablation_list = [("attn", 5, 7515, S2),
                     ("attn", 6, 17410, S2),
                     ("attn", 6, 13836, S2)],
    patch_list = [],
    cache_names_filter = [f'blocks.{layer}.attn.hook_pattern' for layer in range(12)]
)
ablated_logit_diff = logit_diff(logits, correct_ids, wrong_ids).item()
print()
print(f"clean logit diff = {clean_logit_diff:.2f}")
print(f"ablated logit diff = {ablated_logit_diff:.2f}")
print(f"percent logit diff reduction = {100*(clean_logit_diff-ablated_logit_diff)/clean_logit_diff :.1f}%")

Ablating these three nodes reduces the metric by more than a third! Now, let's work out *why*. First, we can look at how these ablations affect the attention pattern from the last token back to the previous positions, in the layers that matter (8, 9 and 10) 

In [None]:
import plotly.express as px
from collections import Counter
for layer in [7,8,9,10,11]:
    A = clean_cache[f'blocks.{layer}.attn.hook_pattern'].mean([0,1])[-1].cpu()
    B = ablated_cache[f'blocks.{layer}.attn.hook_pattern'].mean([0,1])[-1].cpu()
    C = model.to_str_tokens(clean_prompts[0])

    # Create unique labels
    counter = Counter(C)
    unique_labels = []
    for label in C:
        count = counter[label]
        if count > 1:
            unique_labels.append(f"{label} ")
            counter[label] -= 1
        else:
            unique_labels.append(label)

    # Convert tensors to pandas DataFrame
    df = pd.DataFrame({
        'tokens': unique_labels,
        'clean': A.numpy(),
        'ablated': B.numpy()
    })

    # Melt the DataFrame to long format
    df_melted = df.melt(id_vars='tokens', value_vars=['clean', 'ablated'], var_name='Tensor', value_name='attention prob')

    # Create the bar chart with barmode set to 'group'
    fig = px.bar(df_melted, x='tokens', y='attention prob', color='Tensor', barmode='group', title=f'Layer {layer} Attention')

    # Show the plot
    fig.show()

The C1 ablations change the layer 9 and 10 patterns by a bit, bit not much: they only change the attention probs on the name tokens by ~10%.

We can also subtract off the C1 feature components at various points in the residual stream, and check the effect on the logit diff:

In [None]:
percent_logit_diff_reductions = []
layers = [7,8,9,10,11]
for patch_layer in layers:
    model.reset_hooks()
    logits, ablated_cache = run_with_ablations(
        model,
        clean_prompts,
        attn_saes,
        mlp_transcoders,
        grab_and_delete_list = [
            ("attn", 5, 7515, S2, f'blocks.{patch_layer}.hook_resid_pre'),
            ("attn", 6, 17410, S2, f'blocks.{patch_layer}.hook_resid_pre'),
            ("attn", 6, 13836, S2, f'blocks.{patch_layer}.hook_resid_pre')
        ],
        cache_names_filter = [f'blocks.{layer}.attn.hook_pattern' for layer in range(12)]
    )
    reduction = (clean_logit_diff - logit_diff(logits, correct_ids, wrong_ids).item()) / clean_logit_diff
    percent_logit_diff_reductions += [100*reduction]


df = pd.DataFrame({
    'layer': layers,
    'logit diff percent reduction': percent_logit_diff_reductions
})
px.line(df, x='layer', y='logit diff percent reduction', 
        title="Effect of deleting C1 features at resid_pre",
        )



BAAB S2 features:

https://www.neuronpedia.org/gpt2-small/8-att-kk/5580


BAAB final pos key features:

https://www.neuronpedia.org/gpt2-small/8-att-kk/13676 works with both orderings

https://www.neuronpedia.org/gpt2-small/7-att-kk/3607 another CS feature

https://www.neuronpedia.org/gpt2-small/8-att-kk/5580  e.g. A said "..... |,"|. Model predicts A.
