In [1]:
import torch
import sys
import transformers
import torch
import circuitsvis as cv
import torch.nn as nn
import numpy as np
import einops
from copy import deepcopy
from fancy_einsum import einsum
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer, FactoredMatrix, HookedTransformerConfig
from jaxtyping import Float, Int
from torch import Tensor
import huggingface_hub
from tqdm import tqdm
import torch.nn.functional as F
from transformer_lens.ActivationCache import ActivationCache
import re


from utils.metrics import compare_token_probability, kl_divergence, compare_token_logit
from utils.miscellanea import get_top_k_contributors, IOI_head_types
from utils.component_contributions import contribution_mlp, contribution_attn


%load_ext autoreload
%autoreload 2

transformers.logging.set_verbosity_error()
# torch.set_default_dtype(torch.bfloat16)

In [2]:
from utils.nodes import MLP_Node, EMBED_Node, FINAL_Node,Node, ATTN_Node
from utils.graph_search import path_message, evaluate_path, breadth_first_search

In [3]:
import dotenv
import os
dotenv.load_dotenv()

TOKEN = os.getenv("TOKEN")

In [4]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.set_grad_enabled(False)

huggingface_hub.login(token=TOKEN)
# Note: Eventually can set set fold_ln=False, center_unembed=False, center_writing_weights=False
model = HookedTransformer.from_pretrained('gpt2-small', device=DEVICE, torch_dtype=torch.float32)
find_subject_inibition = False
if find_subject_inibition:
    target_idx = 1
else:
    target_idx = 0

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
prompts = ['When John and Mary went to the shops, John gave the bag to', 'When John and Mary went to the shops, Mary gave the bag to', 'When Tom and James went to the park, Tom gave the ball to', 'When Tom and James went to the park, James gave the ball to', 'When Dan and Sid went to the shops, Dan gave an apple to', 'When Dan and Sid went to the shops, Sid gave an apple to', 'After Martin and Amy went to the park, Martin gave a drink to', 'After Martin and Amy went to the park, Amy gave a drink to']
answers = [(' Mary', ' John'), (' John', ' Mary'), (' James', ' Tom'), (' Tom', ' James'), (' Sid', ' Dan'), (' Dan', ' Sid'), (' Amy', ' Martin'), (' Martin', ' Amy')]

# Keep only the prompts where the second token is the indirect object
# This is required because the search requires fixed input positions
prompts_fixed_pos = prompts[0::2]
answers_fixed_pos = answers[0::2]

example_idx = 2

tokens = model.to_tokens(prompts[example_idx])
logits, cache = model.run_with_cache(prompts_fixed_pos)

model_token = logits[0][-1].argmax(dim=-1)
correct_tokens = [model.to_tokens(str(answers_fixed_pos[i][target_idx]))[0][-1].item() for i in range(len(prompts_fixed_pos))]
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_heads = model.cfg.d_head

In [6]:
default_metric = compare_token_logit

In [7]:
min_treshold = 0.15 #0.25, #0.25, 2, 0.025

In [8]:
complete_paths, incomplete_paths = breadth_first_search(
	model,
	cache,
	default_metric,
	start_node = [FINAL_Node(layer=model.cfg.n_layers-1, position=14)],
	ground_truth_tokens = correct_tokens,
	max_depth = 100, # max number of components in the path (max number of nodes -2)
	max_branching_factor = 2048,
	min_contribution = min_treshold,
	min_contribution_percentage=0., #2, 5, 0.5
	inibition_task = find_subject_inibition
)
print(f"Found {len(complete_paths)} complete paths and {len(incomplete_paths)} incomplete paths.")

BFS search:   0%|          | 0/100 [00:00<?, ?it/s]

Exploring depth 1 with 1 paths in the frontier
    Frontier: [(91.54667663574219, [FINAL_Node(layer=11, position=14)])](total 1)


BFS search:   1%|          | 1/100 [00:02<04:09,  2.52s/it, completed_paths=0, frontier_size=117]

Exploring depth 2 with 117 paths in the frontier
    Frontier: [(19.770381927490234, [ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (18.88329315185547, [ATTN_Node(layer=9, head=9, position=14, keyvalue_position=4, patch_query=False, patch_keyvalue=True), FINAL_Node(layer=11, position=14)]), (12.274027824401855, [ATTN_Node(layer=9, head=6, position=14, keyvalue_position=4, patch_query=False, patch_keyvalue=True), FINAL_Node(layer=11, position=14)]), (12.248641967773438, [ATTN_Node(layer=9, head=6, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)])]... ](total 117)


BFS search:   2%|▏         | 2/100 [00:52<49:26, 30.27s/it, completed_paths=11, frontier_size=949]

Exploring depth 3 with 949 paths in the frontier
    Frontier: [(10.142624855041504, [MLP_Node(layer=2, position=14), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (8.015645980834961, [ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (7.049868583679199, [ATTN_Node(layer=8, head=6, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (4.810174942016602, [MLP_Node(layer=0, position=14), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)])]... ](total 

BFS search:   3%|▎         | 3/100 [04:28<3:06:30, 115.37s/it, completed_paths=52, frontier_size=3255]

Exploring depth 4 with 3255 paths in the frontier
    Frontier: [(0.7095268964767456, [MLP_Node(layer=0, position=14), MLP_Node(layer=2, position=14), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (3.7176291942596436, [ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (2.765381336212158, [ATTN_Node(layer=5, head=5, position=10, keyvalue_position=3, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINA

BFS search:   4%|▍         | 4/100 [09:15<4:53:02, 183.15s/it, completed_paths=169, frontier_size=3259]

Exploring depth 5 with 3259 paths in the frontier
    Frontier: [(0.5173777341842651, [ATTN_Node(layer=0, head=11, position=14, keyvalue_position=0, patch_query=False, patch_keyvalue=True), MLP_Node(layer=0, position=14), MLP_Node(layer=2, position=14), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (0.5002471208572388, [ATTN_Node(layer=0, head=11, position=14, keyvalue_position=1, patch_query=False, patch_keyvalue=True), MLP_Node(layer=0, position=14), MLP_Node(layer=2, position=14), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (0.49917590618133545, [ATTN_Node(layer=0, head=11, position=14, keyvalue_position=14, patch_query=False, patch_keyvalue=True), MLP_Node(layer=0, position=14), MLP_Node(layer=2, position=14), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, 

BFS search:   5%|▌         | 5/100 [13:13<5:21:18, 202.93s/it, completed_paths=328, frontier_size=2432]

Exploring depth 6 with 2432 paths in the frontier
    Frontier: [(0.1513633131980896, [ATTN_Node(layer=0, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), MLP_Node(layer=2, position=10), ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (0.48863083124160767, [MLP_Node(layer=0, position=10), ATTN_Node(layer=3, head=0, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=

BFS search:   6%|▌         | 6/100 [15:23<4:39:00, 178.09s/it, completed_paths=465, frontier_size=921] 

Exploring depth 7 with 921 paths in the frontier
    Frontier: [(0.5806459784507751, [MLP_Node(layer=2, position=10), MLP_Node(layer=3, position=10), MLP_Node(layer=4, position=10), ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (0.45428335666656494, [ATTN_Node(layer=3, head=0, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), MLP_Node(layer=3, position=10), MLP_Node(layer=4, position=10), ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patc

BFS search:   7%|▋         | 7/100 [16:18<3:33:41, 137.87s/it, completed_paths=504, frontier_size=470]

Exploring depth 8 with 470 paths in the frontier
    Frontier: [(0.5717835426330566, [MLP_Node(layer=0, position=10), ATTN_Node(layer=3, head=0, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), MLP_Node(layer=3, position=10), MLP_Node(layer=4, position=10), ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (0.19920828938484192, [ATTN_Node(layer=0, head=11, position=10, keyvalue_position=0, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=3, head=0, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), MLP_Node(layer=3, position=10), MLP_Node(layer=4, position=10), ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patc

BFS search:   8%|▊         | 8/100 [16:33<2:31:19, 98.69s/it, completed_paths=516, frontier_size=51]  

Exploring depth 9 with 51 paths in the frontier
    Frontier: [(0.16593943536281586, [MLP_Node(layer=0, position=10), MLP_Node(layer=2, position=10), ATTN_Node(layer=3, head=0, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), MLP_Node(layer=3, position=10), MLP_Node(layer=4, position=10), ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (0.3923649191856384, [MLP_Node(layer=0, position=2), ATTN_Node(layer=1, head=4, position=2, keyvalue_position=None, patch_query=True, patch_keyvalue=False), MLP_Node(layer=1, position=2), MLP_Node(layer=2, position=2), ATTN_Node(layer=3, head=7, position=3, keyvalue_position=2, patch_query=False, patch_keyvalue=True), ATTN_Node(l

BFS search:   9%|▉         | 9/100 [16:35<1:43:50, 68.47s/it, completed_paths=517, frontier_size=13]

Exploring depth 10 with 13 paths in the frontier
    Frontier: [(0.24905027449131012, [ATTN_Node(layer=0, head=9, position=10, keyvalue_position=0, patch_query=False, patch_keyvalue=True), MLP_Node(layer=0, position=10), MLP_Node(layer=2, position=10), ATTN_Node(layer=3, head=0, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), MLP_Node(layer=3, position=10), MLP_Node(layer=4, position=10), ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False), ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch_keyvalue=True), ATTN_Node(layer=9, head=9, position=14, keyvalue_position=None, patch_query=True, patch_keyvalue=False), FINAL_Node(layer=11, position=14)]), (0.20433494448661804, [ATTN_Node(layer=0, head=9, position=10, keyvalue_position=3, patch_query=False, patch_keyvalue=True), MLP_Node(layer=0, position=10), MLP_Node(layer=2, position=10), ATTN_Node(layer=3, head=0, position=10,

BFS search:  10%|█         | 10/100 [16:35<2:29:23, 99.59s/it, completed_paths=517, frontier_size=0]

Found 517 complete paths and 0 incomplete paths.





In [10]:
complete_paths

[(9.29843807220459,
  [EMBED_Node(layer=0, position=4),
   MLP_Node(layer=0, position=4),
   ATTN_Node(layer=9, head=9, position=14, keyvalue_position=4, patch_query=False, patch_keyvalue=True),
   FINAL_Node(layer=11, position=14)]),
 (6.137215614318848,
  [EMBED_Node(layer=0, position=4),
   MLP_Node(layer=0, position=4),
   ATTN_Node(layer=9, head=6, position=14, keyvalue_position=4, patch_query=False, patch_keyvalue=True),
   FINAL_Node(layer=11, position=14)]),
 (3.9328300952911377,
  [EMBED_Node(layer=0, position=4),
   MLP_Node(layer=0, position=4),
   ATTN_Node(layer=10, head=0, position=14, keyvalue_position=4, patch_query=False, patch_keyvalue=True),
   FINAL_Node(layer=11, position=14)]),
 (1.5881882905960083,
  [EMBED_Node(layer=0, position=10),
   MLP_Node(layer=0, position=10),
   ATTN_Node(layer=5, head=5, position=10, keyvalue_position=None, patch_query=True, patch_keyvalue=False),
   ATTN_Node(layer=8, head=6, position=14, keyvalue_position=10, patch_query=False, patch

In [11]:
# save circuit
import json
from datetime import datetime

# Convert the complete_paths to a serializable format
def convert_path_to_dict(path_tuple):
    score, path = path_tuple
    path_dict = {
        "score": float(score),
        "nodes": []
    }
    
    for node in path:
        node_dict = {
            "type": node.__class__.__name__,
            "layer": node.layer,
            "position": node.position
        }
        
        # Add attention-specific attributes
        if hasattr(node, 'head'):
            node_dict["head"] = node.head
        if hasattr(node, 'keyvalue_position'):
            node_dict["keyvalue_position"] = node.keyvalue_position
        if hasattr(node, 'patch_query'):
            node_dict["patch_query"] = node.patch_query
        if hasattr(node, 'patch_keyvalue'):
            node_dict["patch_keyvalue"] = node.patch_keyvalue
            
        path_dict["nodes"].append(node_dict)
    
    return path_dict

# Convert all paths
serializable_paths = [convert_path_to_dict(path) for path in complete_paths]

# Create metadata
metadata = {
    "model": "gpt2-small",
    "prompt": prompts[example_idx],
    "correct_answer": str(answers[example_idx][0]),
    "target_idx": target_idx,
    "find_subject_inhibition": find_subject_inibition,
    "timestamp": datetime.now().isoformat(),
    "total_paths": len(complete_paths),
    "min_treshold": min_treshold,
    "n_layers": model.cfg.n_layers,
    "d_model": model.cfg.d_model,
    "n_heads": model.cfg.n_heads,
    "metric": default_metric.__name__
}

# Combine data
output_data = {
    "metadata": metadata,
    "paths": serializable_paths
}

# Save to JSON file
filename = f"detected_paths/detected_circuit_gpt2_ioi_{default_metric.__name__}_{min_treshold}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(filename, 'w') as f:
    json.dump(output_data, f, indent=2)

print(f"Saved {len(complete_paths)} paths to {filename}")
print(f"Top 3 paths by score:")
for i, path in enumerate(serializable_paths[:3]):
    print(f"  {i+1}. Score: {path['score']:.4f}, Nodes: {len(path['nodes'])}")

Saved 517 paths to detected_paths/detected_circuit_gpt2_ioi_compare_token_logit_0.15_20250723_161117.json
Top 3 paths by score:
  1. Score: 9.2984, Nodes: 4
  2. Score: 6.1372, Nodes: 4
  3. Score: 3.9328, Nodes: 4
