# Auto Circuits Exploration

## Background from Paper
https://github.com/UFO-101/auto-circuit/blob/main/Transformer%20Circuit%20Metrics%20are%20not%20Robust.pdf


6 Degress of Freedom When conducting ablations
1. granuality of computational graph 
    - Attention heads and MLPS
    - Attention Heads separeted into Q, K, V for inputs
2. type of component being ablated 
    - Nodes
    - Edges
    - Branches - this is from causal scrubbing, don't understand
    
    paper focuses on edges
3. activation value used to ablate
    - Zero Ablation
    - Gaussian Noise
    - Resample Ablation - from corrupted
    - Mean ablation - mean on some distribution
    
    paper focuses on resample ablation and mean ablation
4. which token positions are ablated
    - can choose what token positions to ablate
5. ablation direction (destroy or restore signal) and set of components
![image.png](attachment:image.png)


In [1]:
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [1]:
import torch
from auto_circuit.data import load_datasets_from_json
from auto_circuit.experiment_utils import load_tl_model
from auto_circuit.prune_algos.mask_gradient import mask_gradient_prune_scores
from auto_circuit.prune_algos.edge_attribution_patching import edge_attribution_patching_prune_scores
from auto_circuit.types import AblationType, PatchType, PruneScores, CircuitOutputs
from auto_circuit.utils.ablation_activations import src_ablations, batch_src_ablations
from auto_circuit.utils.graph_utils import patch_mode, patchable_model
from auto_circuit.utils.misc import repo_path_to_abs_path
from auto_circuit.visualize import draw_seq_graph

In [2]:
device = "cpu" #TODO: debug mps error
model = load_tl_model("gpt2-small", device)



Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
path = repo_path_to_abs_path("datasets/ioi/ioi_vanilla_template_prompts.json")
train_loader, test_loader = load_datasets_from_json(
    model=model,
    path=path,
    device=device,
    prepend_bos=True,
    batch_size=16,
    train_test_size=(128, 128),
)

In [4]:
model = patchable_model(
    model,
    factorized=True,
    slice_output="last_seq",
    separate_qkv=True,
    device=device,
)

In [32]:
attribution_scores: PruneScores = mask_gradient_prune_scores(
    model=model,
    dataloader=train_loader,
    official_edges=None,
    grad_function="logit",
    answer_function="avg_diff",
    mask_val=0.0,
)

VBox(children=(          | 0/1 [00:00<?, ?it/s],))

In [38]:
def prod(x):
    p = 1
    for i in x:
        p *= i
    return p

In [44]:
len(model.srcs), len(model.dests)

(157, 445)

In [41]:
list(attribution_scores.values())[-1].shape

torch.Size([12, 79])

In [40]:
assert sum([prod(score.shape) for score in attribution_scores.values()]) == model.n_edges
assert sum(score.shape[0] for score in attribution_scores.values()) == len(model

## Basic Code Structure

In [None]:
# get edges to "ablate with" first (e.g. mean ablations, resample ablations from corrupted)
# add ablation edges to "mask" that interpolates from clean (0) to ablated (1)
# run forward pass on clean distribution, compute loss, compute gradients with respect to mask
    # gradients are attribution scores
# returns dest wrapper scores, which are (dest, src) matricies per module

# Ablate All but topk edges

In [45]:
ablations = src_ablations(
    model, 
    test_loader,
    ablation_type=AblationType.TOKENWISE_MEAN_CLEAN_AND_CORRUPT
)

In [52]:
from auto_circuit.utils.tensor_ops import prune_scores_threshold
from auto_circuit.prune import run_circuits
from auto_circuit.metrics.prune_metrics.kl_div import measure_kl_div

In [51]:
circuit_outs = run_circuits(
    model, 
    test_loader, 
    [5, 10, 20],
    attribution_scores,
    patch_type=PatchType.TREE_PATCH,
    ablation_type=AblationType.TOKENWISE_MEAN_CLEAN_AND_CORRUPT,
    reverse_clean_corrupt=False,
)

VBox(children=(          | 0/8 [00:00<?, ?it/s],))

In [53]:
measure_kl_div(model, test_loader, circuit_outs)

VBox(children=(          | 0/3 [00:00<?, ?it/s],))

[(5, 3.166714668273926), (10, 3.0366463661193848), (20, 2.9997355937957764)]

# Edge Pruning Detector

In [None]:
# compute Prune scores on trusted distribution using mean ablation over entire dataset 
# compute patches on untrusted distribution (can vary using mean ablation from trusted, untrusted, combined)
# compute kl divergence between model and ablated model

In [6]:
from cupbearer import tasks
from elk_experiments.tiny_natural_mechanisms_utils import get_task_subset

In [57]:
x = next(iter(train_loader))

In [7]:
device = "cpu"
model_name = "gpt2-small"
task = get_task_subset(tasks.tiny_natural_mechanisms("hex", device, model_name), 16, 8, 8)

Loaded pretrained model attn-only-1l into HookedTransformer
Moving model to device:  cpu




Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cpu


In [8]:
set_model(task.model)

NameError: name 'set_model' is not defined

In [117]:
import json
with open(path, "r") as f:
    data = json.load(f)

In [118]:
data.keys()

dict_keys(['seq_labels', 'word_idxs', 'prompts'])

In [120]:
task.trusted_data[0]

(tensor([    2,     2,     2,    67,    24,    65,    15,    12, 17457,  6659,
            12,    19, 15711,    12,  1350,    22]),
 19)

In [15]:
from auto_circuit.data import PromptDataLoader, PromptDataset

In [12]:
from cupbearer.tasks.tiny_natural_mechanisms import get_effect_tokens

In [13]:
effect_tokens = get_effect_tokens("hex", task.model)

In [16]:
def make_prompt_dataset(data, effect_tokens, vocab_size):
    clean_prompts = [x[0] for x in data]
    answers = [effect_tokens] * len(clean_prompts)
    wrong_answers = [list(set(range(vocab_size)) - set(answer)) for answer in answers]
    
    # put into torch tensors
    clean_prompts = torch.stack(clean_prompts, dim=0)
    corrupt_prompts = torch.stack([torch.zeros_like(clean_prompts[0], dtype=int)] * len(clean_prompts), dim=0)
    answers = [torch.tensor(answer, dtype=int) for answer in answers]
    wrong_answers= [torch.tensor(answer, dtype=int) for answer in wrong_answers]

    return PromptDataset(clean_prompts, corrupt_prompts, answers, wrong_answers)



In [17]:
train_set = make_prompt_dataset(task.trusted_data, effect_tokens, task.model.tokenizer.vocab_size)

In [24]:
train_set.clean_prompts.size(1)

16

In [182]:
# tokenize data
train_loader = PromptDataLoader(
    prompt_dataset=train_set, 
    seq_len=16, 
    diverge_idx=0
)

In [141]:
task.model = patchable_model(
    task.model,
    factorized=True,
    slice_output="last_seq",
    separate_qkv=True,
    device=device,
)

AssertionError: Model is already patchable

In [None]:
from tqdm import tqdm

In [162]:
from auto_circuit.utils.graph_utils import (
    patch_mode,
    set_all_masks,
    train_mask_mode,
)
from auto_circuit.utils.tensor_ops import batch_avg_answer_diff, batch_avg_answer_val

In [170]:
integrated_grad_samples = None
grad_function = "logit"
mask_val = 0.0
answer_func

In [164]:
from tqdm import tqdm

In [161]:
import torch as t
from torch.nn.functional import log_softmax

In [187]:
attribution_scores, src_outs = mask_gradient_prune_scores(
    model=task.model,
    dataloader=train_loader,
    official_edges=None,
    grad_function="logit",
    answer_function="avg_diff",
    ablation_type=AblationType.TOKENWISE_MEAN_CLEAN,
    clean_corrupt=None,
    mask_val=0.0,
    return_src_outs=True
)

VBox(children=(          | 0/1 [00:00<?, ?it/s],))

In [None]:
# compute individual kl scores for each element in trusted and untrusted

In [184]:
clean_test = make_prompt_dataset(task.test_data.normal_data, effect_tokens, task.model.tokenizer.vocab_size)
anomalous_test = make_prompt_dataset(task.test_data.anomalous_data, effect_tokens, task.model.tokenizer.vocab_size)

In [220]:
clean_loader = PromptDataLoader(
    prompt_dataset=clean_test, 
    seq_len=16, 
    diverge_idx=0,
    batch_size = 1
)
anomalous_loader = PromptDataLoader(
    prompt_dataset=anomalous_test, 
    seq_len=16, 
    diverge_idx=0, 
    batch_size = 1
)

In [197]:
all(torch.equal(list(src_outs.values())[0], out) for out in src_outs.values())

True

In [196]:
next(iter(src_outs.values())).shape

torch.Size([157, 1, 16, 768])

In [217]:
circuit_outs = run_circuits(
    task.model, 
    clean_loader, 
    [5, 10, 20],
    attribution_scores,
    patch_type=PatchType.TREE_PATCH,
    ablation_type=AblationType.TOKENWISE_MEAN_CLEAN,
    patch_src_outs=next(iter(src_outs.values())),
)
            


VBox(children=(          | 0/8 [00:00<?, ?it/s],))

In [246]:
meas_clean = measure_kl_div(task.model, clean_loader, circuit_outs, reduce=None)

VBox(children=(          | 0/3 [00:00<?, ?it/s],))

In [221]:
circuit_outs_anom = run_circuits(
    task.model, 
    anomalous_loader, 
    [5, 10, 20],
    attribution_scores,
    patch_type=PatchType.TREE_PATCH,
    ablation_type=AblationType.TOKENWISE_MEAN_CLEAN,
    patch_src_outs=next(iter(src_outs.values())),
)
       

VBox(children=(          | 0/8 [00:00<?, ?it/s],))

In [240]:
meas_anom = measure_kl_div(task.model, anomalous_loader, circuit_outs_anom, reduce=None)

VBox(children=(          | 0/3 [00:00<?, ?it/s],))

In [252]:
count = 0 
correct = 0
for clean_edge_ls, anom_edge_ls in zip(meas_clean, meas_anom):
    for clean, anom in zip(clean_edge_ls[1], anom_edge_ls[1]):
        if clean < anom:
            correct += 1
        count += 1
print(correct / count)

0.875


In [257]:
task.model.tokenizer.decode(clean_test[7].clean)

'###org/cpython/rev/8c03fe2318'

In [None]:
clean_test[0]

In [259]:
meas_clean[1], meas_anom[1]

((10,
  [0.42782506346702576,
   0.396804541349411,
   0.9605979323387146,
   0.47550007700920105,
   0.32582902908325195,
   3.6614952087402344,
   1.5085805654525757,
   0.3816421627998352]),
 (10,
  [2.068248987197876,
   2.634714365005493,
   3.8851945400238037,
   1.82752525806427,
   1.5908031463623047,
   2.1713614463806152,
   2.162188768386841,
   2.4708309173583984]))

In [None]:
# TODO: make into a detector (prefreabley very general with a score fucntion...)