# Set up models for edge or weight masking

## Workflow:
- Load model
- Use Task with clean and corrupt data, use ACDCPP and get the ACDCPP-style edges
- Convert ACDCPP-style edges to edge mask, get either edge superset of node superset
- Apply these masks to the mask training, either by limiting edge mask to only edge superset, node superset, or by limiting weight mask to node superset

- Also need to test other baselines, like regular finetuning

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment

In [2]:
import os
import sys
import re

# import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t

from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

import torch
import pickle

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


In [3]:
# set up pipeline from acdcpp to edge mask
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

gpt2_tokenizer = model.tokenizer

reference_gpt2 = model

Loaded pretrained model gpt2-small into HookedTransformer


## Check DemoTransformer Implementations Correct

In [4]:
from cb_utils.transformers.gpt2.edge_masked_transformer import DemoTransformer as GPT2EdgeDemoTransformer, Config as GPT2Config

from cb_utils.models import tl_config_to_demo_config
with open("models/gpt2_weights.pkl", "rb") as f:
    gpt2_weights = pickle.load(f)
demo_edge_gpt2 = GPT2EdgeDemoTransformer(GPT2Config(debug=False, n_layers=12, n_heads=12), means=False)
demo_edge_gpt2.load_state_dict(gpt2_weights, strict=False)
demo_edge_gpt2.cuda()


from cb_utils.transformers.gpt2.weight_masked_transformer import DemoTransformer as GPT2WeightDemoTransformer, Config as GPT2Config
demo_weight_gpt2 = GPT2WeightDemoTransformer(GPT2Config(debug=False, n_layers=12, n_heads=12))
demo_weight_gpt2.load_state_dict(gpt2_weights, strict=False)
demo_weight_gpt2.cuda()

with torch.no_grad():
    test_input = t.tensor(gpt2_tokenizer.encode("The quick brown fox jumps over the lazy")).unsqueeze(0).cuda()
    print((demo_edge_gpt2(test_input)[0][0, -1] - reference_gpt2(test_input)[0, -1]).std())
    print((demo_weight_gpt2(test_input)[0][0, -1] - reference_gpt2(test_input)[0, -1]).std())
    print((demo_edge_gpt2(test_input)[0][0, -1] - demo_weight_gpt2(test_input)[0][0, -1]).std())
    # print((model(test_input)[0][0, -1] - edge_masked_model(test_input)[0][0, -1]).std())
    # print((reference_pythia(test_input)[0, -1] - edge_masked_model(test_input)[0][0, -1]).std())

Using device: cuda:0
tensor(7.2482e-06, device='cuda:0')
tensor(7.2237e-06, device='cuda:0')
tensor(6.9957e-06, device='cuda:0')


## Setup ACDCPP edges

In [5]:
from tasks.ioi.IOITask import IOITask_old, IOITask
# ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25)
ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25, nb_templates=1, prompt_type="ABBA")
ioi_task.set_logit_diffs(model)

ioi_metric = ioi_task.get_acdcpp_metric()
def negative_abs_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ioi_metric(logits))

with t.no_grad():
    clean_logits = model(ioi_task.clean_data.toks)
    corrupt_logits = model(ioi_task.corr_data.toks)
    clean_logit_diff = ioi_task.ave_logit_diff(clean_logits, ioi_task.clean_data).item()
    corrupt_logit_diff = ioi_task.ave_logit_diff(corrupt_logits, ioi_task.corr_data).item()
    print(f'Clean logit diff: {clean_logit_diff:.3f}, Corrupt logit diff: {corrupt_logit_diff:.3f}')

OpenAI API key not found, will not be able to run evaluations on HPSAQ Task
OpenAI API key not found, will not be able to run evaluations on HPSAQ Task
Clean logit diff: 3.040117025375366, Corrupted logit diff: 1.2651995420455933
Clean logit diff: 3.040, Corrupt logit diff: 1.265


In [6]:
ioi_metric(clean_logits)

tensor(1., device='cuda:0')

In [8]:
ioi_metric(corrupt_logits, ioi_dataset=ioi_task.corr_data)

tensor(0., device='cuda:0')

In [9]:
clean_logit_diff, corrupt_logit_diff

(3.040117025375366, 1.2651995420455933)

In [11]:
ioi_task_2 = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25, nb_templates=1, prompt_type="BABA", template_start_idx=0)
ioi_task_2.get_batch()

{'[PLACE]': ['station',
  'restaurant',
  'restaurant',
  'restaurant',
  'restaurant'],
 '[OBJECT]': ['ring', 'computer', 'necklace', 'bone', 'computer'],
 'text': ['Then, William and Richard went to the station. William gave a ring to',
  'Then, Charles and Jeremy went to the restaurant. Charles gave a computer to',
  'Then, Simon and Clark went to the restaurant. Simon gave a necklace to',
  'Then, Jacob and Scott went to the restaurant. Jacob gave a bone to',
  'Then, Steven and Sullivan went to the restaurant. Steven gave a computer to'],
 'IO': ['Richard', 'Jeremy', 'Clark', 'Scott', 'Sullivan'],
 'S': ['William', 'Charles', 'Simon', 'Jacob', 'Steven'],
 'TEMPLATE_IDX': tensor([0, 0, 0, 0, 0])}

In [13]:
from ACDCPPExperiment import ACDCPPExperiment
from cb_utils.mask_utils import get_masks_from_acdcpp_exp
THRESHOLDS = [0.08, .15]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ioi_task.clean_data.toks,
    corr_data=ioi_task.corr_data.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

# pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()
acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = get_masks_from_acdcpp_exp(acdcpp_exp, threshold=0.08)



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15299.74it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 258.14it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 281818.85it/s]


dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])




self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15224.87it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 253.97it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 306343.88it/s]
100%|██████████| 2/2 [00:15<00:00,  7.86s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])





In [21]:
acdcpp_edges

{((0, 'a0.10'), (-1, 'embed')),
 ((0, 'a0.9'), (-1, 'embed')),
 ((0, 'm0'), (-1, 'embed')),
 ((0, 'm0'), (0, 'a0.10')),
 ((1, 'm1'), (0, 'm0')),
 ((3, 'a3.0'), (0, 'm0')),
 ((3, 'm3'), (0, 'm0')),
 ((3, 'm3'), (1, 'm1')),
 ((3, 'm3'), (3, 'a3.0')),
 ((4, 'm4'), (0, 'm0')),
 ((4, 'm4'), (3, 'a3.0')),
 ((5, 'a5.9'), (0, 'm0')),
 ((5, 'a5.9'), (1, 'm1')),
 ((5, 'a5.9'), (3, 'a3.0')),
 ((5, 'a5.9'), (3, 'm3')),
 ((5, 'm5'), (3, 'a3.0')),
 ((5, 'm5'), (5, 'a5.5')),
 ((6, 'm6'), (3, 'a3.0')),
 ((6, 'm6'), (5, 'm5')),
 ((7, 'a7.9'), (0, 'm0')),
 ((7, 'm7'), (6, 'm6')),
 ((8, 'a8.10'), (4, 'm4')),
 ((8, 'a8.10'), (5, 'a5.5')),
 ((8, 'a8.10'), (5, 'a5.9')),
 ((8, 'a8.6'), (5, 'a5.5')),
 ((9, 'a9.9'), (8, 'a8.10')),
 ((10, 'a10.0'), (8, 'a8.10')),
 ((10, 'a10.7'), (0, 'm0')),
 ((10, 'a10.7'), (6, 'm6')),
 ((10, 'a10.7'), (9, 'a9.6')),
 ((10, 'a10.7'), (9, 'a9.9')),
 ((11, 'a11.10'), (0, 'm0')),
 ((11, 'a11.10'), (6, 'm6')),
 ((11, 'a11.10'), (8, 'a8.10')),
 ((11, 'a11.10'), (9, 'a9.6')),
 ((11, 

In [3]:
threshold = 0.001
import pickle
with open('localizations/eap/eap_sports/1000_graph.pkl', 'rb') as f:
    graph = pickle.load(f)

eap_edges = graph.top_edges(n=1000, threshold=threshold)
# eap_edges = set()
# for i in range(eap_scores.shape[0]):
#     for j in range(eap_scores.shape[1]):
#         if eap_scores[i, j] > threshold:
            # eap_edges.add((get_node_name(graph.node_names[i], show_full_index=False), get_node_name(graph.node_names[j, show_full_index=False))))
print(eap_edges)

[('head.0.2', 'mlp.0', 0.009443754330277443), ('head.0.14', 'mlp.0', 0.006188404746353626), ('mlp.6', 'mlp.8', 0.00573932658880949), ('mlp.0', 'mlp.2', 0.005653967149555683), ('head.16.20', 'mlp.16', -0.005345507059246302), ('head.16.20', 'head.17.30.v', 0.005315450485795736), ('mlp.0', 'head.1.16.k', -0.005109565332531929), ('head.14.14', 'mlp.15', -0.004921694286167622), ('mlp.15', 'head.16.20.k', -0.004847021773457527), ('mlp.6', 'mlp.15', -0.004780464340001345), ('mlp.10', 'head.16.20.k', 0.004766407422721386), ('mlp.6', 'head.16.20.k', 0.004757486749440432), ('mlp.0', 'head.1.15.k', -0.004494336899369955), ('mlp.0', 'mlp.5', -0.004315529949963093), ('mlp.0', 'mlp.4', -0.004094669129699469), ('head.16.20', 'mlp.18', 0.004019735846668482), ('head.0.30', 'mlp.0', 0.0039080469869077206), ('mlp.0', 'mlp.6', -0.0038432518485933542), ('mlp.8', 'mlp.9', 0.00376768596470356), ('head.16.20', 'head.21.9.v', 0.003543522208929062), ('mlp.9', 'mlp.11', 0.003415714716538787), ('mlp.6', 'head.21.

In [4]:
len(graph.eap_scores.flatten())

3277824

In [5]:
"""
Convert format:
want: {((3, 'm3'), (3, 'a3.0')),
 ((4, 'm4'), (0, 'm0')),
 ((4, 'm4'), (3, 'a3.0')),
 ((5, 'a5.9'), (0, 'm0')),}

have:
[('mlp.0', 'mlp.2', 0.005653967149555683),
('head.0.14', 'mlp.0', 0.006188404746353626),]
...
"""
from cb_utils.mask_utils import get_formatted_edges_from_eap, get_masks_from_eap_exp
# formatted_eap_edges = get_formatted_edges_from_eap(eap_edges)
# formatted_eap_edges
with open('localizations/eap/eap_sports/1000_graph.pkl', 'rb') as f:
    graph = pickle.load(f)
acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = get_masks_from_eap_exp(graph, threshold=0.001, num_layers=32, num_heads=32)

In [6]:
(acdcpp_mask_dict['m31'] == 0).sum()

tensor(5)

In [47]:
from tasks import InductionTask
ind_task = InductionTask(batch_size=16, tokenizer=model.tokenizer, prep_acdcpp=True, seq_len=10, acdcpp_metric="ave_logit_diff")
ind_task.set_logit_diffs(model)
print(ind_task.clean_logit_diff, ind_task.corrupted_logit_diff)

12.292320251464844 1.4027974605560303


In [48]:
ind_metric = ind_task.get_acdcpp_metric()
def negative_abs_ind_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ind_metric(logits))

with t.no_grad():
    clean_logits = model(ind_task.clean_data.cuda())
    corrupt_logits = model(ind_task.corr_data.cuda())
    clean_logit_diff = ind_task.ave_logit_diff(clean_logits, ind_task.clean_data).item()
    corrupt_logit_diff = ind_task.ave_logit_diff(corrupt_logits, ind_task.corr_data).item()
    
print(ind_metric(clean_logits))
print(ind_metric(corrupt_logits))

In [51]:
from ACDCPPExperiment import ACDCPPExperiment
from cb_utils.mask_utils import get_masks_from_acdcpp_exp
THRESHOLDS = [0.05]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ind_task.clean_data,
    corr_data=ind_task.corr_data,
    acdc_metric=negative_abs_ind_metric,
    acdcpp_metric=ind_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

# pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()
acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = get_masks_from_acdcpp_exp(acdcpp_exp, threshold=THRESHOLDS[0])



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15171.34it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 252.85it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 304067.19it/s]
100%|██████████| 1/1 [00:08<00:00,  8.35s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 5, 0, 1, 2, 3, 4, 6, 12])





## Loading models

In [13]:
from cb_utils.models import load_demo_pythia, load_demo_gpt2
threshold = 0.05
with open(f"localizations/eap/ioi/gpt2_{threshold=}.pkl", "rb") as f:
    acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = pickle.load(f)

# model = load_demo_pythia(means=False, model_name="pythia-2.8b", edge_masks=True, mask_dict_superset=acdcpp_mask_dict)

model = load_demo_gpt2(means=False, edge_mask=True, weight_mask=False, mask_dict_superset=acdcpp_mask_dict)

# model = load_demo_gpt2(means=False, edge_mask=False, weight_mask=True, weight_mask_attn_dict=acdcpp_weight_mask_attn_dict, weight_mask_mlp_dict=acdcpp_weight_mask_mlp_dict)

Loaded edge-masked transformer


In [6]:
reference_gpt = HookedTransformer.from_pretrained(
        'gpt2-small',
        fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,
        # default_padding_side="left",
        # device='cuda'
        device='cuda'
    )
tokenizer = reference_gpt.tokenizer
# pythia_tokenizer = reference_pythia.tokenizer
# reference_pythia.to("cuda")

Loaded pretrained model gpt2-small into HookedTransformer


## Run some tests

In [8]:
# compare the two models
# compare model outputs
with torch.no_grad():
    test_input = t.tensor(tokenizer.encode("The quick brown fox jumps over the lazy")).unsqueeze(0).cuda()
    print((model(test_input)[0][0, -1] - reference_gpt(test_input)[0, -1]).std())
    # print((model(test_input)[0][0, -1] - edge_masked_model(test_input)[0][0, -1]).std())
    # print((reference_pythia(test_input)[0, -1] - edge_masked_model(test_input)[0][0, -1]).std())


tensor(6.8494e-06, device='cuda:0')


In [9]:
from tasks import SportsTask, InductionTask, IOITask, InductionTask_Uniform, OWTTask#, SportsTask_Uniform
from tasks.facts.SportsTask import SportsTask_Uniform
sports_task = SportsTask(batch_size=32, tokenizer=tokenizer, prep_acdcpp=False)
ioi_task = IOITask(batch_size=32, tokenizer=tokenizer, prep_acdcpp=False)
ind_task = InductionTask(batch_size=32, tokenizer=tokenizer, prep_acdcpp=False)
ind_uniform_task = InductionTask_Uniform(batch_size=16, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15, uniform_over="rep_tokens")
owt_task = OWTTask(batch_size=32, tokenizer=tokenizer, device=device, ctx_length=30)
sports_uniform_task = SportsTask_Uniform(batch_size=32, tokenizer=tokenizer, uniform_over="sports_tokens")
print(sports_task.get_test_accuracy(model))#, sports_task.get_test_accuracy(reference_pythia))
print(ioi_task.get_test_accuracy(model))#, ioi_task.get_test_accuracy(reference_pythia))
print(ind_task.get_test_accuracy(model))#, ind_task.get_test_accuracy(reference_pythia))
print(ind_uniform_task.get_test_loss(model))#, ind_uniform_task.get_test_loss(reference_pythia))
print(owt_task.get_test_loss(model))#, owt_task.get_test_accuracy(reference_pythia))
print(sports_uniform_task.get_test_loss(model))#, sports_uniform_task.get_test_accuracy(reference_pythia))

  table = cls._concat_blocks(blocks, axis=0)


0.5
0.96875
1.0
tensor(15.3725, device='cuda:0')
tensor(3.5195, device='cuda:0')
tensor(3.1784, device='cuda:0')


In [6]:
sports_uniform_task = SportsTask_Uniform(batch_size=32, tokenizer=tokenizer, uniform_over="all_tokens")
print(sports_uniform_task.get_test_loss(model))#, sports_uniform_task.get_test_accuracy(reference_pythia))

tensor(20.2142, device='cuda:0')


In [7]:
from tasks import LimitedSportsTask
forget_task = LimitedSportsTask(batch_size=32, tokenizer=tokenizer, start_index=0, stop_index=64, make_complementary_task=True)
remember_task = forget_task.complementary_task
forget_task.get_test_accuracy(model), remember_task.get_test_accuracy(model)

(1.0, 1.0)

In [10]:
torch.cuda.memory_allocated(device=device) / 1e9

2.828438528

In [9]:
# test max batch size for sports, owt
owt_task = OWTTask(batch_size=1, tokenizer=tokenizer, ctx_length=30)
sports_task = SportsTask(batch_size=1, tokenizer=tokenizer, shuffle=False)

  table = cls._concat_blocks(blocks, axis=0)


In [12]:
sports_task = SportsTask(batch_size=1, tokenizer=tokenizer, shuffle=False)
tot_loss = 0
print(torch.cuda.memory_allocated(device=device) / 1e9)
for i in range(4):
    loss = sports_task.get_train_loss(model)
    tot_loss += loss
    print(torch.cuda.memory_allocated(device=device) / 1e9)
tot_loss.backward()
print(torch.cuda.memory_allocated(device=device) / 1e9)

31.304587776
41.535725568
51.745850368
61.98750208
72.183995392
31.304588288


In [70]:
tot_loss = 0
print(torch.cuda.memory_allocated(device=device) / 1e9)
model.zero_grad()
for i in range(20):
    loss = sports_task.get_train_loss(model)
    print(torch.cuda.memory_allocated(device=device) / 1e9)
    print(model.blocks[2].edge_mask_mlp.grad)
    loss.backward()
print(torch.cuda.memory_allocated(device=device) / 1e9)

11.808385536
19.528289792
None
19.150077952
tensor([-0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000,  0.0000,
         0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000,
         0.0000,  0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,  0.0000,
        -0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
        -0.0000, -0.1926,  0.0000,  0.0000,  0.0000, -0.0000, -0.0000,  0.0000,
         0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000,
         0.0000,  0.0000,  0.0000, -0.0430,  0.0000,  0.0000,  0.0000, -0.0000,
        -0.0000,  0.0000, -0.0000, -0.0041,  0.0266,  0.0000,  0.0000,  0.0000,
        -0.0000,  0.0000,  0.0028], device='cuda:0')
19.530473472
tensor([-0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000,  0.0000,  0.0000,
         0.0000,  0.0000, -0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000,
         0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000,  0.0000,
        -0

In [68]:
model.blocks[2].edge_mask_mlp.grad

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000, -0.4634,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.1553,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.2797,  0.0654,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.3760], device='cuda:0')

## Test that gradients flow correctly

### Check reg term

In [16]:
weight_reg_term, tot_weight_params = model.get_edge_reg()
print(weight_reg_term)
print(tot_weight_params)

tensor(95., device='cuda:0', grad_fn=<AddBackward0>)
tensor(95., device='cuda:0')


In [15]:
print(acdcpp_weight_mask_attn_dict)
print(acdcpp_weight_mask_mlp_dict)

{0: [1, 6, 9, 10], 1: [7, 11], 2: [], 3: [0, 4], 4: [11], 5: [5, 8, 9], 6: [0], 7: [3, 9], 8: [6, 10], 9: [6, 7, 8, 9], 10: [0, 1, 2, 6, 7, 10], 11: [2, 3, 10]}
{0: True, 1: True, 2: True, 3: True, 4: True, 5: True, 6: True, 7: True, 8: True, 9: True, 10: True, 11: False}


In [19]:
# step
optim = t.optim.Adam(model.parameters(), lr=1e-3)
for i in range(1000):
    weight_reg_term, tot_weight_params = model.get_edge_reg()
    loss = weight_reg_term / tot_weight_params
    print(loss)
    loss.backward()
    optim.step()

tensor(0.2483, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2493, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2503, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2513, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2523, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2533, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2543, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2553, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2563, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2573, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2582, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2592, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2602, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2612, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2622, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2632, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2642, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2652, device='cuda:0', grad_fn=<DivBack

KeyboardInterrupt: 

In [88]:
model.blocks[10].mlp.weight_mask_W_in

Parameter containing:
tensor([[0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
        [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
        [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
        ...,
        [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
        [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
        [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994]],
       device='cuda:0', requires_grad=True)

In [82]:
model.blocks[10].attn.weight_mask_W_Q

Parameter containing:
tensor([[[0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         ...,
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994]],

        [[0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         ...,
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994]],

        [[0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  ..., 0.9994, 0.9994, 0.9994],
         [0.9994, 0.9994, 0.9994,  .

In [50]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape, param.requires_grad)
    # print(name, param.shape, param.requires_grad)

blocks.0.attn.weight_mask_W_Q torch.Size([12, 768, 64]) True
blocks.0.attn.weight_mask_W_K torch.Size([12, 768, 64]) True
blocks.0.attn.weight_mask_W_V torch.Size([12, 768, 64]) True
blocks.0.attn.weight_mask_W_O torch.Size([12, 64, 768]) True
blocks.0.mlp.weight_mask_W_in torch.Size([768, 3072]) True
blocks.0.mlp.weight_mask_W_out torch.Size([3072, 768]) True
blocks.0.mlp.weight_mask_b_in torch.Size([3072]) True
blocks.0.mlp.weight_mask_b_out torch.Size([768]) True
blocks.1.attn.weight_mask_W_Q torch.Size([12, 768, 64]) True
blocks.1.attn.weight_mask_W_K torch.Size([12, 768, 64]) True
blocks.1.attn.weight_mask_W_V torch.Size([12, 768, 64]) True
blocks.1.attn.weight_mask_W_O torch.Size([12, 64, 768]) True
blocks.1.mlp.weight_mask_W_in torch.Size([768, 3072]) True
blocks.1.mlp.weight_mask_W_out torch.Size([3072, 768]) True
blocks.1.mlp.weight_mask_b_in torch.Size([3072]) True
blocks.1.mlp.weight_mask_b_out torch.Size([768]) True
blocks.2.attn.weight_mask_W_Q torch.Size([12, 768, 64]) Tr

In [18]:
model.train()
batch_size = 2
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False)
loss = ioi.get_train_loss(model)
loss.backward()

In [19]:
torch.cuda.memory_allocated(device=device) / 1e9

32.17700352

In [20]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if param.requires_grad: # and "edge" in name:
        # check if param.grad is all zeros
        if param.grad is not None and param.grad.sum() != 0:
            print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        else:
            print(f"{name} grad is all zeros")

blocks.0.attn.weight_mask_W_Q grad is not all zeros, param.grad.norm()=tensor(0.0216, device='cuda:0')
blocks.0.attn.weight_mask_W_K grad is not all zeros, param.grad.norm()=tensor(1.5815, device='cuda:0')
blocks.0.attn.weight_mask_W_V grad is not all zeros, param.grad.norm()=tensor(1.3946, device='cuda:0')
blocks.0.attn.weight_mask_W_O grad is not all zeros, param.grad.norm()=tensor(1.7858, device='cuda:0')
blocks.0.mlp.weight_mask_W_in grad is not all zeros, param.grad.norm()=tensor(3.0081, device='cuda:0')
blocks.0.mlp.weight_mask_W_out grad is not all zeros, param.grad.norm()=tensor(5.3979, device='cuda:0')
blocks.0.mlp.weight_mask_b_in grad is not all zeros, param.grad.norm()=tensor(0.5360, device='cuda:0')
blocks.0.mlp.weight_mask_b_out grad is not all zeros, param.grad.norm()=tensor(0.6562, device='cuda:0')
blocks.1.attn.weight_mask_W_Q grad is not all zeros, param.grad.norm()=tensor(0.0523, device='cuda:0')
blocks.1.attn.weight_mask_W_K grad is not all zeros, param.grad.norm()=

In [21]:
acdcpp_nodes

{(0, 'a0.21'),
 (0, 'a0.30'),
 (0, 'm0'),
 (1, 'a1.0'),
 (1, 'a1.10'),
 (1, 'a1.11'),
 (1, 'a1.12'),
 (1, 'a1.14'),
 (1, 'a1.15'),
 (1, 'a1.16'),
 (1, 'a1.17'),
 (1, 'a1.18'),
 (1, 'a1.25'),
 (1, 'a1.26'),
 (1, 'a1.29'),
 (1, 'a1.30'),
 (1, 'a1.4'),
 (1, 'a1.6'),
 (1, 'm1'),
 (2, 'a2.11'),
 (2, 'a2.12'),
 (2, 'm2'),
 (3, 'a3.23'),
 (3, 'a3.29'),
 (3, 'm3'),
 (4, 'a4.10'),
 (4, 'a4.13'),
 (4, 'a4.16'),
 (4, 'a4.20'),
 (4, 'a4.25'),
 (4, 'a4.4'),
 (4, 'm4'),
 (5, 'a5.0'),
 (5, 'a5.17'),
 (5, 'a5.20'),
 (5, 'm5'),
 (6, 'a6.19'),
 (6, 'a6.6'),
 (6, 'm6'),
 (7, 'a7.14'),
 (7, 'a7.15'),
 (7, 'a7.20'),
 (7, 'a7.8'),
 (7, 'a7.9'),
 (7, 'm7'),
 (8, 'a8.11'),
 (8, 'a8.14'),
 (8, 'a8.26'),
 (8, 'a8.6'),
 (8, 'm8'),
 (9, 'a9.0'),
 (9, 'a9.15'),
 (9, 'a9.19'),
 (9, 'a9.3'),
 (9, 'a9.8'),
 (9, 'a9.9'),
 (9, 'm9'),
 (10, 'a10.1'),
 (10, 'a10.10'),
 (10, 'a10.11'),
 (10, 'a10.14'),
 (10, 'a10.21'),
 (10, 'a10.26'),
 (10, 'a10.29'),
 (10, 'm10'),
 (11, 'a11.10'),
 (11, 'a11.12'),
 (11, 'a11.14'),
 (11,

In [22]:
for name, param in model.named_parameters():
    if param.requires_grad and "blocks.29.attn.weight_mask_W_O" in name:
        zero_indices = torch.nonzero(param.grad != 0, as_tuple=True)
        print(f"{name=}, nonzero indices: {zero_indices}")

name='blocks.29.attn.weight_mask_W_O', nonzero indices: (tensor([13, 13, 13,  ..., 13, 13, 13], device='cuda:0'), tensor([ 0,  0,  0,  ..., 79, 79, 79], device='cuda:0'), tensor([   0,    1,    2,  ..., 2557, 2558, 2559], device='cuda:0'))


### Check if correct MLPs flow

In [29]:
for node in acdcpp_nodes:
    if "m" in node[1]:
        print(node)

(2, 'm2')
(10, 'm10')
(5, 'm5')
(3, 'm3')
(7, 'm7')
(1, 'm1')
(9, 'm9')
(8, 'm8')
(4, 'm4')
(0, 'm0')
(-1, 'embed')
(6, 'm6')


In [31]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if "mlp" in name:
        # check if param.grad is all zeros
        if param.grad is not None and param.grad.sum() != 0:
            print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        else:
            print(f"{name} grad is all zeros")

blocks.0.edge_mask_mlp grad is all zeros
blocks.0.edge_mask_mlp_baseline grad is all zeros
blocks.0.edge_mask_mlp_frozen grad is all zeros
blocks.0.mlp.W_in grad is not all zeros, param.grad.norm()=tensor(1.5009, device='cuda:0')
blocks.0.mlp.b_in grad is not all zeros, param.grad.norm()=tensor(0.3840, device='cuda:0')
blocks.0.mlp.W_out grad is not all zeros, param.grad.norm()=tensor(1.9580, device='cuda:0')
blocks.0.mlp.b_out grad is not all zeros, param.grad.norm()=tensor(0.2079, device='cuda:0')
blocks.1.edge_mask_mlp grad is all zeros
blocks.1.edge_mask_mlp_baseline grad is all zeros
blocks.1.edge_mask_mlp_frozen grad is all zeros
blocks.1.mlp.W_in grad is not all zeros, param.grad.norm()=tensor(2.1881, device='cuda:0')
blocks.1.mlp.b_in grad is not all zeros, param.grad.norm()=tensor(0.4780, device='cuda:0')
blocks.1.mlp.W_out grad is not all zeros, param.grad.norm()=tensor(2.0281, device='cuda:0')
blocks.1.mlp.b_out grad is not all zeros, param.grad.norm()=tensor(0.2563, device=

### Check if individual attention heads have gradients

In [76]:
for node in acdcpp_nodes:
    if "a9" in node[1]:
        print(node)

(9, 'a9.7')
(9, 'a9.9')
(9, 'a9.6')
(9, 'a9.8')


In [81]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if param.requires_grad and "9.attn.weight" in name:
    # if param.requires_grad and "9.attn.b" in name:
        print(param.shape)
        # if param.grad is not None and param.grad.sum() != 0:
        #     print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        # else:
        #     print(f"{name} grad is all zeros")
        print(f"{param.grad[3].norm()=}")
        print(f"{param.grad[4].norm()=}")
        print(f"{param.grad[7].norm()=}")
        print(f"{param.grad[8].norm()=}")
        print(f"{param.grad[11].norm()=}")
        print()


torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.1051, device='cuda:0')
param.grad[8].norm()=tensor(0.0287, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.1273, device='cuda:0')
param.grad[8].norm()=tensor(0.0335, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.0570, device='cuda:0')
param.grad[8].norm()=tensor(0.0485, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 64, 768])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.030

## Start Mask Learning

In [None]:
from tasks import IOITask, SportsTask, OWTTask
batch_size = 64
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, prompt_type="ABBA", nb_templates=1, template_start_idx=0)
sports = SportsTask(batch_size=batch_size, tokenizer=tokenizer, device=device)
owt = OWTTask(batch_size=batch_size, tokenizer=tokenizer, device=device)

ioi_ood = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, prompt_type="ABBA", nb_templates=1, template_start_idx=1) # different template

train_tasks = {"ioi": ioi, "owt": owt}
task_weights = {"ioi": -.2, "owt": 1} # I think means preserve OWT, corrupt IOI
eval_tasks = {"ioi": ioi, "sports": sports, "owt": owt}

  table = cls._concat_blocks(blocks, axis=0)


In [None]:
mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)

In [None]:
from cb_utils.learn_mask import train_masks

epochs_left = 500
steps_per_epoch = 10
lr = .05 # free
weight_decay = 0
evaluate_every = 1
discretize_every = 50 # 5 # free
threshold = 0.5
use_wandb = False
edge_mask_reg_strength = None
weight_mask_reg_strength = 10

wandb_config = {"edge_masks": edge_masks, "weight_masks_attn": weight_masks_attn, "weight_masks_mlp": weight_masks_mlp, "epochs": epochs_left, "steps_per_epoch": steps_per_epoch, "lr": lr, "weight_decay": weight_decay, "evaluate_every": evaluate_every, "discretize_every": discretize_every, "threshold": threshold, "edge_mask_reg_strength": edge_mask_reg_strength, "weight_mask_reg_strength": weight_mask_reg_strength}

optimizer = torch.optim.AdamW(mask_params, lr=lr, weight_decay=weight_decay)
train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
            # param_names=param_names, mask_params=mask_params, 
            task_weights=task_weights, eval_tasks=eval_tasks, evaluate_every=evaluate_every, discretize_every=discretize_every, threshold=threshold, edge_mask_reg_strength=edge_mask_reg_strength, weight_mask_reg_strength=None, verbose=False, use_wandb=use_wandb, wandb_config=wandb_config)

[34m[1mwandb[0m: Currently logged in as: [33mphilliphguo[0m. Use [1m`wandb login --relogin`[0m to force relogin


  6%|▋         | 32/501 [10:48<2:38:21, 20.26s/it]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_93609/3197613862.py", line 15, in <module>
    train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
  File "/data/phillip_guo/mechanistic-unlearning/cb_utils/learn_mask.py", line 178, in train_masks
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call l

In [None]:
import pickle
with open(f"masks/trained_mask_params_{epochs_left=}_{edge_mask_reg_strength=}.pkl", "wb") as f:
    pickle.dump(mask_params, f)

In [None]:
for name, p in zip(param_names, mask_params):
    if p.requires_grad:
        # print(name, p)
        # count how many zeros in p
        print(torch.sum(p == 0))

tensor(8, device='cuda:0')
tensor(2, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(6, device='cuda:0')
tensor(0, device='cuda:0')
tensor(12, device='cuda:0')
tensor(0, device='cuda:0')
