In [1]:
from rich.table import Table
from rich import print as rprint
from circuit_finder.pretrained import load_attn_saes, load_resid_saes, load_model
from circuit_finder.utils import (
    get_answer_tokens, 
    logits_to_ave_logit_diff,
    get_cache_fwd_and_bwd,
    
)
from circuit_finder.data import ioi

clean_prompts, corrupt_prompts, answers = ioi.get_ioi_data()
# Print the data
table = Table("Prompt", "Correct", "Incorrect", title="Prompts & Answers:")
for prompt, answer in zip(clean_prompts, answers):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]))
rprint(table)

In [2]:
# Initialize SAEs
attn_saes = load_attn_saes()
resid_saes = load_resid_saes()
all_saes = list(attn_saes.values()) + list(resid_saes.values())
for sae in all_saes:
    sae.cfg.use_error_term = True

# Load model
model = load_model()

100%|██████████| 1/1 [00:01<00:00,  1.42s/it]
100%|██████████| 1/1 [00:01<00:00,  1.37s/it]
100%|██████████| 1/1 [00:00<00:00,  1.00it/s]
100%|██████████| 1/1 [00:01<00:00,  1.28s/it]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it]
100%|██████████| 1/1 [00:01<00:00,  1.56s/it]
100%|██████████| 1/1 [00:01<00:00,  1.48s/it]
100%|██████████| 1/1 [00:00<00:00,  1.15it/s]
100%|██████████| 1/1 [00:01<00:00,  1.40s/it]
100%|██████████| 1/1 [00:01<00:00,  1.12s/it]
100%|██████████| 1/1 [00:00<00:00,  1.02it/s]
100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it]


Loaded pretrained model gpt2 into HookedTransformer


# Layer 5 Induction Features

In [3]:
prompt_tokens = model.to_tokens(clean_prompts)
answer_tokens = get_answer_tokens(answers, model)
metric_fn = lambda logits: logits_to_ave_logit_diff(logits, answer_tokens, per_prompt = False)
filter_fn = lambda hook_name: "hook_sae_acts_post" in hook_name

with model.saes(all_saes):
    value, clean_act_cache, clean_grad_cache = get_cache_fwd_and_bwd(
        model = model,
        tokens = prompt_tokens,
        metric_fn = metric_fn,
        filter_fn = filter_fn,
    )

In [4]:
print(value)
print(len(clean_act_cache))
print(len(clean_grad_cache))

3.2691798210144043
25
25


In [5]:
from transformer_lens import utils
layer = 5
hook_name = utils.get_act_name('z', layer) + ".hook_sae_acts_post"

clean_acts = clean_act_cache[hook_name]
clean_grads = clean_grad_cache[hook_name]
# clean_attribs = clean_grads * (-clean_acts).coalesce()



In [23]:
p_toks = prompt_tokens[0]
# NOTE: the wrong answer is the subject
wrong_ans_tok = answer_tokens[1]
p_str_tokens = model.to_str_tokens(p_toks)
wa_str_token = model.to_str_tokens(wrong_ans_tok)[0]
print(p_str_tokens)
print(wa_str_token)

s1_position = p_str_tokens.index(wa_str_token)
s2_position = p_str_tokens.index(wa_str_token, s1_position + 1)
print(s1_position)
print(s2_position)

['<|endoftext|>', 'When', ' John', ' and', ' Mary', ' went', ' to', ' the', ' shops', ',', ' Mary', ' gave', ' the', ' bag', ' to']
 Mary
4
10


## Position 10

In [12]:
from circuit_finder.plotting import imshow

clean_s2_acts = clean_acts[:, s2_position, :]

live_feature_mask = clean_s2_acts > 0
live_feature_union = live_feature_mask.any(dim=0)
selected_acts = clean_s2_acts[:, live_feature_union]

imshow(
    clean_s2_acts[:, live_feature_union],
    title=f"Activations of Live SAE features at L{layer} S2 position per prompt",
    xaxis="Feature Id", 
    yaxis="Prompt",
    x=list(map(str, live_feature_union.nonzero().flatten().tolist())),
)

In [19]:
import torch
from circuit_finder.neuronpedia import get_neuronpedia_url_for_quick_list

# ABBA format
topk = torch.topk(clean_s2_acts[0, :], k=5)
topk_features = topk.indices.detach().cpu().numpy().tolist()
print(topk_features)
get_neuronpedia_url_for_quick_list(
    layer=layer, features=topk_features, sae_family="att-kk"
)

[13777, 44256, 24091, 10894, 3047]


'https://neuronpedia.org/quick-list/?name=temporary_list&features=%5B%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2213777%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2244256%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2224091%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2210894%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%223047%22%7D%5D'

In [20]:
# BABA format
topk = torch.topk(clean_s2_acts[1, :], k=5)
topk_features = topk.indices.detach().cpu().numpy().tolist()
print(topk_features)
get_neuronpedia_url_for_quick_list(
    layer=layer, features=topk_features, sae_family="att-kk"
)

[7515, 29482, 27535, 9463, 18210]


'https://neuronpedia.org/quick-list/?name=temporary_list&features=%5B%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%227515%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2229482%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2227535%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%229463%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2218210%22%7D%5D'

Remarks: 
- We found induction heads matching Kissane' colab demo
  - L5.7515 and L5.27353 in BABA prompts
  - L5.44256 in ABBA prompts 
- This gives me confidence that the code is working as intended! 

## Position 12

In [21]:
from circuit_finder.plotting import imshow

clean_s2_acts = clean_acts[:, 12, :]

live_feature_mask = clean_s2_acts > 0
live_feature_union = live_feature_mask.any(dim=0)
selected_acts = clean_s2_acts[:, live_feature_union]

imshow(
    clean_s2_acts[:, live_feature_union],
    title=f"Activations of Live SAE features at L{layer} S2 position per prompt",
    xaxis="Feature Id",
    yaxis="Prompt",
    x=list(map(str, live_feature_union.nonzero().flatten().tolist())),
)

In [22]:
topk = torch.topk(clean_s2_acts[0, :], k=5)
topk_features = topk.indices.detach().cpu().numpy().tolist()
print(topk_features)
get_neuronpedia_url_for_quick_list(
    layer=layer, features=topk_features, sae_family="att-kk"
)

[36109, 28197, 10823, 13063, 44610]


'https://neuronpedia.org/quick-list/?name=temporary_list&features=%5B%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2236109%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2228197%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2210823%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2213063%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2244610%22%7D%5D'

## Position 4

In [26]:
from circuit_finder.plotting import imshow

clean_s2_acts = clean_acts[:, s1_position, :]

live_feature_mask = clean_s2_acts > 0
live_feature_union = live_feature_mask.any(dim=0)
selected_acts = clean_s2_acts[:, live_feature_union]
print(selected_acts.shape)

imshow(
    clean_s2_acts[:, live_feature_union],
    title=f"Activations of Live SAE features at L{layer} S2 position per prompt",
    xaxis="Feature Id",
    yaxis="Prompt",
    x=list(map(str, live_feature_union.nonzero().flatten().tolist())),
)

torch.Size([2, 25])


In [25]:
topk = torch.topk(clean_s2_acts[0, :], k=5)
topk_features = topk.indices.detach().cpu().numpy().tolist()
print(topk_features)
get_neuronpedia_url_for_quick_list(
    layer=layer, features=topk_features, sae_family="att-kk"
)

[9504, 44452, 28819, 37069, 44631]


'https://neuronpedia.org/quick-list/?name=temporary_list&features=%5B%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%229504%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2244452%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2228819%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2237069%22%7D%2C%20%7B%22modelId%22%3A%20%22gpt2-small%22%2C%20%22layer%22%3A%20%225-att-kk%22%2C%20%22index%22%3A%20%2244631%22%7D%5D'

# Attribution Patching

We compute attributions via attribution patching




In [7]:
hook_name: str = "blocks.9.attn.hook_z.hook_sae_acts_post"
clean_acts = clean_act_cache[hook_name].to_sparse().coalesce()
# corrupt_acts = corrupt_act_cache[hook_name].to_sparse().coalesce()

clean_corrupt_act_delta = (- clean_acts).coalesce()
clean_grads = clean_grad_cache[hook_name].to_sparse().coalesce()

print(clean_corrupt_act_delta.shape)
print(clean_grads.shape)



torch.Size([4, 19, 24576])
torch.Size([4, 19, 24576])


In [8]:
import pandas as pd
import torch

rows = []
for index, value in zip(clean_corrupt_act_delta.indices().T, clean_corrupt_act_delta.values()):
    example_idx, token_idx, feature_idx = index
    grad_value = clean_grads[index[0], index[1], index[2]]
    
    
    rows.append({
        'hook_name': hook_name,
        'example_idx': example_idx.item(),
        'token_idx': token_idx.item(),
        'feature_idx': feature_idx.item(),
        'act_delta': value.item(),
        'grad': grad_value.item(),
        'attrib': value.item() * grad_value.item(),        
    })

df = pd.DataFrame(rows)
df

Unnamed: 0,hook_name,example_idx,token_idx,feature_idx,act_delta,grad,attrib
0,blocks.9.attn.hook_z.hook_sae_acts_post,0,0,13021,-1.171209,-6.545376e-19,7.666006e-19
1,blocks.9.attn.hook_z.hook_sae_acts_post,0,1,13021,-1.169506,0.000000e+00,-0.000000e+00
2,blocks.9.attn.hook_z.hook_sae_acts_post,0,2,30,-0.853277,-1.564868e-19,1.335266e-19
3,blocks.9.attn.hook_z.hook_sae_acts_post,0,2,31,-0.305261,-1.852233e-19,5.654141e-20
4,blocks.9.attn.hook_z.hook_sae_acts_post,0,2,41,-1.766256,3.984899e-19,-7.038353e-19
...,...,...,...,...,...,...,...
27433,blocks.9.attn.hook_z.hook_sae_acts_post,3,18,20863,-0.034942,-8.188876e-05,2.861369e-06
27434,blocks.9.attn.hook_z.hook_sae_acts_post,3,18,21560,-0.162702,-2.312821e-06,3.763005e-07
27435,blocks.9.attn.hook_z.hook_sae_acts_post,3,18,21793,-0.084693,-9.371792e-06,7.937231e-07
27436,blocks.9.attn.hook_z.hook_sae_acts_post,3,18,23715,-0.032405,1.065277e-05,-3.452006e-07


In [9]:
import plotly_express as px

plot_df = df[df['example_idx'] == 0]

fig = px.scatter(plot_df, x='token_idx', y='attrib', color='feature_idx')
fig.show()

Remarks

In [10]:
fig = px.scatter(plot_df, x='token_idx', y='act_delta', color='feature_idx')
fig.show()