# Set up models for edge or weight masking

## Workflow:
- Load model
- Use Task with clean and corrupt data, use ACDCPP and get the ACDCPP-style edges
- Convert ACDCPP-style edges to edge mask, get either edge superset of node superset
- Apply these masks to the mask training, either by limiting edge mask to only edge superset, node superset, or by limiting weight mask to node superset

- Also need to test other baselines, like regular finetuning

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment

In [2]:
import os
import sys
import re

# import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

import torch
import pickle

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


In [3]:
# set up pipeline from acdcpp to edge mask
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
torch.cuda.max_memory_allocated(device=device) / 1e9

0.666549248

## Setup ACDCPP edges

In [5]:
from tasks.ioi.IOITask import IOITask_old, IOITask
# ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25)
ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25, nb_templates=1, prompt_type="ABBA")
ioi_task.set_logit_diffs(model)

ioi_metric = ioi_task.get_acdcpp_metric()
def negative_abs_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ioi_metric(logits))

with t.no_grad():
    clean_logits = model(ioi_task.clean_data.toks)
    corrupt_logits = model(ioi_task.corr_data.toks)
    clean_logit_diff = ioi_task.ave_logit_diff(clean_logits, ioi_task.clean_data).item()
    corrupt_logit_diff = ioi_task.ave_logit_diff(corrupt_logits, ioi_task.corr_data).item()

In [6]:
ioi_task.get_batch()

{'[PLACE]': ['station', 'restaurant', 'station', 'restaurant', 'school'],
 '[OBJECT]': ['bone', 'snack', 'kiss', 'ring', 'drink'],
 'text': ['Then, Madison and David went to the station. David gave a bone to',
  'Then, Jamie and Roman went to the restaurant. Roman gave a snack to',
  'Then, Laura and Stephen went to the station. Stephen gave a kiss to',
  'Then, Clark and George went to the restaurant. George gave a ring to',
  'Then, Jamie and Maria went to the school. Maria gave a drink to'],
 'IO': ['Madison', 'Jamie', 'Laura', 'Clark', 'Jamie'],
 'S': ['David', 'Roman', 'Stephen', 'George', 'Maria'],
 'TEMPLATE_IDX': tensor([0, 0, 0, 0, 0])}

In [7]:
ioi_task_2 = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25, nb_templates=1, prompt_type="ABBA", template_start_idx=1)
ioi_task_2.get_batch()

{'[PLACE]': ['house', 'office', 'restaurant', 'hospital', 'garden'],
 '[OBJECT]': ['computer', 'basketball', 'necklace', 'bone', 'bone'],
 'text': ['Then, Brian and Madison had a lot of fun at the house. Madison gave a computer to',
  'Then, Tyler and Ruby had a lot of fun at the office. Ruby gave a basketball to',
  'Then, Georgia and Amy had a lot of fun at the restaurant. Amy gave a necklace to',
  'Then, Sullivan and Robert had a lot of fun at the hospital. Robert gave a bone to',
  'Then, Frank and Alan had a lot of fun at the garden. Alan gave a bone to'],
 'IO': ['Brian', 'Tyler', 'Georgia', 'Sullivan', 'Frank'],
 'S': ['Madison', 'Ruby', 'Amy', 'Robert', 'Alan'],
 'TEMPLATE_IDX': tensor([0, 0, 0, 0, 0])}

In [8]:
ioi_task_2 = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25, nb_templates=1, prompt_type="BABA", template_start_idx=0)
ioi_task_2.get_batch()

{'[PLACE]': ['station', 'office', 'office', 'restaurant', 'station'],
 '[OBJECT]': ['ring', 'bone', 'necklace', 'kiss', 'kiss'],
 'text': ['Then, Ford and Charlie went to the station. Ford gave a ring to',
  'Then, Mark and George went to the office. Mark gave a bone to',
  'Then, Charlie and Cole went to the office. Charlie gave a necklace to',
  'Then, Marco and Emily went to the restaurant. Marco gave a kiss to',
  'Then, Jack and Madison went to the station. Jack gave a kiss to'],
 'IO': ['Charlie', 'George', 'Cole', 'Emily', 'Madison'],
 'S': ['Ford', 'Mark', 'Charlie', 'Marco', 'Jack'],
 'TEMPLATE_IDX': tensor([0, 0, 0, 0, 0])}

In [9]:
from ACDCPPExperiment import ACDCPPExperiment
from cb_utils.mask_utils import get_masks_from_acdcpp_exp
THRESHOLDS = [0.08, .15]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ioi_task.clean_data.toks,
    corr_data=ioi_task.corr_data.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

# pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()
acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = get_masks_from_acdcpp_exp(acdcpp_exp, threshold=0.08)



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15446.21it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 251.35it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 303004.98it/s]
 50%|█████     | 1/2 [00:07<00:07,  7.83s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])




self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15573.12it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 254.40it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 321109.90it/s]
100%|██████████| 2/2 [00:15<00:00,  7.53s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])





In [72]:
import pickle
from cb_utils.transformer import DemoTransformer
from cb_utils.models import load_demo_gpt2, tokenizer
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

edge_masks = True
weight_masks_attn = True
weight_masks_mlp = True
train_base_weights = True
localize_acdcpp = True

# if edge_masks is True, then have mask_dict_superset be acdcpp_mask_dict
mask_dict_superset = None if not edge_masks else acdcpp_mask_dict
# model = load_demo_gpt2(means=means, mask_dict_superset=acdcpp_mask_dict)
if localize_acdcpp:
    weight_mask_attn_dict = acdcpp_weight_mask_attn_dict if weight_masks_attn else None
    weight_mask_mlp_dict = acdcpp_weight_mask_mlp_dict if weight_masks_mlp else None
    base_weight_attn_dict = acdcpp_weight_mask_attn_dict if train_base_weights else None
    base_weight_mlp_dict = acdcpp_weight_mask_mlp_dict if train_base_weights else None

else:
    weight_mask_attn_dict = None
    weight_mask_mlp_dict = None
    base_weight_attn_dict = None
    base_weight_mlp_dict = None

model = load_demo_gpt2(means=False, edge_masks=edge_masks, mask_dict_superset=mask_dict_superset, weight_masks_attn=weight_masks_attn, weight_masks_mlp=weight_masks_mlp, weight_mask_attn_dict=weight_mask_attn_dict, weight_mask_mlp_dict=weight_mask_mlp_dict, train_base_weights=train_base_weights, base_weight_attn_dict=base_weight_attn_dict, base_weight_mlp_dict=base_weight_mlp_dict)

In [64]:
torch.cuda.memory_allocated(device=device) / 1e9

16.789679616

## Test that gradients flow correctly

In [73]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape, param.requires_grad)
    # print(name, param.shape, param.requires_grad)

output_mask torch.Size([157]) True
blocks.0.edge_mask_attentions torch.Size([1, 12]) True
blocks.0.edge_mask_mlp torch.Size([13]) True
blocks.0.attn.W_Q torch.Size([12, 768, 64]) True
blocks.0.attn.b_Q torch.Size([12, 64]) True
blocks.0.attn.W_K torch.Size([12, 768, 64]) True
blocks.0.attn.b_K torch.Size([12, 64]) True
blocks.0.attn.W_V torch.Size([12, 768, 64]) True
blocks.0.attn.b_V torch.Size([12, 64]) True
blocks.0.attn.W_O torch.Size([12, 64, 768]) True
blocks.0.attn.b_O torch.Size([768]) True
blocks.0.attn.weight_mask_W_Q torch.Size([12, 768, 64]) True
blocks.0.attn.weight_mask_W_K torch.Size([12, 768, 64]) True
blocks.0.attn.weight_mask_W_V torch.Size([12, 768, 64]) True
blocks.0.attn.weight_mask_W_O torch.Size([12, 64, 768]) True
blocks.0.mlp.W_in torch.Size([768, 3072]) True
blocks.0.mlp.b_in torch.Size([3072]) True
blocks.0.mlp.W_out torch.Size([3072, 768]) True
blocks.0.mlp.b_out torch.Size([768]) True
blocks.0.mlp.weight_mask_W_in torch.Size([768, 3072]) True
blocks.0.mlp.w

In [74]:
model.train()
batch_size = 8
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False)
loss = ioi.get_train_loss(model)
loss.backward()

In [75]:
losses = []
for i in range(10):
    losses.append(ioi.get_train_loss(model))
    print(torch.cuda.memory_allocated(device=device) / 1e9)

7.701484032
9.50002944
11.296608768
13.095366144
14.886672896
16.662592
18.435150336
20.20689152
21.983021568
23.762801152


In [68]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if param.requires_grad: # and "edge" in name:
        # check if param.grad is all zeros
        if param.grad is not None and param.grad.sum() != 0:
            print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        else:
            print(f"{name} grad is all zeros")

output_mask grad is all zeros
blocks.0.edge_mask_attentions grad is all zeros
blocks.0.edge_mask_mlp grad is all zeros
blocks.0.attn.W_Q grad is not all zeros, param.grad.norm()=tensor(0.1548, device='cuda:0')
blocks.0.attn.b_Q grad is not all zeros, param.grad.norm()=tensor(0.0549, device='cuda:0')
blocks.0.attn.W_K grad is not all zeros, param.grad.norm()=tensor(0.2216, device='cuda:0')
blocks.0.attn.b_K grad is not all zeros, param.grad.norm()=tensor(1.4179e-08, device='cuda:0')
blocks.0.attn.W_V grad is not all zeros, param.grad.norm()=tensor(1.2717, device='cuda:0')
blocks.0.attn.b_V grad is not all zeros, param.grad.norm()=tensor(0.7394, device='cuda:0')
blocks.0.attn.W_O grad is not all zeros, param.grad.norm()=tensor(2.7787, device='cuda:0')
blocks.0.attn.b_O grad is not all zeros, param.grad.norm()=tensor(1.2196, device='cuda:0')
blocks.0.mlp.W_in grad is not all zeros, param.grad.norm()=tensor(1.8217, device='cuda:0')
blocks.0.mlp.b_in grad is not all zeros, param.grad.norm()

### Check if correct MLPs flow

In [29]:
for node in acdcpp_nodes:
    if "m" in node[1]:
        print(node)

(2, 'm2')
(10, 'm10')
(5, 'm5')
(3, 'm3')
(7, 'm7')
(1, 'm1')
(9, 'm9')
(8, 'm8')
(4, 'm4')
(0, 'm0')
(-1, 'embed')
(6, 'm6')


In [31]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if "mlp" in name:
        # check if param.grad is all zeros
        if param.grad is not None and param.grad.sum() != 0:
            print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        else:
            print(f"{name} grad is all zeros")

blocks.0.edge_mask_mlp grad is all zeros
blocks.0.edge_mask_mlp_baseline grad is all zeros
blocks.0.edge_mask_mlp_frozen grad is all zeros
blocks.0.mlp.W_in grad is not all zeros, param.grad.norm()=tensor(1.5009, device='cuda:0')
blocks.0.mlp.b_in grad is not all zeros, param.grad.norm()=tensor(0.3840, device='cuda:0')
blocks.0.mlp.W_out grad is not all zeros, param.grad.norm()=tensor(1.9580, device='cuda:0')
blocks.0.mlp.b_out grad is not all zeros, param.grad.norm()=tensor(0.2079, device='cuda:0')
blocks.1.edge_mask_mlp grad is all zeros
blocks.1.edge_mask_mlp_baseline grad is all zeros
blocks.1.edge_mask_mlp_frozen grad is all zeros
blocks.1.mlp.W_in grad is not all zeros, param.grad.norm()=tensor(2.1881, device='cuda:0')
blocks.1.mlp.b_in grad is not all zeros, param.grad.norm()=tensor(0.4780, device='cuda:0')
blocks.1.mlp.W_out grad is not all zeros, param.grad.norm()=tensor(2.0281, device='cuda:0')
blocks.1.mlp.b_out grad is not all zeros, param.grad.norm()=tensor(0.2563, device=

### Check if individual attention heads have gradients

In [76]:
for node in acdcpp_nodes:
    if "a9" in node[1]:
        print(node)

(9, 'a9.7')
(9, 'a9.9')
(9, 'a9.6')
(9, 'a9.8')


In [81]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if param.requires_grad and "9.attn.weight" in name:
    # if param.requires_grad and "9.attn.b" in name:
        print(param.shape)
        # if param.grad is not None and param.grad.sum() != 0:
        #     print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        # else:
        #     print(f"{name} grad is all zeros")
        print(f"{param.grad[3].norm()=}")
        print(f"{param.grad[4].norm()=}")
        print(f"{param.grad[7].norm()=}")
        print(f"{param.grad[8].norm()=}")
        print(f"{param.grad[11].norm()=}")
        print()


torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.1051, device='cuda:0')
param.grad[8].norm()=tensor(0.0287, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.1273, device='cuda:0')
param.grad[8].norm()=tensor(0.0335, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.0570, device='cuda:0')
param.grad[8].norm()=tensor(0.0485, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 64, 768])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.030

## Start Mask Learning

In [None]:
from tasks import IOITask, SportsTask, OWTTask
batch_size = 64
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, prompt_type="ABBA", nb_templates=1, template_start_idx=0)
sports = SportsTask(batch_size=batch_size, tokenizer=tokenizer, device=device)
owt = OWTTask(batch_size=batch_size, tokenizer=tokenizer, device=device)

ioi_ood = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, prompt_type="ABBA", nb_templates=1, template_start_idx=1) # different template

train_tasks = {"ioi": ioi, "owt": owt}
task_weights = {"ioi": -.2, "owt": 1} # I think means preserve OWT, corrupt IOI
eval_tasks = {"ioi": ioi, "sports": sports, "owt": owt}

  table = cls._concat_blocks(blocks, axis=0)


In [None]:
mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)

In [None]:
from cb_utils.learn_mask import train_masks

epochs_left = 500
steps_per_epoch = 10
lr = .05 # free
weight_decay = 0
evaluate_every = 1
discretize_every = 50 # 5 # free
threshold = 0.5
use_wandb = False
edge_mask_reg_strength = None
weight_mask_reg_strength = 10

wandb_config = {"edge_masks": edge_masks, "weight_masks_attn": weight_masks_attn, "weight_masks_mlp": weight_masks_mlp, "epochs": epochs_left, "steps_per_epoch": steps_per_epoch, "lr": lr, "weight_decay": weight_decay, "evaluate_every": evaluate_every, "discretize_every": discretize_every, "threshold": threshold, "edge_mask_reg_strength": edge_mask_reg_strength, "weight_mask_reg_strength": weight_mask_reg_strength}

optimizer = torch.optim.AdamW(mask_params, lr=lr, weight_decay=weight_decay)
train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
            # param_names=param_names, mask_params=mask_params, 
            task_weights=task_weights, eval_tasks=eval_tasks, evaluate_every=evaluate_every, discretize_every=discretize_every, threshold=threshold, edge_mask_reg_strength=edge_mask_reg_strength, weight_mask_reg_strength=None, verbose=False, use_wandb=use_wandb, wandb_config=wandb_config)

[34m[1mwandb[0m: Currently logged in as: [33mphilliphguo[0m. Use [1m`wandb login --relogin`[0m to force relogin


  6%|▋         | 32/501 [10:48<2:38:21, 20.26s/it]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_93609/3197613862.py", line 15, in <module>
    train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
  File "/data/phillip_guo/mechanistic-unlearning/cb_utils/learn_mask.py", line 178, in train_masks
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call l

In [None]:
import pickle
with open(f"masks/trained_mask_params_{epochs_left=}_{edge_mask_reg_strength=}.pkl", "wb") as f:
    pickle.dump(mask_params, f)

In [None]:
for name, p in zip(param_names, mask_params):
    if p.requires_grad:
        # print(name, p)
        # count how many zeros in p
        print(torch.sum(p == 0))

tensor(8, device='cuda:0')
tensor(2, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(6, device='cuda:0')
tensor(0, device='cuda:0')
tensor(12, device='cuda:0')
tensor(0, device='cuda:0')
