# Set up models for edge or weight masking

## Workflow:
- Load model
- Use Task with clean and corrupt data, use ACDCPP and get the ACDCPP-style edges
- Convert ACDCPP-style edges to edge mask, get either edge superset of node superset
- Apply these masks to the mask training, either by limiting edge mask to only edge superset, node superset, or by limiting weight mask to node superset

- Also need to test other baselines, like regular finetuning

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment

In [2]:
import os
import sys
import re

# import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

import torch

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


In [3]:
# set up pipeline from acdcpp to edge mask
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
from tasks.ioi.IOITask import IOITask_old, IOITask
ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25)
ioi_task.set_logit_diffs(model)

In [5]:
ioi_metric = ioi_task.get_acdcpp_metric()
def negative_abs_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ioi_metric(logits))

with t.no_grad():
    clean_logits = model(ioi_task.clean_data.toks)
    corrupt_logits = model(ioi_task.corr_data.toks)
    clean_logit_diff = ioi_task.ave_logit_diff(clean_logits, ioi_task.clean_data).item()
    corrupt_logit_diff = ioi_task.ave_logit_diff(corrupt_logits, ioi_task.corr_data).item()

In [6]:
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, ioi_task.clean_data)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, ioi_task.corr_data)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: 3.040117025375366, Corrupt direction: 1.578493595123291
Clean metric: 0.9999999403953552, Corrupt metric: 0.0


In [7]:
from ACDCPPExperiment import ACDCPPExperiment
THRESHOLDS = [0.08, .15]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ioi_task.clean_data.toks,
    corr_data=ioi_task.corr_data.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 14907.37it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 245.11it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 277597.79it/s]
 50%|█████     | 1/2 [00:08<00:08,  8.30s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])




self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 14949.09it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 246.04it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 289687.42it/s]
100%|██████████| 2/2 [00:16<00:00,  8.08s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])





In [8]:
import pickle
with open("masks/ioi_acdcpp_edges.pkl", "wb") as f:
    pickle.dump(edges_after_acdcpp, f)

In [9]:
from cb_utils.mask_utils import get_node_name

acdcpp_edges = set()
for edge in edges_after_acdcpp[0.08]:
    # split the edge into two nodes, e.g. blocks.1.attn.hook_result[:, :, 10]blocks.0.hook_mlp_in[:] into blocks.1.attn.hook_result[:, :, 10] and blocks.0.hook_mlp_in[:]
    node_1 = get_node_name(edge.split("]")[0]+"]", show_full_index=False)
    node_2 = get_node_name(edge.split("]")[1]+"]", show_full_index=False)
    if node_1 != node_2:
        acdcpp_edges.add((node_1, node_2))

Using device: cuda:0


In [10]:
from cb_utils.mask_utils import get_edge_mask_template, get_mask_from_edges, convert_mask_dict_to_params
edge_mask_template = get_edge_mask_template()
acdcpp_mask_dict = get_mask_from_edges(acdcpp_edges, edge_mask_template=edge_mask_template, num_layers=12, num_heads=12)

In [11]:
convert_mask_dict_to_params(acdcpp_mask_dict)

[tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1.,
         0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1.,
         1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1.]),
 tensor([[1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1.]]),
 tensor([0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.]),
 tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1., 1., 1., 1

In [12]:
from cb_utils.transformer import DemoTransformer
from cb_utils.models import load_demo_gpt2, tokenizer
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

model = load_demo_gpt2(means=means, mask_dict_superset=acdcpp_mask_dict)

In [13]:
from tasks import IOITask, SportsTask, OWTTask
batch_size = 64
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False)
sports = SportsTask(batch_size=batch_size, tokenizer=tokenizer, device=device)
owt = OWTTask(batch_size=batch_size, tokenizer=tokenizer, device=device)

train_tasks = {"ioi": ioi, "owt": owt}
task_weights = {"ioi": -.2, "owt": 1} # I think means preserve OWT, corrupt IOI
eval_tasks = {"ioi": ioi, "sports": sports, "owt": owt}

  table = cls._concat_blocks(blocks, axis=0)


In [14]:
mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)

In [15]:
print(torch.cuda.max_memory_allocated(device=device) / 1e9)

11.337863168


In [16]:
from cb_utils.learn_mask import train_masks

epochs_left = 500
steps_per_epoch = 10
lr = .05 # free
weight_decay = 0
evaluate_every = 1
discretize_every = 50 # 5 # free
threshold = 0.5
use_wandb = True
edge_mask_reg_strength = 10
weight_mask_reg_strength = None

optimizer = torch.optim.AdamW(mask_params, lr=lr, weight_decay=weight_decay)
train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
            # param_names=param_names, mask_params=mask_params, 
            task_weights=task_weights, eval_tasks=eval_tasks, evaluate_every=evaluate_every, discretize_every=discretize_every, threshold=threshold, edge_mask_reg_strength=edge_mask_reg_strength, weight_mask_reg_strength=None, verbose=False, use_wandb=use_wandb)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mphilliphguo[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 501/501 [2:37:06<00:00, 18.81s/it]  


VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
edge_reg_term,█▇▇▇▇▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁
num_ablated_edges,▁▆█████████
test_loss_ioi,▁▁▁▁▆▇▆▆▇▆▆▆▇▇▇▇▇▆▆▆▆▇▅▆▆▆▇██▆▇▆▇▆▇▇▆▇▇▇
test_loss_owt,▁▁▁▁█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇▇▇▇▇▇▇
test_loss_sports,▂▁▁▁█▆▆▆▆▆▆▆▆▆▅▅▆▅▆▅▆▆▆▅▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
total_loss,▂▂▂▃█▂▂▁▄▂▂▂▂▃▂▂▂▂▂▂▂▂▁▁▂▃▂▂▂▂▂▂▃▂▂▁▂▂▂▂
train_loss_ioi,▁▁▁▁▄███▆▇▇█▇▇███████▇██▇▇█▇▇▇█▇▇███▇▇█▇
train_loss_owt,▁▁▁▂███▇█▇▇███▇▇█▇█▇█▇▇█▇█▇▇▇▇█▇▇█▇▇█▇█▇

0,1
edge_reg_term,0.99725
num_ablated_edges,41.0
test_loss_ioi,21.26536
test_loss_owt,8.3449
test_loss_sports,7.61511
total_loss,-102.03638
train_loss_ioi,20.8978
train_loss_owt,8.41252


(defaultdict(list,
             {'ioi': [(0, 0, 0.6011006832122803),
               (0, 1, 0.855120837688446),
               (0, 2, 1.2980576753616333),
               (0, 3, 1.576408863067627),
               (0, 4, 2.249493360519409),
               (0, 5, 2.7327346801757812),
               (0, 6, 3.2827675342559814),
               (0, 7, 3.665067434310913),
               (0, 8, 4.137988567352295),
               (0, 9, 4.75350284576416),
               (1, 0, 10.113905906677246),
               (1, 1, 10.042095184326172),
               (1, 2, 10.332114219665527),
               (1, 3, 10.29022216796875),
               (1, 4, 10.100848197937012),
               (1, 5, 10.472070693969727),
               (1, 6, 10.507486343383789),
               (1, 7, 10.553999900817871),
               (1, 8, 10.47105884552002),
               (1, 9, 10.924793243408203),
               (2, 0, 10.84225845336914),
               (2, 1, 10.751439094543457),
               (2, 2, 10.6478071212768

In [17]:
import pickle
with open(f"masks/trained_mask_params_{epochs_left=}_{edge_mask_reg_strength=}.pkl", "wb") as f:
    pickle.dump(mask_params, f)

In [18]:
for name, p in zip(param_names, mask_params):
    if p.requires_grad:
        # print(name, p)
        # count how many zeros in p
        print(torch.sum(p == 0))

tensor(8, device='cuda:0')
tensor(2, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(6, device='cuda:0')
tensor(0, device='cuda:0')
tensor(12, device='cuda:0')
tensor(0, device='cuda:0')
