# Set up models for edge or weight masking

## Workflow:
- Load model
- Use Task with clean and corrupt data, use ACDCPP and get the ACDCPP-style edges
- Convert ACDCPP-style edges to edge mask, get either edge superset of node superset
- Apply these masks to the mask training, either by limiting edge mask to only edge superset, node superset, or by limiting weight mask to node superset

- Also need to test other baselines, like regular finetuning

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment

In [2]:
import os
import sys
import re

# import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

import torch

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


In [3]:
# set up pipeline from acdcpp to edge mask
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
from tasks.ioi.IOITask import IOITask_old, IOITask
ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25)
ioi_task.set_logit_diffs(model)

In [5]:
ioi_metric = ioi_task.get_acdcpp_metric()
def negative_abs_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ioi_metric(logits))

with t.no_grad():
    clean_logits = model(ioi_task.clean_data.toks)
    corrupt_logits = model(ioi_task.corr_data.toks)
    clean_logit_diff = ioi_task.ave_logit_diff(clean_logits, ioi_task.clean_data).item()
    corrupt_logit_diff = ioi_task.ave_logit_diff(corrupt_logits, ioi_task.corr_data).item()

In [6]:
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, ioi_task.clean_data)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, ioi_task.corr_data)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: 3.040117025375366, Corrupt direction: 1.4358559846878052
Clean metric: 1.0, Corrupt metric: 0.0


In [7]:
from ACDCPPExperiment import ACDCPPExperiment
THRESHOLDS = [0.08, .15]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ioi_task.clean_data.toks,
    corr_data=ioi_task.corr_data.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15076.25it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 246.80it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 278632.21it/s]
 50%|█████     | 1/2 [00:08<00:08,  8.38s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])




self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 14697.95it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 249.16it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 299531.07it/s]
100%|██████████| 2/2 [00:16<00:00,  8.11s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])





In [8]:
import pickle
with open("masks/ioi_acdcpp_edges.pkl", "wb") as f:
    pickle.dump(edges_after_acdcpp, f)

In [9]:
from cb_utils.mask_utils import get_node_name

acdcpp_edges = set()
for edge in edges_after_acdcpp[0.08]:
    # split the edge into two nodes, e.g. blocks.1.attn.hook_result[:, :, 10]blocks.0.hook_mlp_in[:] into blocks.1.attn.hook_result[:, :, 10] and blocks.0.hook_mlp_in[:]
    node_1 = get_node_name(edge.split("]")[0]+"]", show_full_index=False)
    node_2 = get_node_name(edge.split("]")[1]+"]", show_full_index=False)
    if node_1 != node_2:
        acdcpp_edges.add((node_1, node_2))

Using device: cuda:0


In [10]:
from cb_utils.mask_utils import get_edge_mask_template, get_mask_from_edges, convert_mask_dict_to_params
edge_mask_template = get_edge_mask_template()
acdcpp_mask_dict = get_mask_from_edges(acdcpp_edges, edge_mask_template=edge_mask_template, num_layers=12, num_heads=12)

In [11]:
convert_mask_dict_to_params(acdcpp_mask_dict)

[tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,
         0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 0., 1., 0.,
         1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1.]),
 tensor([[1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1.]]),
 tensor([0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.]),
 tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1

In [12]:
from cb_utils.transformer import DemoTransformer
from cb_utils.models import load_demo_gpt2, tokenizer
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

model = load_demo_gpt2(means=means, mask_dict_superset=acdcpp_mask_dict)

In [13]:
from tasks import IOITask, SportsTask, OWTTask
ioi = IOITask(batch_size=64, tokenizer=tokenizer, device=device, prep_acdcpp=False)
sports = SportsTask(batch_size=64, tokenizer=tokenizer, device=device)
owt = OWTTask(batch_size=64, tokenizer=tokenizer, device=device)

train_tasks = {"ioi": ioi, "owt": owt}
task_weights = {"ioi": -.5, "owt": 1} # I think means preserve OWT, corrupt IOI
eval_tasks = {"ioi": ioi, "sports": sports, "owt": owt}

  table = cls._concat_blocks(blocks, axis=0)


In [14]:
mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)

In [15]:
from cb_utils.learn_mask import train_masks

epochs_left = 200
log_every = 1
lr = .05 # free
weight_decay = 0
clamp_every = 50 # 5 # free
threshold = 0.5

optimizer = torch.optim.AdamW(mask_params, lr=lr, weight_decay=weight_decay)
train_masks(model, tasks=train_tasks, optimizer=optimizer, 
            # param_names=param_names, mask_params=mask_params, 
            task_weights=task_weights, eval_tasks=eval_tasks, num_epochs=epochs_left, clamp_every=clamp_every, threshold=threshold, edge_mask_reg_strength=1, weight_mask_reg_strength=None, verbose=True)

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 0, step 0: train loss 4.297313690185547
Evaluating on ioi
Loss on ioi: 1.3480257987976074
Evaluating on sports
Loss on sports: 3.0990006923675537
Evaluating on owt
Loss on owt: 3.577061653137207
Clamping weights
done clamping weights
Epoch 0, step 10: train loss -0.39757490158081055
Evaluating on ioi
Loss on ioi: 11.589863777160645
Evaluating on sports
Loss on sports: 6.133066177368164
Evaluating on owt
Loss on owt: 3.9215543270111084
Epoch 0, step 20: train loss -6.938171863555908
Evaluating on ioi
Loss on ioi: 25.466266632080078
Evaluating on sports
Loss on sports: 11.917108535766602
Evaluating on owt
Loss on owt: 4.4212470054626465
Epoch 0, step 30: train loss -15.102596282958984
Evaluating on ioi
Loss on ioi: 45.89265823364258
Evaluating on sports
Loss on sports: 19.449020385742188
Evaluating on owt
Loss on owt: 5.2236456871032715
Epoch 0, step 40: train loss -21.158437728881836
Evaluating on ioi
Loss on ioi: 56.313514709472656
Evaluating on sports
Loss on sports: 25.23201370

  0%|          | 1/200 [03:09<10:29:14, 189.72s/it]

Epoch 1, step 0: train loss -32.287654876708984
Evaluating on ioi
Loss on ioi: 79.13150787353516
Evaluating on sports
Loss on sports: 15.99068832397461
Evaluating on owt
Loss on owt: 6.857211589813232
Clamping weights
done clamping weights
Epoch 1, step 10: train loss -28.316158294677734
Evaluating on ioi
Loss on ioi: 72.6974105834961
Evaluating on sports
Loss on sports: 21.236833572387695
Evaluating on owt
Loss on owt: 6.704435348510742
Epoch 1, step 20: train loss -31.75678253173828
Evaluating on ioi
Loss on ioi: 78.71527099609375
Evaluating on sports
Loss on sports: 24.19527816772461
Evaluating on owt
Loss on owt: 7.018429279327393
Epoch 1, step 30: train loss -32.36194610595703
Evaluating on ioi
Loss on ioi: 80.39202117919922
Evaluating on sports
Loss on sports: 17.745651245117188
Evaluating on owt
Loss on owt: 7.030704498291016
Epoch 1, step 40: train loss -33.33784866333008
Evaluating on ioi
Loss on ioi: 81.15989685058594
Evaluating on sports
Loss on sports: 11.598390579223633
Ev