In [1]:
# Load model

from sae_eap.model.load_pretrained import load_model

model = load_model('gpt2')

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model gpt2 into HookedTransformer


In [2]:
# Make data handler

from sae_eap.data.ioi import make_ioi_single

handler = make_ioi_single(model)
print(handler)

SinglePromptHandler(
    clean_prompt=When John and Mary went to the shops, John gave a bag to,
    corrupt_prompt= When Alice and Bob went to the shops, Charlie gave a bag to,
    answer= Mary,
    wrong_answer= John
)


In [3]:
# Build graph
from sae_eap.graph import build

graph = build.build_graph(model)

In [4]:
# Run attribution
from sae_eap.runner import run_attribution

attribution = run_attribution(model, graph, handler)

0it [00:00, ?it/s]

1it [00:00,  1.82it/s]
100%|██████████| 32491/32491 [00:00<00:00, 1303887.74it/s]


In [5]:
# Prune graph
from sae_eap.prune import PruningPipeline, ThresholdEdgePruner, DeadNodePruner

model_graph = graph.copy()
pipeline = PruningPipeline([
    ThresholdEdgePruner(0.01),
    DeadNodePruner()
])
pipeline.prune(graph, attribution)
circuit_graph = graph

print(len(circuit_graph.nodes))
print(len(circuit_graph.edges))
print(len(model_graph.nodes))
print(len(model_graph.edges))

601
29539
602
32491


In [6]:
# Visualize graph

In [7]:
# Run ablation
from sae_eap.runner import run_ablation

faithfulness = run_ablation(model, circuit_graph, model_graph, handler)
print(faithfulness)

[0.19197118282318115]


# Experiments

In [9]:
# Plot faithfulness curve

# Prune graph
from sae_eap.prune import PruningPipeline, TopNEdgePruner, DeadNodePruner

model_graph = build.build_graph(model)
graph = model_graph.copy()
for k_edges in (10, 50, 100, 200, 500, 1000):
    pipeline = PruningPipeline([
        TopNEdgePruner(k_edges),
        DeadNodePruner()
    ])
    pipeline.prune(graph, attribution)
    circuit_graph = graph

    faithfulness = run_ablation(model, circuit_graph, model_graph, handler)[0]
    print(f"Top {k_edges} edges: faithfulness={faithfulness:.4f}")

Top 10 edges: faithfulness=-3.9571
Top 50 edges: faithfulness=-3.9571
Top 100 edges: faithfulness=-3.9571
Top 200 edges: faithfulness=-3.9571
Top 500 edges: faithfulness=-3.9571
Top 1000 edges: faithfulness=-3.9571
