In [10]:
%load_ext autoreload
%autoreload 2

In [11]:
import argparse
import os
import torch
import pickle
from functools import partial

from transformer_lens import HookedTransformer, HookedTransformerConfig
from huggingface_hub import hf_hub_download

from MIB_circuit_track.dataset import HFEAPDataset
from eap.graph import Graph
from eap.attribute import attribute
from eap.attribute_node import attribute_node
from MIB_circuit_track.metrics import get_metric
from MIB_circuit_track.utils import MODEL_NAME_TO_FULLNAME, TASKS_TO_HF_NAMES, COL_MAPPING

In [47]:
model_name = "gpt2"
model = HookedTransformer.from_pretrained(model_name)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True

task = 'ioi'
graph = Graph.from_model(model, neuron_level=True, node_scores=False)

hf_task_name = f'mib-bench/{TASKS_TO_HF_NAMES[task]}'
dataset = HFEAPDataset(hf_task_name, model.tokenizer, split="validation", task=task, model_name=model_name, num_examples=1)

dataloader = dataset.to_dataloader(batch_size=1)
metric = get_metric('logit_diff', task, model.tokenizer, model)
attribution_metric = partial(metric, mean=True, loss=True)


Loaded pretrained model gpt2 into HookedTransformer


In [48]:
attribute_node(model, graph, dataloader, attribution_metric, "EAP", 
                    "patching", neuron=True, ig_steps=5,
                    optimal_ablation_path=None,
                    intervention_dataloader=dataloader)

100%|██████████| 1/1 [00:00<00:00,  5.24it/s]


In [45]:
len(graph.nodes), graph.neurons_scores.shape, graph.neurons_in_graph.shape, graph.nodes_in_graph.shape

(158, torch.Size([157, 768]), torch.Size([157, 768]), torch.Size([157]))

In [27]:
from eap.evaluate import evaluate_graph, evaluate_baseline

baseline = evaluate_baseline(model, dataloader, attribution_metric).mean().item()
results = evaluate_graph(model, graph, dataloader, attribution_metric).mean().item()
print(f"Original performance was {baseline}; the circuit's performance is {results}")

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00, 10.58it/s]
100%|██████████| 1/1 [00:00<00:00, 10.42it/s]


Original performance was -3.20346736907959; the circuit's performance is 3.045830726623535


In [46]:
for node_name, node in graph.nodes.items():
    print(f"{node_name}: {graph.forward_index(node)}")

input: 0
a0.h0: slice(1, 13, None)
a0.h1: slice(1, 13, None)
a0.h2: slice(1, 13, None)
a0.h3: slice(1, 13, None)
a0.h4: slice(1, 13, None)
a0.h5: slice(1, 13, None)
a0.h6: slice(1, 13, None)
a0.h7: slice(1, 13, None)
a0.h8: slice(1, 13, None)
a0.h9: slice(1, 13, None)
a0.h10: slice(1, 13, None)
a0.h11: slice(1, 13, None)
m0: 13
a1.h0: slice(14, 26, None)
a1.h1: slice(14, 26, None)
a1.h2: slice(14, 26, None)
a1.h3: slice(14, 26, None)
a1.h4: slice(14, 26, None)
a1.h5: slice(14, 26, None)
a1.h6: slice(14, 26, None)
a1.h7: slice(14, 26, None)
a1.h8: slice(14, 26, None)
a1.h9: slice(14, 26, None)
a1.h10: slice(14, 26, None)
a1.h11: slice(14, 26, None)
m1: 26
a2.h0: slice(27, 39, None)
a2.h1: slice(27, 39, None)
a2.h2: slice(27, 39, None)
a2.h3: slice(27, 39, None)
a2.h4: slice(27, 39, None)
a2.h5: slice(27, 39, None)
a2.h6: slice(27, 39, None)
a2.h7: slice(27, 39, None)
a2.h8: slice(27, 39, None)
a2.h9: slice(27, 39, None)
a2.h10: slice(27, 39, None)
a2.h11: slice(27, 39, None)
m2: 39
a3.h

In [49]:
graph.to_json(f'circuits/debug/importances.json')

Node input in graph: False
Node a0.h0 in graph: False
Node a0.h1 in graph: False
Node a0.h2 in graph: False
Node a0.h3 in graph: False
Node a0.h4 in graph: False
Node a0.h5 in graph: False
Node a0.h6 in graph: False
Node a0.h7 in graph: False
Node a0.h8 in graph: False
Node a0.h9 in graph: False
Node a0.h10 in graph: False
Node a0.h11 in graph: False
Node m0 in graph: False
Node a1.h0 in graph: False
Node a1.h1 in graph: False
Node a1.h2 in graph: False
Node a1.h3 in graph: False
Node a1.h4 in graph: False
Node a1.h5 in graph: False
Node a1.h6 in graph: False
Node a1.h7 in graph: False
Node a1.h8 in graph: False
Node a1.h9 in graph: False
Node a1.h10 in graph: False
Node a1.h11 in graph: False
Node m1 in graph: False
Node a2.h0 in graph: False
Node a2.h1 in graph: False
Node a2.h2 in graph: False
Node a2.h3 in graph: False
Node a2.h4 in graph: False
Node a2.h5 in graph: False
Node a2.h6 in graph: False
Node a2.h7 in graph: False
Node a2.h8 in graph: False
Node a2.h9 in graph: False
Nod

In [51]:
graph = Graph.from_json('circuits/debug/importances.json')

In [52]:
from eap.evaluate import evaluate_graph, evaluate_baseline

baseline = evaluate_baseline(model, dataloader, attribution_metric).mean().item()
results = evaluate_graph(model, graph, dataloader, attribution_metric).mean().item()
print(f"Original performance was {baseline}; the circuit's performance is {results}")

100%|██████████| 1/1 [00:00<00:00, 10.12it/s]
100%|██████████| 1/1 [00:00<00:00, 12.37it/s]

Original performance was -3.20346736907959; the circuit's performance is 3.045830726623535



