In [1]:
%load_ext autoreload
%autoreload 2

from functools import partial

from transformer_lens import HookedTransformer

from eap.graph import Graph
from eap.evaluate import evaluate_graph, evaluate_baseline
from eap.attribute import attribute

from dataset import HFEAPDataset
from metrics import get_metric
from mib_evaluations import evaluate_area_under_curve

In [2]:
model_name = "gpt2-small"
model = HookedTransformer.from_pretrained(model_name, device="cuda")
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True
model.cfg.ungroup_grouped_query_attention = True

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
dataset = HFEAPDataset("mech-interp-bench/ioi", model.tokenizer, task="ioi", num_examples=100)
dataloader = dataset.to_dataloader(20)
metric_fn = get_metric("logit_diff", "ioi", model.tokenizer, model)

In [4]:
g = Graph.from_model(model)
attribute(model, g, dataloader, partial(metric_fn, loss=True, mean=True), 'EAP')

100%|██████████| 5/5 [00:03<00:00,  1.54it/s]


In [5]:
g.apply_topn(300, True)
baseline = evaluate_baseline(model, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()
results = evaluate_graph(model, g, dataloader, partial(metric_fn, loss=False, mean=False)).mean().item()

print(f"Faithfulness: {results / baseline}. Original {baseline}, new {results}")

 20%|██        | 1/5 [00:00<00:00,  5.55it/s]

100%|██████████| 5/5 [00:01<00:00,  4.98it/s]
100%|██████████| 5/5 [00:01<00:00,  4.56it/s]

Faithfulness: 0.09384159553886247. Original 3.0810790061950684, new 0.28913336992263794





In [6]:
results = evaluate_area_under_curve(model, g, dataloader, partial(metric_fn, loss=False, mean=False))

 20%|██        | 1/5 [00:00<00:00,  5.68it/s]

100%|██████████| 5/5 [00:01<00:00,  4.99it/s]
100%|██████████| 5/5 [00:01<00:00,  4.96it/s]


Computing results for 0.1% of edges (N=32)


100%|██████████| 5/5 [00:01<00:00,  4.95it/s]


Computing results for 0.2% of edges (N=64)


100%|██████████| 5/5 [00:01<00:00,  4.69it/s]


Computing results for 0.5% of edges (N=162)


100%|██████████| 5/5 [00:01<00:00,  4.58it/s]


Computing results for 1.0% of edges (N=324)


100%|██████████| 5/5 [00:01<00:00,  4.55it/s]


Computing results for 2.0% of edges (N=649)


100%|██████████| 5/5 [00:01<00:00,  4.52it/s]


Computing results for 5.0% of edges (N=1624)


100%|██████████| 5/5 [00:01<00:00,  4.52it/s]


Computing results for 10.0% of edges (N=3249)


100%|██████████| 5/5 [00:01<00:00,  4.51it/s]


Computing results for 20.0% of edges (N=6498)


100%|██████████| 5/5 [00:01<00:00,  4.52it/s]


Computing results for 50.0% of edges (N=16245)


100%|██████████| 5/5 [00:01<00:00,  4.49it/s]


Computing results for 100% of edges (N=32491)


100%|██████████| 5/5 [00:01<00:00,  3.49it/s]
