# Set up models for edge or weight masking

## Workflow:
- Load model
- Use Task with clean and corrupt data, use ACDCPP and get the ACDCPP-style edges
- Convert ACDCPP-style edges to edge mask, get either edge superset of node superset
- Apply these masks to the mask training, either by limiting edge mask to only edge superset, node superset, or by limiting weight mask to node superset

- Also need to test other baselines, like regular finetuning

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment

In [2]:
import os
import sys
import re

# import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

import torch

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


In [3]:
# set up pipeline from acdcpp to edge mask
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
torch.cuda.max_memory_allocated(device=device) / 1e9

0.666549248

## Setup ACDCPP edges

In [5]:
from tasks.ioi.IOITask import IOITask_old, IOITask
# ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25)
ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25, nb_templates=2, prompt_type="mixed")
ioi_task.set_logit_diffs(model)

In [6]:
ioi_metric = ioi_task.get_acdcpp_metric()
def negative_abs_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ioi_metric(logits))

with t.no_grad():
    clean_logits = model(ioi_task.clean_data.toks)
    corrupt_logits = model(ioi_task.corr_data.toks)
    clean_logit_diff = ioi_task.ave_logit_diff(clean_logits, ioi_task.clean_data).item()
    corrupt_logit_diff = ioi_task.ave_logit_diff(corrupt_logits, ioi_task.corr_data).item()

In [7]:
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, ioi_task.clean_data)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, ioi_task.corr_data)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: 2.9320199489593506, Corrupt direction: 1.4523407220840454
Clean metric: 1.0, Corrupt metric: 0.0


In [8]:
from ACDCPPExperiment import ACDCPPExperiment
THRESHOLDS = [0.08, .15]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ioi_task.clean_data.toks,
    corr_data=ioi_task.corr_data.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 14971.80it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 249.47it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 284779.72it/s]
 50%|█████     | 1/2 [00:08<00:08,  8.01s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])




self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 14953.78it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 249.26it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 283847.79it/s]
100%|██████████| 2/2 [00:15<00:00,  7.94s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])





In [9]:
import pickle
with open("masks/ioi_acdcpp_edges.pkl", "wb") as f:
    pickle.dump(edges_after_acdcpp, f)

In [11]:
from cb_utils.mask_utils import get_node_name

acdcpp_edges = set()
for edge in edges_after_acdcpp[0.08]:
    # split the edge into two nodes, e.g. blocks.1.attn.hook_result[:, :, 10]blocks.0.hook_mlp_in[:] into blocks.1.attn.hook_result[:, :, 10] and blocks.0.hook_mlp_in[:]
    node_1 = get_node_name(edge.split("]")[0]+"]", show_full_index=False)
    node_2 = get_node_name(edge.split("]")[1]+"]", show_full_index=False)
    if node_1 != node_2:
        acdcpp_edges.add((node_1, node_2))

In [14]:
from cb_utils.mask_utils import get_edge_mask_template, get_mask_from_edges, convert_mask_dict_to_params, get_nodes_from_edges, get_mask_components
edge_mask_template = get_edge_mask_template()
acdcpp_mask_dict = get_mask_from_edges(acdcpp_edges, edge_mask_template=edge_mask_template, num_layers=12, num_heads=12)
acdcpp_nodes = get_nodes_from_edges(acdcpp_edges)
acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = get_mask_components(acdcpp_nodes, num_layers=12, num_heads=12)

In [16]:
acdcpp_weight_mask_mlp_dict

{0: True,
 1: True,
 2: True,
 3: True,
 4: True,
 5: True,
 6: True,
 7: True,
 8: False,
 9: False,
 10: False,
 11: True}

In [25]:
from cb_utils.transformer import DemoTransformer
from cb_utils.models import load_demo_gpt2, tokenizer
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

edge_masks = False
weight_masks_attn = True
weight_masks_mlp = True
# model = load_demo_gpt2(means=means, mask_dict_superset=acdcpp_mask_dict)
model = load_demo_gpt2(means=False, edge_masks=edge_masks, weight_masks_attn=weight_masks_attn, weight_masks_mlp=weight_masks_mlp, weight_mask_attn_dict=acdcpp_weight_mask_attn_dict, weight_mask_mlp_dict=acdcpp_weight_mask_mlp_dict)

In [27]:
for name, param in model.named_parameters():
    # if param.requires_grad
    if "baseline" in name:
        print(name, param.shape)

output_mask_baseline torch.Size([157])
blocks.0.edge_mask_attentions_baseline torch.Size([1, 12])
blocks.0.edge_mask_mlp_baseline torch.Size([13])
blocks.0.attn.weight_mask_W_Q_baseline torch.Size([12, 768, 64])
blocks.0.attn.weight_mask_W_K_baseline torch.Size([12, 768, 64])
blocks.0.attn.weight_mask_W_V_baseline torch.Size([12, 768, 64])
blocks.0.attn.weight_mask_W_O_baseline torch.Size([12, 64, 768])
blocks.1.edge_mask_attentions_baseline torch.Size([14, 12])
blocks.1.edge_mask_mlp_baseline torch.Size([26])
blocks.1.attn.weight_mask_W_Q_baseline torch.Size([12, 768, 64])
blocks.1.attn.weight_mask_W_K_baseline torch.Size([12, 768, 64])
blocks.1.attn.weight_mask_W_V_baseline torch.Size([12, 768, 64])
blocks.1.attn.weight_mask_W_O_baseline torch.Size([12, 64, 768])
blocks.2.edge_mask_attentions_baseline torch.Size([27, 12])
blocks.2.edge_mask_mlp_baseline torch.Size([39])
blocks.2.attn.weight_mask_W_Q_baseline torch.Size([12, 768, 64])
blocks.2.attn.weight_mask_W_K_baseline torch.Size(

### Test that gradients flow correctly

## Start Mask Learning

In [None]:
from tasks import IOITask, SportsTask, OWTTask
batch_size = 64
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False)
sports = SportsTask(batch_size=batch_size, tokenizer=tokenizer, device=device)
owt = OWTTask(batch_size=batch_size, tokenizer=tokenizer, device=device)

train_tasks = {"ioi": ioi, "owt": owt}
task_weights = {"ioi": -.2, "owt": 1} # I think means preserve OWT, corrupt IOI
eval_tasks = {"ioi": ioi, "sports": sports, "owt": owt}

  table = cls._concat_blocks(blocks, axis=0)


In [None]:
mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)

In [None]:
print(torch.cuda.max_memory_allocated(device=device) / 1e9)

11.337863168


In [None]:
from cb_utils.learn_mask import train_masks

epochs_left = 500
steps_per_epoch = 10
lr = .05 # free
weight_decay = 0
evaluate_every = 1
discretize_every = 50 # 5 # free
threshold = 0.5
use_wandb = False
edge_mask_reg_strength = None
weight_mask_reg_strength = 10

wandb_config = {"edge_masks": edge_masks, "weight_masks_attn": weight_masks_attn, "weight_masks_mlp": weight_masks_mlp, "epochs": epochs_left, "steps_per_epoch": steps_per_epoch, "lr": lr, "weight_decay": weight_decay, "evaluate_every": evaluate_every, "discretize_every": discretize_every, "threshold": threshold, "edge_mask_reg_strength": edge_mask_reg_strength, "weight_mask_reg_strength": weight_mask_reg_strength}

optimizer = torch.optim.AdamW(mask_params, lr=lr, weight_decay=weight_decay)
train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
            # param_names=param_names, mask_params=mask_params, 
            task_weights=task_weights, eval_tasks=eval_tasks, evaluate_every=evaluate_every, discretize_every=discretize_every, threshold=threshold, edge_mask_reg_strength=edge_mask_reg_strength, weight_mask_reg_strength=None, verbose=False, use_wandb=use_wandb, wandb_config=wandb_config)

[34m[1mwandb[0m: Currently logged in as: [33mphilliphguo[0m. Use [1m`wandb login --relogin`[0m to force relogin


  6%|▋         | 32/501 [10:48<2:38:21, 20.26s/it]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_93609/3197613862.py", line 15, in <module>
    train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
  File "/data/phillip_guo/mechanistic-unlearning/cb_utils/learn_mask.py", line 178, in train_masks
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call l

In [None]:
import pickle
with open(f"masks/trained_mask_params_{epochs_left=}_{edge_mask_reg_strength=}.pkl", "wb") as f:
    pickle.dump(mask_params, f)

In [None]:
for name, p in zip(param_names, mask_params):
    if p.requires_grad:
        # print(name, p)
        # count how many zeros in p
        print(torch.sum(p == 0))

tensor(8, device='cuda:0')
tensor(2, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(6, device='cuda:0')
tensor(0, device='cuda:0')
tensor(12, device='cuda:0')
tensor(0, device='cuda:0')
