# Automatic Circuit Discovery

### Adapted from: https://colab.research.google.com/github/ArthurConmy/Easy-Transformer/blob/main/AutomaticCircuitDiscovery.ipynb

In [None]:
import os

try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    os.system("pip install git+https://github.com/ArthurConmy/Easy-Transformer.git")

except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")

Running as a Colab notebook


In [None]:
from typing import List, Tuple, Dict, Union, Optional, Callable, Any
from time import ctime
import einops
import torch
import numpy as np
from copy import deepcopy
from collections import OrderedDict
import pickle
from subprocess import call
from IPython import get_ipython

ipython = get_ipython()
if ipython is not None:
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")
from easy_transformer import EasyTransformer
from easy_transformer.utils_circuit_discovery import (
    evaluate_circuit,
    patch_all,
    direct_path_patching,
    logit_diff_io_s,
    Circuit,
    logit_diff_from_logits,
    get_datasets,
)

from easy_transformer.ioi_utils import (
    path_patching,
)

from tqdm import tqdm

from easy_transformer.experiments import (
    get_act_hook,
)
from easy_transformer.ioi_utils import (
    show_pp,
)
from easy_transformer.ioi_dataset import IOIDataset
import os

file_prefix = "archive/" if os.path.exists("archive") else ""

from easy_transformer.experiments import (
    ExperimentMetric,
    AblationConfig,
    EasyAblation,
    EasyPatching,
    PatchingConfig,
)

from easy_transformer.ioi_circuit_extraction import (
    get_circuit_replacement_hook,
)

### Load the model in

In [None]:
model_name = "gpt2"
model = EasyTransformer.from_pretrained(model_name)

### Creating a dataset

In [None]:
templates = [
    "So {name} is a really great friend, isn't",
    "So {name} is such a good cook, isn't",
    "So {name} is a very good athlete, isn't",
    "So {name} is a really nice person, isn't",
    "So {name} is such a funny person, isn't"
    ]

male_names = [
    "John",
    "David",
    "Mark",
    "Paul",
    "Ryan",
    "Gary",
    "Jack",
    "Sean",
    "Carl",
    "Joe",    
]
female_names = [
    "Mary",
    "Lisa",
    "Anna",
    "Sarah",
    "Amy",
    "Carol",
    "Karen",
    "Susan",
    "Julie",
    "Judy"
]

sentences = []
answers = []
wrongs = []

responses = [' he', ' she']

count = 0

for name in male_names + female_names:
    for template in templates:
        cur_sentence = template.format(name = name)
        sentences.append(cur_sentence)

batch_size = len(sentences)

count = 0

for _ in range(batch_size):
    if count < (0.5 * len(sentences)):
        answers.append(responses[0])
        wrongs.append(responses[1])
        count += 1
    else:
        answers.append(responses[1])
        wrongs.append(responses[0])

tokens = model.to_tokens(sentences, prepend_bos = True)
answers = torch.tensor(model.tokenizer(answers)["input_ids"]).squeeze()
wrongs = torch.tensor(model.tokenizer(wrongs)["input_ids"]).squeeze()

### Make the positions labels

In [None]:
for i, token in enumerate(model.to_str_tokens(tokens[0])):
    print(i, token)

In [None]:
positions = OrderedDict()

ones = torch.ones(size = (batch_size,)).long()

positions["name"] = ones.clone() * 2
positions["is"] = ones.clone() * 3
positions["person"] = ones.clone() * 7
positions["isn"] = ones.clone() * 9
positions["'t"] = ones.clone() * 10

### Making a baseline dataset

In [None]:
baseline_data = tokens.clone()
baseline_data[0] = model.to_tokens("That person is a really great friend, isn't", prepend_bos = True)
baseline_data = einops.repeat(baseline_data[0], "seq -> batch seq", batch = batch_size)

### Define the metric

In [None]:
def pronoun_metric(model, tokens = tokens):
    logits = model(tokens)
    logits_on_correct = logits[torch.arange(batch_size), -1, answers]
    logits_on_wrong = logits[torch.arange(batch_size), -1, wrongs]
    result = torch.mean(logits_on_correct - logits_on_wrong)
    return result.item()

In [None]:
model_performance = pronoun_metric(model, tokens)

### Make the Circuit object

In [None]:
h = Circuit(
    model,
    metric = pronoun_metric,
    orig_data = tokens,
    new_data = baseline_data,
    threshold = 0.015,
    dataset = tokens,
    orig_positions = positions,
    new_positions = positions
)

In [None]:
while h.current_node is not None:    
    h.eval(show_graphics=True, verbose=True)

    a = h.show()
    # save digraph object
    with open(file_prefix + "hypothesis_tree.dot", "w") as f:
        f.write(a.source)

    # convert to png
    call(
        [
            "dot",
            "-Tpng",
            "hypothesis_tree.dot",
            "-o",
            file_prefix + f"gpt2_hypothesis_tree_{ctime()}.png",
            "-Gdpi=600",
        ]
    )

In [None]:
h.show()

## Evaluating the Circuit's Performance

In [None]:
# From: https://github.com/ArthurConmy/Automatic-Circuit-Discovery/blob/main/easy_transformer/utils_circuit_discovery.py
def get_hook_tuple(layer, head_idx, comp=None, input=False, model_layers=12):
    """Very cursed"""
    """warning, built for 12 layer models"""

    if layer == -1:
        assert head_idx is None, head_idx
        assert comp is None, comp
        return ("blocks.0.hook_resid_pre", None)

    if comp is None:
        if head_idx is None:
            if layer < model_layers:
                if input:
                    return (f"blocks.{layer}.hook_resid_mid", None)
                else:
                    return (f"blocks.{layer}.hook_mlp_out", None)
            else:
                assert layer == model_layers
                return (f"blocks.{layer-1}.hook_resid_post", None)
        else:
            return (f"blocks.{layer}.attn.hook_result", head_idx)

    else:  # I think the QKV case here is quite different because this is INPUT to a component, not output
        assert comp in ["q", "k", "v"]
        assert head_idx is not None
        return (f"blocks.{layer}.attn.hook_{comp}_input", head_idx)

In [None]:
# From: https://github.com/ArthurConmy/Automatic-Circuit-Discovery/blob/main/easy_transformer/utils_circuit_discovery.py
def make_base_receiver_sender_objects(
    important_nodes,
):
    base_initial_senders = []
    base_receivers_to_senders = {}

    for receiver in important_nodes:
        hook = get_hook_tuple(receiver.layer, receiver.head, input=True)

        for sender_child, _, comp in receiver.children:
            if comp in ["v", "k", "q"]:
                qkv_hook = get_hook_tuple(receiver.layer, receiver.head, comp)
                if qkv_hook not in base_receivers_to_senders:
                    base_receivers_to_senders[qkv_hook] = []
                sender_hook = get_hook_tuple(sender_child.layer, sender_child.head)
                base_receivers_to_senders[qkv_hook].append(
                    (sender_hook[0], sender_hook[1], sender_child.position)
                )

            else:
                if hook not in base_receivers_to_senders:
                    base_receivers_to_senders[hook] = []
                sender_hook = get_hook_tuple(sender_child.layer, sender_child.head)
                base_receivers_to_senders[hook].append(
                    (sender_hook[0], sender_hook[1], sender_child.position)
                )

    return base_receivers_to_senders

In [None]:
# From: https://github.com/ArthurConmy/Automatic-Circuit-Discovery/blob/main/easy_transformer/utils_circuit_discovery.py
def evaluate_circuit(h):
    if h.current_node is not None:
        raise NotImplementedError("Make circuit full")

    receivers_to_senders = make_base_receiver_sender_objects(h.important_nodes)

    # what we do here is make sure that the ONLY embed objects that are set to their values on the original dataset are the ones that are in the circuit
    initial_receivers_to_senders: List[
        Tuple[Tuple[str, Optional[int]], Tuple[str, Optional[int], str]]
    ] = []
    for node in h.important_nodes:
        for child, _, _2 in node.children:
            if child.layer == -1:
                initial_receivers_to_senders.append(
                    (
                        ("blocks.0.hook_resid_pre", None),
                        ("blocks.0.hook_resid_pre", None, node.position),
                    )
                )
    assert (
        len(initial_receivers_to_senders) > 0
    ), "Need at least one embedding present!!!"

    initial_receivers_to_senders = list(set(initial_receivers_to_senders))

    for pos in h.orig_positions:
        assert torch.allclose(
            h.orig_positions[pos], h.new_positions[pos]
        ), "Data must be the same for all positions"

    model = direct_path_patching(
        model=h.model,
        orig_data=h.new_data,  # NOTE these are different
        new_data=h.orig_data,
        initial_receivers_to_senders=initial_receivers_to_senders,
        receivers_to_senders=receivers_to_senders,
        orig_positions=h.orig_positions,  # tensor of shape (batch_size,)
        new_positions=h.new_positions,
        orig_cache=None,
        new_cache=None,
    )
    return h.metric(model)

In [None]:
circuit_performance = evaluate_circuit(h)

In [None]:
circuit_percentage = 100*(circuit_performance / model_performance)
print(f"Circuit performs {circuit_percentage:.2f}% as well as the model on this task and has {len(h.important_nodes)} heads.")