# Set up models for edge or weight masking

## Workflow:
- Load model
- Use Task with clean and corrupt data, use ACDCPP and get the ACDCPP-style edges
- Convert ACDCPP-style edges to edge mask, get either edge superset of node superset
- Apply these masks to the mask training, either by limiting edge mask to only edge superset, node superset, or by limiting weight mask to node superset

- Also need to test other baselines, like regular finetuning

In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import pickle
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment
import os
import sys
import re

from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools
from transformer_lens import HookedTransformer, ActivationCache
import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table
from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional
import torch

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')
from ACDCPPExperiment import ACDCPPExperiment
from cb_utils.mask_utils import get_masks_from_acdcpp_exp

Device: cuda


In [2]:
import json
# Load the configuration file
config_dir = "masks/debug/weight_mask_none_threshold=0.9"
# config_dir = "masks/debug/edge_mask_k=25"
with open(config_dir+"/config.json", 'r') as f:
    config = json.load(f)


use_pythia = config.get('use_pythia', False)

# Now you can use these arguments in your code
edge_masks = config.get('edge_masks', False)
weight_masks_attn = config.get('weight_masks_attn', False)
weight_masks_mlp = config.get('weight_masks_mlp', False)
train_base_weights = config.get('train_base_weights', False)
localize_acdcpp = config.get('localize_acdcpp', False)
localize_ct = config.get('localize_ct', False)

assert not (localize_acdcpp and localize_ct), "Cannot localize with both acdcpp and ct"

# localization_method = config.get('localization_method', None)
# assert "acdcpp" == localization_method or 
localize_task = config.get('localize_task', "induction")

use_uniform = config.get('use_uniform', False)
uniform_type = config.get('uniform_type', "all_tokens")
exclude_correct = config.get('exclude_correct', True)

unlrn_task_weight = config.get('unlrn_task_weight', -0.2)
epochs_left = config.get('epochs_left', 200)
steps_per_epoch = config.get('steps_per_epoch', 20)
accum_grad_steps = config.get('accum_grad_steps', 1)
lr = config.get('lr', 1e-3)
weight_decay = config.get('weight_decay', 0)
evaluate_every = config.get('evaluate_every', 2)
discretize_every = config.get('discretize_every', 40)
threshold = config.get('threshold', 0.5)
mask_k = config.get('mask_k', None)

use_wandb = config.get('use_wandb', True)
edge_mask_reg_strength = config.get('edge_mask_reg_strength', 100)
weight_mask_reg_strength = config.get('weight_mask_reg_strength', 100)
num_eval_steps = config.get('num_eval_steps', 10)
save_every = config.get('save_every', None)
# For 'save_path', since the default is not provided in the JSON, assuming None as default
save_path = config.get('save_path', None)
save_efficient = config.get('save_efficient', True)
# Assuming 'scale_reg_strength' is also a parameter you want to load with a default value
scale_reg_strength = config.get('scale_reg_strength', False)
localization_dir_path = config.get('localization_dir_path', None)
# If save_path is None, set it to the directory of the config file
if config['save_path'] is None:
    save_path = config_dir + f"/ckpts"

if localization_dir_path is None:
    localization_method = None
    if localize_acdcpp:
        localization_method = "acdcpp"
    elif localize_ct:
        localization_method = "ct"
    localization_dir_path = f"localizations/{localize_task}/{localization_method}/"

## Load Localizations and model

In [3]:
if localize_acdcpp or localize_ct:
    with open(f"{localization_dir_path}", "rb") as f:
        acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = pickle.load(f)

    mask_dict_superset = acdcpp_mask_dict if edge_masks else None
    weight_mask_attn_dict = acdcpp_weight_mask_attn_dict if weight_masks_attn else None
    weight_mask_mlp_dict = acdcpp_weight_mask_mlp_dict if weight_masks_mlp else None
    base_weight_attn_dict = acdcpp_weight_mask_attn_dict if train_base_weights else None
    base_weight_mlp_dict = acdcpp_weight_mask_mlp_dict if train_base_weights else None

else:
    acdcpp_nodes = None
    acdcpp_edges = None
    acdcpp_mask_dict = None
    acdcpp_weight_mask_attn_dict = None
    acdcpp_weight_mask_mlp_dict = None

    mask_dict_superset = None
    weight_mask_attn_dict = None
    weight_mask_mlp_dict = None
    base_weight_attn_dict = None
    base_weight_mlp_dict = None


print(acdcpp_edges)

None


In [4]:
from cb_utils.transformer import DemoTransformer
from cb_utils.models import load_demo_gpt2, tokenizer, load_demo_pythia

if use_pythia:
    if edge_masks:
        model = load_demo_pythia(means=False, model_name="pythia-2.8b", 
                                #  edge_masks=edge_masks, 
                                mask_dict_superset=mask_dict_superset,)
    elif weight_masks_attn or weight_masks_mlp:
        model = load_demo_pythia(means=False, model_name="pythia-2.8b", edge_mask=False, weight_mask=True, 
                                #  weight_masks_attn=True, weight_masks_mlp=True, 
                                weight_mask_attn_dict=weight_mask_attn_dict, weight_mask_mlp_dict=weight_mask_mlp_dict)

else:
    if edge_masks:
        model = load_demo_gpt2(means=False, edge_mask=True, weight_mask=False,
                        #    edge_masks=edge_masks, 
                        mask_dict_superset=mask_dict_superset)
    elif weight_masks_attn or weight_masks_mlp:
        model = load_demo_gpt2(means=False, edge_mask=False, weight_mask=True,
                        #    weight_masks_attn=weight_masks_attn, weight_masks_mlp=weight_masks_mlp, 
                        weight_mask_attn_dict=weight_mask_attn_dict, weight_mask_mlp_dict=weight_mask_mlp_dict)
    else:
        model = load_demo_gpt2(means=False, edge_mask=False, weight_mask=False,
                        edge_masks=edge_masks, mask_dict_superset=mask_dict_superset, weight_masks_attn=weight_masks_attn, weight_masks_mlp=weight_masks_mlp, weight_mask_attn_dict=weight_mask_attn_dict, weight_mask_mlp_dict=weight_mask_mlp_dict, train_base_weights=train_base_weights, base_weight_attn_dict=base_weight_attn_dict, base_weight_mlp_dict=base_weight_mlp_dict)

Using device: cuda:0
Loaded weight-masked transformer


In [5]:
for name, param in model.blocks[0].attn.named_parameters():
    print(name, param.shape, param.requires_grad)
print()
for name, param in model.blocks[0].mlp.named_parameters():
    print(name, param.shape, param.requires_grad)

W_Q torch.Size([12, 768, 64]) False
b_Q torch.Size([12, 64]) False
W_K torch.Size([12, 768, 64]) False
b_K torch.Size([12, 64]) False
W_V torch.Size([12, 768, 64]) False
b_V torch.Size([12, 64]) False
W_O torch.Size([12, 64, 768]) False
b_O torch.Size([768]) False
weight_mask_W_Q torch.Size([12, 768, 64]) True
weight_mask_W_K torch.Size([12, 768, 64]) True
weight_mask_W_V torch.Size([12, 768, 64]) True
weight_mask_W_O torch.Size([12, 64, 768]) True

W_in torch.Size([768, 3072]) False
b_in torch.Size([3072]) False
W_out torch.Size([3072, 768]) False
b_out torch.Size([768]) False
weight_mask_W_in torch.Size([768, 3072]) True
weight_mask_W_out torch.Size([3072, 768]) True
weight_mask_b_in torch.Size([3072]) True
weight_mask_b_out torch.Size([768]) True


In [6]:
for block in model.blocks:
    print(f"{block.attn.weight_mask=}, {block.attn.mask_heads=}, {block.mlp.weight_mask=}")

block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_mask=True
block.attn.weight_mask=True, block.attn.mask_heads=None, block.mlp.weight_ma

In [7]:
if use_pythia:
    from tasks import IOITask, SportsTask, OWTTask, IOITask_Uniform, GreaterThanTask, InductionTask, InductionTask_Uniform, SportsTask_Uniform
    test_batch_size = 32
    sports = SportsTask(batch_size=test_batch_size, tokenizer=tokenizer, device=device)
    owt = OWTTask(batch_size=test_batch_size, tokenizer=tokenizer, device=device, ctx_length=30)
    ioi = IOITask(batch_size=test_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, nb_templates=4, prompt_type="ABBA")
    induction = InductionTask(batch_size=test_batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15)

    train_batch_size=4
    owt_train = OWTTask(batch_size=3, tokenizer=tokenizer, device=device, ctx_length=30)
    if localize_task == "ioi":

        ioi_task_2 = IOITask(batch_size=test_batch_size, tokenizer=tokenizer, device=device, nb_templates=1, prompt_type="ABBA", template_start_idx=4) # slightly different template

        ioi_task_3 = IOITask(batch_size=test_batch_size, tokenizer=tokenizer, device=device, nb_templates=1, prompt_type="BABA", template_start_idx=0) # different name format

        # train_tasks = {"ioi": ioi, "owt": owt}
        if use_uniform:
            ioi_uniform = IOITask_Uniform(batch_size=train_batch_size, tokenizer=tokenizer, device=device, uniform_over=uniform_type, nb_templates=4, prompt_type="ABBA")
            train_tasks = {"ioi_uniform": ioi_uniform, "owt": owt_train}
            task_weights = {"ioi_uniform": unlrn_task_weight, "owt": 1} # I think means preserve OWT, corrupt IOI
        else: 
            ioi_train = IOITask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, nb_templates=4, prompt_type="ABBA")
            train_tasks = {"ioi": ioi_train, "owt": owt_train}
            task_weights = {"ioi": unlrn_task_weight, "owt": 1}

        eval_tasks = {"ioi": ioi, "induction": induction, "owt": owt, "ioi_2": ioi_task_2, "ioi_3": ioi_task_3, "sports": sports}

    elif localize_task == "induction":
        if use_uniform:
            induction_uniform = InductionTask_Uniform(batch_size=train_batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15, uniform_over=uniform_type)
            train_tasks = {"induction_uniform": induction_uniform, "owt": owt_train}
            task_weights = {"induction_uniform": unlrn_task_weight, "owt": 1}

        else:
            induction_train = InductionTask(batch_size=train_batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15)
            train_tasks = {"induction": induction_train, "owt": owt_train}
            task_weights = {"induction": unlrn_task_weight, "owt": 1}

        eval_tasks = {"ioi": ioi, "induction": induction, "owt": owt, "sports": sports}

    elif localize_task == "sports":
        if use_uniform:
            sports_uniform = SportsTask_Uniform(batch_size=train_batch_size, tokenizer=tokenizer, uniform_over=uniform_type)
            train_tasks = {"sports_uniform": sports_uniform, "owt": owt_train}
            task_weights = {"sports_uniform": unlrn_task_weight, "owt": 1}
        
        else:
            sports_train = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer)
            train_tasks = {"sports": sports_train, "owt": owt_train}
            task_weights = {"sports": unlrn_task_weight, "owt": 1}

        eval_tasks = {"ioi": ioi, "induction": induction, "owt": owt, "sports": sports}

    elif localize_task == "sports_limited":
        maintain_sports = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, start_index=64, stop_index=-128, train_test_split=False)
        if use_uniform:
            forget_sports_uniform = SportsTask_Uniform(batch_size=train_batch_size, tokenizer=tokenizer, uniform_over=uniform_type, start_index=0, stop_index=64, train_test_split=False)
            train_tasks = {"forget_sports_uniform": forget_sports_uniform, "maintain_sports": maintain_sports, "owt": owt_train}
            task_weights = {"forget_sports_uniform": unlrn_task_weight, "maintain_sports": 1, "owt": 1}

        else:
            forget_sports = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, start_index=0, stop_index=64, train_test_split=False)
            train_tasks = {"forget_sports": forget_sports, "maintain_sports": maintain_sports, "owt": owt_train}
            task_weights = {"forget_sports": unlrn_task_weight, "maintain_sports": 1, "owt": 1}

        forget_sports_eval = SportsTask(batch_size=test_batch_size, tokenizer=tokenizer, start_index=0, stop_index=64, train_test_split=False)
        maintain_sports_eval = SportsTask(batch_size=test_batch_size, tokenizer=tokenizer, start_index=64, stop_index=-128, train_test_split=False)
        other_sports = SportsTask(batch_size=test_batch_size, tokenizer=tokenizer, start_index=-128, train_test_split=False)
        eval_tasks = {"ioi": ioi, "induction": induction, "owt": owt, "forget_sports": forget_sports_eval, "maintain_sports": maintain_sports_eval, "other_sports": other_sports}

else:

    from tasks import IOITask, SportsTask, OWTTask, IOITask_Uniform, GreaterThanTask, InductionTask, InductionTask_Uniform
    batch_size = 80
    # sports = SportsTask(batch_size=batch_size*2, tokenizer=tokenizer, device=device)
    owt = OWTTask(batch_size=batch_size, tokenizer=tokenizer, device=device, ctx_length=40)
    greaterthan = GreaterThanTask(batch_size=batch_size, tokenizer=tokenizer, device=device)
    ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, nb_templates=4, prompt_type="ABBA")
    induction = InductionTask(batch_size=batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15)

    if localize_task == "ioi":
        ioi_uniform = IOITask_Uniform(batch_size=batch_size, tokenizer=tokenizer, device=device, uniform_over=uniform_type, nb_templates=4, prompt_type="ABBA", exclude_correct=exclude_correct)

        ioi_task_2 = IOITask(batch_size=batch_size*2, tokenizer=tokenizer, device=device, nb_templates=1, prompt_type="ABBA", template_start_idx=4) # slightly different template

        ioi_task_3 = IOITask(batch_size=batch_size*2, tokenizer=tokenizer, device=device, nb_templates=1, prompt_type="BABA", template_start_idx=0) # different name format

        # train_tasks = {"ioi": ioi, "owt": owt}
        if use_uniform:
            train_tasks = {"ioi_uniform": ioi_uniform, "owt": owt}
            task_weights = {"ioi_uniform": unlrn_task_weight, "owt": 1} # I think means preserve OWT, corrupt IOI
        else:
            train_tasks = {"ioi": ioi, "owt": owt}
            task_weights = {"ioi": unlrn_task_weight, "owt": 1}

        eval_tasks = {"ioi": ioi, "induction": induction, "owt": owt, "ioi_2": ioi_task_2, "ioi_3": ioi_task_3, "greaterthan": greaterthan}

    elif localize_task == "induction":
        induction_uniform = InductionTask_Uniform(batch_size=batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15, uniform_over=uniform_type, exclude_correct=exclude_correct)
        
        if use_uniform:
            train_tasks = {"induction_uniform": induction_uniform, "owt": owt}
            task_weights = {"induction_uniform": unlrn_task_weight, "owt": 1}

        else:
            train_tasks = {"induction": induction, "owt": owt}
            task_weights = {"induction": unlrn_task_weight, "owt": 1}

        eval_tasks = {"ioi": ioi, "induction": induction, "owt": owt, "greaterthan": greaterthan}


  table = cls._concat_blocks(blocks, axis=0)


In [8]:
mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)

print(param_names)

['blocks.0.attn.weight_mask_W_Q', 'blocks.0.attn.weight_mask_W_K', 'blocks.0.attn.weight_mask_W_V', 'blocks.0.attn.weight_mask_W_O', 'blocks.0.mlp.weight_mask_W_in', 'blocks.0.mlp.weight_mask_W_out', 'blocks.0.mlp.weight_mask_b_in', 'blocks.0.mlp.weight_mask_b_out', 'blocks.1.attn.weight_mask_W_Q', 'blocks.1.attn.weight_mask_W_K', 'blocks.1.attn.weight_mask_W_V', 'blocks.1.attn.weight_mask_W_O', 'blocks.1.mlp.weight_mask_W_in', 'blocks.1.mlp.weight_mask_W_out', 'blocks.1.mlp.weight_mask_b_in', 'blocks.1.mlp.weight_mask_b_out', 'blocks.2.attn.weight_mask_W_Q', 'blocks.2.attn.weight_mask_W_K', 'blocks.2.attn.weight_mask_W_V', 'blocks.2.attn.weight_mask_W_O', 'blocks.2.mlp.weight_mask_W_in', 'blocks.2.mlp.weight_mask_W_out', 'blocks.2.mlp.weight_mask_b_in', 'blocks.2.mlp.weight_mask_b_out', 'blocks.3.attn.weight_mask_W_Q', 'blocks.3.attn.weight_mask_W_K', 'blocks.3.attn.weight_mask_W_V', 'blocks.3.attn.weight_mask_W_O', 'blocks.3.mlp.weight_mask_W_in', 'blocks.3.mlp.weight_mask_W_out', 'b

## Do Mask Learning

In [9]:
if scale_reg_strength:
    orig_edge_mask_reg_strength = edge_mask_reg_strength
    orig_weight_mask_reg_strength = weight_mask_reg_strength
    edge_mask_reg_strength = lambda epoch: orig_edge_mask_reg_strength * (epoch - 1/10*epochs_left)
    weight_mask_reg_strength = lambda epoch: orig_weight_mask_reg_strength * (epoch - 1/10*epochs_left)


In [10]:
from cb_utils.learn_mask import *
wandb_config = config
use_wandb = False

optimizer = torch.optim.AdamW(mask_params, lr=lr, weight_decay=weight_decay)
train_losses, test_losses = train_masks(model, 
                                        tasks=train_tasks, 
                                        optimizer=optimizer, 
                                        num_epochs=epochs_left, 

                                        steps_per_epoch=steps_per_epoch, 
                                        accum_grad_steps=accum_grad_steps,

                                        task_weights=task_weights, 
                                        eval_tasks=eval_tasks, 

                                        evaluate_every=evaluate_every, 
                                        discretize_every=discretize_every, 
                                        save_every=save_every,

                                        threshold=threshold, 
                                        mask_k=mask_k,
                                        edge_mask_reg_strength=edge_mask_reg_strength, 
                                        weight_mask_reg_strength=weight_mask_reg_strength, 

                                        verbose=True, 
                                        use_wandb=use_wandb, 
                                        wandb_config=wandb_config, 
                                        save_dir=save_path, 

                                        save_efficient=save_efficient, 
                                        refresh_memory=use_pythia) # only refresh memory is pythia is being used

  0%|          | 0/21 [00:00<?, ?it/s]

Epoch 0, step 0
Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight

  5%|▍         | 1/21 [00:09<03:05,  9.27s/it]

Loss on ioi_3: 0.7931265234947205
Evaluating on greaterthan
Loss on greaterthan: 3.911612033843994
Epoch 1, step 0
Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=Non

 10%|▉         | 2/21 [00:15<02:27,  7.77s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 14%|█▍        | 3/21 [00:22<02:12,  7.34s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 19%|█▉        | 4/21 [00:29<02:00,  7.11s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 24%|██▍       | 5/21 [00:37<01:56,  7.28s/it]

Loss on ioi_3: 0.7243724465370178
Evaluating on greaterthan
Loss on greaterthan: 3.911612033843994
Epoch 5, step 0
Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=Non

 29%|██▊       | 6/21 [00:43<01:44,  6.95s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 33%|███▎      | 7/21 [00:49<01:34,  6.75s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 38%|███▊      | 8/21 [00:56<01:26,  6.65s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 43%|████▎     | 9/21 [01:03<01:23,  6.95s/it]

Loss on ioi_3: 0.7887482047080994
Evaluating on greaterthan
Loss on greaterthan: 3.911648988723755
Epoch 9, step 0
Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=Non

 48%|████▊     | 10/21 [01:10<01:14,  6.75s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 52%|█████▏    | 11/21 [01:16<01:06,  6.69s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 57%|█████▋    | 12/21 [01:23<00:59,  6.64s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 62%|██████▏   | 13/21 [01:30<00:55,  6.88s/it]

Loss on ioi_3: 0.8537848591804504
Evaluating on greaterthan
Loss on greaterthan: 3.911863088607788
Epoch 13, step 0
Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=No

 67%|██████▋   | 14/21 [01:36<00:46,  6.69s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 71%|███████▏  | 15/21 [01:43<00:39,  6.64s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 76%|███████▌  | 16/21 [01:49<00:32,  6.58s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 81%|████████  | 17/21 [01:57<00:27,  6.94s/it]

Loss on ioi_3: 0.7934633493423462
Evaluating on greaterthan
Loss on greaterthan: 3.9118459224700928
Epoch 17, step 0
Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=N

 86%|████████▌ | 18/21 [02:04<00:20,  6.92s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 90%|█████████ | 19/21 [02:11<00:13,  6.82s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

 95%|█████████▌| 20/21 [02:17<00:06,  6.76s/it]

Calculating weight_reg
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
self.attn.weight_mask=True, self.attn.mask_heads=None, self.mlp.weight_mask=True
weigh

100%|██████████| 21/21 [02:25<00:00,  6.93s/it]

Loss on ioi_3: 0.85821932554245
Evaluating on greaterthan
Loss on greaterthan: 3.9120635986328125





In [11]:
train_losses

defaultdict(list,
            {'ioi': [(0, 0, 0.8509278297424316),
              (0, 1, 5.451411247253418),
              (0, 2, 11.070261001586914),
              (0, 3, 15.22369384765625),
              (0, 4, 20.594860076904297),
              (0, 5, 26.3444766998291),
              (0, 6, 29.318954467773438),
              (0, 7, 35.21002197265625),
              (0, 8, 37.39036560058594),
              (0, 9, 41.46675109863281),
              (1, 0, 0.9649065136909485),
              (1, 1, 1.9460601806640625),
              (1, 2, 5.629096508026123),
              (1, 3, 10.17425537109375),
              (1, 4, 17.89014434814453),
              (1, 5, 28.62545394897461),
              (1, 6, 36.456703186035156),
              (1, 7, 43.86286926269531),
              (1, 8, 48.4648323059082),
              (1, 9, 51.050193786621094),
              (2, 0, 53.51465606689453),
              (2, 1, 55.42345428466797),
              (2, 2, 60.24665069580078),
              (2, 3, 65.27

In [20]:
import wandb
wandb.finish()



VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
edge_reg_term,▁▁▁▁▁▁▁▁
train_loss_ioi,█▆▅▄▄▃▂▂▁
train_loss_owt,▄█▅▁▅█▄▃
weight_mask_reg,▁▁▁▁▁▁▁▁

0,1
edge_reg_term,0.0
train_loss_ioi,-10.9533
train_loss_owt,3.68572
weight_mask_reg,0.0


## Check how many edges masked

In [13]:
model

DemoTransformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (ln_final): LayerNorm()
  (unembed): Unembed()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP()
    )
  )
)

In [14]:
model.get_weight_reg()

(tensor(57824992., device='cuda:0', grad_fn=<AddBackward0>),
 tensor(57844992., device='cuda:0'))

In [16]:
acdcpp_nodes

{(-1, 'embed'),
 (0, 'a0.1'),
 (0, 'a0.10'),
 (0, 'a0.6'),
 (0, 'a0.9'),
 (0, 'm0'),
 (1, 'a1.11'),
 (1, 'a1.7'),
 (1, 'm1'),
 (2, 'm2'),
 (3, 'a3.0'),
 (3, 'a3.4'),
 (3, 'm3'),
 (4, 'a4.11'),
 (4, 'm4'),
 (5, 'a5.5'),
 (5, 'a5.8'),
 (5, 'a5.9'),
 (5, 'm5'),
 (6, 'a6.0'),
 (6, 'm6'),
 (7, 'a7.3'),
 (7, 'a7.9'),
 (7, 'm7'),
 (8, 'a8.10'),
 (8, 'a8.6'),
 (8, 'm8'),
 (9, 'a9.6'),
 (9, 'a9.7'),
 (9, 'a9.8'),
 (9, 'a9.9'),
 (9, 'm9'),
 (10, 'a10.0'),
 (10, 'a10.1'),
 (10, 'a10.10'),
 (10, 'a10.2'),
 (10, 'a10.6'),
 (10, 'a10.7'),
 (10, 'm10'),
 (11, 'a11.10'),
 (11, 'a11.2'),
 (11, 'a11.3'),
 (12, 'output')}

In [22]:
# check mlp 1
print(model.blocks[1].mlp.weight_mask_W_in.shape)
print(model.blocks[1].mlp.weight_mask_W_in.sum())
print((model.blocks[1].mlp.weight_mask_W_in == 0).sum())

torch.Size([768, 3072])
tensor(2359000., device='cuda:0', grad_fn=<SumBackward0>)
tensor(296, device='cuda:0')


In [24]:
tot_zeros = 0
for layer, node in acdcpp_nodes:
    if 'embed' in node:
        continue
    elif 'm' in node:
        tot_zeros += (model.blocks[layer].mlp.weight_mask_W_in == 0).sum()
        tot_zeros += (model.blocks[layer].mlp.weight_mask_W_out == 0).sum()
        tot_zeros += (model.blocks[layer].mlp.weight_mask_b_in == 0).sum()
        tot_zeros += (model.blocks[layer].mlp.weight_mask_b_out == 0).sum()
    elif 'a' in node:
        head = int(node.split('.')[1])
        tot_zeros += (model.blocks[layer].attn.weight_mask_W_Q[head] == 0).sum()
        tot_zeros += (model.blocks[layer].attn.weight_mask_W_K[head] == 0).sum()
        tot_zeros += (model.blocks[layer].attn.weight_mask_W_V[head] == 0).sum()
        tot_zeros += (model.blocks[layer].attn.weight_mask_W_O[head] == 0).sum()
    
print(tot_zeros)

tensor(19999, device='cuda:0')


# Debug Transformer

## Check DemoTransformer Implementations Correct

In [13]:
from cb_utils.transformers.gpt2.edge_masked_transformer import DemoTransformer as GPT2EdgeDemoTransformer, Config as GPT2Config

from cb_utils.models import tl_config_to_demo_config
with open("models/gpt2_weights.pkl", "rb") as f:
    gpt2_weights = pickle.load(f)
demo_edge_gpt2 = GPT2EdgeDemoTransformer(GPT2Config(debug=False, n_layers=12, n_heads=12), means=False)
demo_edge_gpt2.load_state_dict(gpt2_weights, strict=False)
demo_edge_gpt2.cuda()


from cb_utils.transformers.gpt2.weight_masked_transformer import DemoTransformer as GPT2WeightDemoTransformer, Config as GPT2Config
demo_weight_gpt2 = GPT2WeightDemoTransformer(GPT2Config(debug=False, n_layers=12, n_heads=12))
demo_weight_gpt2.load_state_dict(gpt2_weights, strict=False)
demo_weight_gpt2.cuda()

with torch.no_grad():
    test_input = t.tensor(gpt2_tokenizer.encode("The quick brown fox jumps over the lazy")).unsqueeze(0).cuda()
    print((demo_edge_gpt2(test_input)[0][0, -1] - reference_gpt2(test_input)[0, -1]).std())
    print((demo_weight_gpt2(test_input)[0][0, -1] - reference_gpt2(test_input)[0, -1]).std())
    print((demo_edge_gpt2(test_input)[0][0, -1] - demo_weight_gpt2(test_input)[0][0, -1]).std())
    # print((model(test_input)[0][0, -1] - edge_masked_model(test_input)[0][0, -1]).std())
    # print((reference_pythia(test_input)[0, -1] - edge_masked_model(test_input)[0][0, -1]).std())

tensor(7.2482e-06, device='cuda:0')
tensor(7.2237e-06, device='cuda:0')
tensor(6.9957e-06, device='cuda:0')


## Setup ACDCPP edges

In [5]:
from tasks.ioi.IOITask import IOITask_old, IOITask
# ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25)
ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25, nb_templates=1, prompt_type="ABBA")
ioi_task.set_logit_diffs(model)

ioi_metric = ioi_task.get_acdcpp_metric()
def negative_abs_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ioi_metric(logits))

with t.no_grad():
    clean_logits = model(ioi_task.clean_data.toks)
    corrupt_logits = model(ioi_task.corr_data.toks)
    clean_logit_diff = ioi_task.ave_logit_diff(clean_logits, ioi_task.clean_data).item()
    corrupt_logit_diff = ioi_task.ave_logit_diff(corrupt_logits, ioi_task.corr_data).item()
    print(f'Clean logit diff: {clean_logit_diff:.3f}, Corrupt logit diff: {corrupt_logit_diff:.3f}')

OpenAI API key not found, will not be able to run evaluations on HPSAQ Task
OpenAI API key not found, will not be able to run evaluations on HPSAQ Task
Clean logit diff: 3.040117025375366, Corrupted logit diff: 1.2651995420455933
Clean logit diff: 3.040, Corrupt logit diff: 1.265


In [6]:
ioi_metric(clean_logits)

tensor(1., device='cuda:0')

In [8]:
ioi_metric(corrupt_logits, ioi_dataset=ioi_task.corr_data)

tensor(0., device='cuda:0')

In [9]:
clean_logit_diff, corrupt_logit_diff

(3.040117025375366, 1.2651995420455933)

In [11]:
ioi_task_2 = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25, nb_templates=1, prompt_type="BABA", template_start_idx=0)
ioi_task_2.get_batch()

{'[PLACE]': ['station',
  'restaurant',
  'restaurant',
  'restaurant',
  'restaurant'],
 '[OBJECT]': ['ring', 'computer', 'necklace', 'bone', 'computer'],
 'text': ['Then, William and Richard went to the station. William gave a ring to',
  'Then, Charles and Jeremy went to the restaurant. Charles gave a computer to',
  'Then, Simon and Clark went to the restaurant. Simon gave a necklace to',
  'Then, Jacob and Scott went to the restaurant. Jacob gave a bone to',
  'Then, Steven and Sullivan went to the restaurant. Steven gave a computer to'],
 'IO': ['Richard', 'Jeremy', 'Clark', 'Scott', 'Sullivan'],
 'S': ['William', 'Charles', 'Simon', 'Jacob', 'Steven'],
 'TEMPLATE_IDX': tensor([0, 0, 0, 0, 0])}

In [13]:
from ACDCPPExperiment import ACDCPPExperiment
from cb_utils.mask_utils import get_masks_from_acdcpp_exp
THRESHOLDS = [0.08, .15]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ioi_task.clean_data.toks,
    corr_data=ioi_task.corr_data.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

# pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()
acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = get_masks_from_acdcpp_exp(acdcpp_exp, threshold=0.08)



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15299.74it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 258.14it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 281818.85it/s]


dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])




self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15224.87it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 253.97it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 306343.88it/s]
100%|██████████| 2/2 [00:15<00:00,  7.86s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])





In [21]:
acdcpp_edges

{((0, 'a0.10'), (-1, 'embed')),
 ((0, 'a0.9'), (-1, 'embed')),
 ((0, 'm0'), (-1, 'embed')),
 ((0, 'm0'), (0, 'a0.10')),
 ((1, 'm1'), (0, 'm0')),
 ((3, 'a3.0'), (0, 'm0')),
 ((3, 'm3'), (0, 'm0')),
 ((3, 'm3'), (1, 'm1')),
 ((3, 'm3'), (3, 'a3.0')),
 ((4, 'm4'), (0, 'm0')),
 ((4, 'm4'), (3, 'a3.0')),
 ((5, 'a5.9'), (0, 'm0')),
 ((5, 'a5.9'), (1, 'm1')),
 ((5, 'a5.9'), (3, 'a3.0')),
 ((5, 'a5.9'), (3, 'm3')),
 ((5, 'm5'), (3, 'a3.0')),
 ((5, 'm5'), (5, 'a5.5')),
 ((6, 'm6'), (3, 'a3.0')),
 ((6, 'm6'), (5, 'm5')),
 ((7, 'a7.9'), (0, 'm0')),
 ((7, 'm7'), (6, 'm6')),
 ((8, 'a8.10'), (4, 'm4')),
 ((8, 'a8.10'), (5, 'a5.5')),
 ((8, 'a8.10'), (5, 'a5.9')),
 ((8, 'a8.6'), (5, 'a5.5')),
 ((9, 'a9.9'), (8, 'a8.10')),
 ((10, 'a10.0'), (8, 'a8.10')),
 ((10, 'a10.7'), (0, 'm0')),
 ((10, 'a10.7'), (6, 'm6')),
 ((10, 'a10.7'), (9, 'a9.6')),
 ((10, 'a10.7'), (9, 'a9.9')),
 ((11, 'a11.10'), (0, 'm0')),
 ((11, 'a11.10'), (6, 'm6')),
 ((11, 'a11.10'), (8, 'a8.10')),
 ((11, 'a11.10'), (9, 'a9.6')),
 ((11, 

In [3]:
threshold = 0.001
import pickle
with open('localizations/eap/eap_sports/1000_graph.pkl', 'rb') as f:
    graph = pickle.load(f)

eap_edges = graph.top_edges(n=1000, threshold=threshold)
# eap_edges = set()
# for i in range(eap_scores.shape[0]):
#     for j in range(eap_scores.shape[1]):
#         if eap_scores[i, j] > threshold:
            # eap_edges.add((get_node_name(graph.node_names[i], show_full_index=False), get_node_name(graph.node_names[j, show_full_index=False))))
print(eap_edges)

[('head.0.2', 'mlp.0', 0.009443754330277443), ('head.0.14', 'mlp.0', 0.006188404746353626), ('mlp.6', 'mlp.8', 0.00573932658880949), ('mlp.0', 'mlp.2', 0.005653967149555683), ('head.16.20', 'mlp.16', -0.005345507059246302), ('head.16.20', 'head.17.30.v', 0.005315450485795736), ('mlp.0', 'head.1.16.k', -0.005109565332531929), ('head.14.14', 'mlp.15', -0.004921694286167622), ('mlp.15', 'head.16.20.k', -0.004847021773457527), ('mlp.6', 'mlp.15', -0.004780464340001345), ('mlp.10', 'head.16.20.k', 0.004766407422721386), ('mlp.6', 'head.16.20.k', 0.004757486749440432), ('mlp.0', 'head.1.15.k', -0.004494336899369955), ('mlp.0', 'mlp.5', -0.004315529949963093), ('mlp.0', 'mlp.4', -0.004094669129699469), ('head.16.20', 'mlp.18', 0.004019735846668482), ('head.0.30', 'mlp.0', 0.0039080469869077206), ('mlp.0', 'mlp.6', -0.0038432518485933542), ('mlp.8', 'mlp.9', 0.00376768596470356), ('head.16.20', 'head.21.9.v', 0.003543522208929062), ('mlp.9', 'mlp.11', 0.003415714716538787), ('mlp.6', 'head.21.

In [4]:
len(graph.eap_scores.flatten())

3277824

In [5]:
"""
Convert format:
want: {((3, 'm3'), (3, 'a3.0')),
 ((4, 'm4'), (0, 'm0')),
 ((4, 'm4'), (3, 'a3.0')),
 ((5, 'a5.9'), (0, 'm0')),}

have:
[('mlp.0', 'mlp.2', 0.005653967149555683),
('head.0.14', 'mlp.0', 0.006188404746353626),]
...
"""
from cb_utils.mask_utils import get_formatted_edges_from_eap, get_masks_from_eap_exp
# formatted_eap_edges = get_formatted_edges_from_eap(eap_edges)
# formatted_eap_edges
with open('localizations/eap/eap_sports/1000_graph.pkl', 'rb') as f:
    graph = pickle.load(f)
acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = get_masks_from_eap_exp(graph, threshold=0.001, num_layers=32, num_heads=32)

In [6]:
(acdcpp_mask_dict['m31'] == 0).sum()

tensor(5)

In [47]:
from tasks import InductionTask
ind_task = InductionTask(batch_size=16, tokenizer=model.tokenizer, prep_acdcpp=True, seq_len=10, acdcpp_metric="ave_logit_diff")
ind_task.set_logit_diffs(model)
print(ind_task.clean_logit_diff, ind_task.corrupted_logit_diff)

12.292320251464844 1.4027974605560303


In [48]:
ind_metric = ind_task.get_acdcpp_metric()
def negative_abs_ind_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ind_metric(logits))

with t.no_grad():
    clean_logits = model(ind_task.clean_data.cuda())
    corrupt_logits = model(ind_task.corr_data.cuda())
    clean_logit_diff = ind_task.ave_logit_diff(clean_logits, ind_task.clean_data).item()
    corrupt_logit_diff = ind_task.ave_logit_diff(corrupt_logits, ind_task.corr_data).item()
    
print(ind_metric(clean_logits))
print(ind_metric(corrupt_logits))

In [51]:
from ACDCPPExperiment import ACDCPPExperiment
from cb_utils.mask_utils import get_masks_from_acdcpp_exp
THRESHOLDS = [0.05]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ind_task.clean_data,
    corr_data=ind_task.corr_data,
    acdc_metric=negative_abs_ind_metric,
    acdcpp_metric=ind_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

# pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()
acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = get_masks_from_acdcpp_exp(acdcpp_exp, threshold=THRESHOLDS[0])



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15171.34it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 252.85it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 304067.19it/s]
100%|██████████| 1/1 [00:08<00:00,  8.35s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 5, 0, 1, 2, 3, 4, 6, 12])





## Loading models

In [4]:
import pickle
from cb_utils.transformer import DemoTransformer
from cb_utils.models import load_demo_gpt2, tokenizer
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

edge_masks = True
weight_masks_attn = True
weight_masks_mlp = True
train_base_weights = True
localize_acdcpp = True

# if edge_masks is True, then have mask_dict_superset be acdcpp_mask_dict
mask_dict_superset = None if not edge_masks else acdcpp_mask_dict
# model = load_demo_gpt2(means=means, mask_dict_superset=acdcpp_mask_dict)
if localize_acdcpp:
    weight_mask_attn_dict = acdcpp_weight_mask_attn_dict if weight_masks_attn else None
    weight_mask_mlp_dict = acdcpp_weight_mask_mlp_dict if weight_masks_mlp else None
    base_weight_attn_dict = acdcpp_weight_mask_attn_dict if train_base_weights else None
    base_weight_mlp_dict = acdcpp_weight_mask_mlp_dict if train_base_weights else None

else:
    weight_mask_attn_dict = None
    weight_mask_mlp_dict = None
    base_weight_attn_dict = None
    base_weight_mlp_dict = None

# model = load_demo_gpt2(means=False, edge_masks=edge_masks, mask_dict_superset=mask_dict_superset, weight_masks_attn=weight_masks_attn, weight_masks_mlp=weight_masks_mlp, weight_mask_attn_dict=weight_mask_attn_dict, weight_mask_mlp_dict=weight_mask_mlp_dict, train_base_weights=train_base_weights, base_weight_attn_dict=base_weight_attn_dict, base_weight_mlp_dict=base_weight_mlp_dict)

Using device: cuda:0


NameError: name 'acdcpp_mask_dict' is not defined

In [5]:
from cb_utils.models import load_demo_pythia
threshold = 0.0005
with open(f"localizations/eap/eap_sports/pythia-2.8b_{threshold=}.pkl", "rb") as f:
    acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = pickle.load(f)

# model = load_demo_pythia(means=False, model_name="pythia-2.8b", edge_masks=True, mask_dict_superset=acdcpp_mask_dict)
model = load_demo_pythia(means=False, model_name="pythia-2.8b", edge_mask=False, weight_mask=True, weight_masks_attn=True, weight_masks_mlp=True, weight_mask_attn_dict=acdcpp_weight_mask_attn_dict, weight_mask_mlp_dict=acdcpp_weight_mask_mlp_dict)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-2.8b into HookedTransformer


In [4]:
reference_pythia = HookedTransformer.from_pretrained(
        'pythia-2.8b',
        fold_ln=False,
        center_writing_weights=False,
        center_unembed=False,
        # default_padding_side="left",
        # device='cuda'
        device='cpu'
    )
tokenizer = reference_pythia.tokenizer
pythia_tokenizer = reference_pythia.tokenizer
# reference_pythia.to("cuda")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-2.8b into HookedTransformer


In [6]:
# compare the two models
# compare model outputs
with torch.no_grad():
    test_input = t.tensor(pythia_tokenizer.encode("The quick brown fox jumps over the lazy")).unsqueeze(0).cuda()
    # print((model(test_input)[0][0, -1] - reference_pythia(test_input)[0, -1]).std())
    # print((model(test_input)[0][0, -1] - edge_masked_model(test_input)[0][0, -1]).std())
    # print((reference_pythia(test_input)[0, -1] - edge_masked_model(test_input)[0][0, -1]).std())


In [23]:
from tasks import SportsTask, InductionTask, IOITask, InductionTask_Uniform, OWTTask#, SportsTask_Uniform
from tasks.facts.SportsTask import SportsTask_Uniform
sports_task = SportsTask(batch_size=32, tokenizer=tokenizer, prep_acdcpp=False)
ioi_task = IOITask(batch_size=32, tokenizer=tokenizer, prep_acdcpp=False)
ind_task = InductionTask(batch_size=32, tokenizer=tokenizer, prep_acdcpp=False)
ind_uniform_task = InductionTask_Uniform(batch_size=16, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15, uniform_over="rep_tokens")
owt_task = OWTTask(batch_size=32, tokenizer=tokenizer, device=device, ctx_length=30)
sports_uniform_task = SportsTask_Uniform(batch_size=32, tokenizer=tokenizer, uniform_over="sports_tokens")
print(sports_task.get_test_accuracy(model))#, sports_task.get_test_accuracy(reference_pythia))
print(ioi_task.get_test_accuracy(model))#, ioi_task.get_test_accuracy(reference_pythia))
print(ind_task.get_test_accuracy(model))#, ind_task.get_test_accuracy(reference_pythia))
print(ind_uniform_task.get_test_loss(model))#, ind_uniform_task.get_test_loss(reference_pythia))
print(owt_task.get_test_loss(model))#, owt_task.get_test_accuracy(reference_pythia))
print(sports_uniform_task.get_test_loss(model))#, sports_uniform_task.get_test_accuracy(reference_pythia))

  table = cls._concat_blocks(blocks, axis=0)


1.0
1.0
0.96875
tensor(12.0880, device='cuda:0')
tensor(2.8271, device='cuda:0')
tensor(3.4641, device='cuda:0')


In [6]:
sports_uniform_task = SportsTask_Uniform(batch_size=32, tokenizer=tokenizer, uniform_over="all_tokens")
print(sports_uniform_task.get_test_loss(model))#, sports_uniform_task.get_test_accuracy(reference_pythia))

tensor(20.2142, device='cuda:0')


In [7]:
from tasks import LimitedSportsTask
forget_task = LimitedSportsTask(batch_size=32, tokenizer=tokenizer, start_index=0, stop_index=64, make_complementary_task=True)
remember_task = forget_task.complementary_task
forget_task.get_test_accuracy(model), remember_task.get_test_accuracy(model)

(1.0, 1.0)

In [8]:
torch.cuda.memory_allocated(device=device) / 1e9

21.968525312

In [9]:
# test max batch size for sports, owt
owt_task = OWTTask(batch_size=1, tokenizer=tokenizer, ctx_length=30)
sports_task = SportsTask(batch_size=1, tokenizer=tokenizer, shuffle=False)

  table = cls._concat_blocks(blocks, axis=0)


In [12]:
sports_task = SportsTask(batch_size=1, tokenizer=tokenizer, shuffle=False)
tot_loss = 0
print(torch.cuda.memory_allocated(device=device) / 1e9)
for i in range(4):
    loss = sports_task.get_train_loss(model)
    tot_loss += loss
    print(torch.cuda.memory_allocated(device=device) / 1e9)
tot_loss.backward()
print(torch.cuda.memory_allocated(device=device) / 1e9)

31.304587776
41.535725568
51.745850368
61.98750208
72.183995392
31.304588288


In [70]:
tot_loss = 0
print(torch.cuda.memory_allocated(device=device) / 1e9)
model.zero_grad()
for i in range(20):
    loss = sports_task.get_train_loss(model)
    print(torch.cuda.memory_allocated(device=device) / 1e9)
    print(model.blocks[2].edge_mask_mlp.grad)
    loss.backward()
print(torch.cuda.memory_allocated(device=device) / 1e9)

11.808385536
19.528289792
None
19.150077952
tensor([-0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000,  0.0000,
         0.0000,  0.0000, -0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -0.0000,
         0.0000,  0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,  0.0000,
        -0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
        -0.0000, -0.1926,  0.0000,  0.0000,  0.0000, -0.0000, -0.0000,  0.0000,
         0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000,
         0.0000,  0.0000,  0.0000, -0.0430,  0.0000,  0.0000,  0.0000, -0.0000,
        -0.0000,  0.0000, -0.0000, -0.0041,  0.0266,  0.0000,  0.0000,  0.0000,
        -0.0000,  0.0000,  0.0028], device='cuda:0')
19.530473472
tensor([-0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000,  0.0000,  0.0000,
         0.0000,  0.0000, -0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0000,
         0.0000,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,  0.0000,  0.0000,
        -0

In [68]:
model.blocks[2].edge_mask_mlp.grad

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000, -0.4634,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.1553,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.2797,  0.0654,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.3760], device='cuda:0')

## Test that gradients flow correctly

In [14]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.shape, param.requires_grad)
    # print(name, param.shape, param.requires_grad)

blocks.0.attn.weight_mask_W_Q torch.Size([32, 2560, 80]) True
blocks.0.attn.weight_mask_W_K torch.Size([32, 2560, 80]) True
blocks.0.attn.weight_mask_W_V torch.Size([32, 2560, 80]) True
blocks.0.attn.weight_mask_W_O torch.Size([32, 80, 2560]) True
blocks.0.mlp.weight_mask_W_in torch.Size([2560, 10240]) True
blocks.0.mlp.weight_mask_W_out torch.Size([10240, 2560]) True
blocks.0.mlp.weight_mask_b_in torch.Size([10240]) True
blocks.0.mlp.weight_mask_b_out torch.Size([2560]) True
blocks.1.attn.weight_mask_W_Q torch.Size([32, 2560, 80]) True
blocks.1.attn.weight_mask_W_K torch.Size([32, 2560, 80]) True
blocks.1.attn.weight_mask_W_V torch.Size([32, 2560, 80]) True
blocks.1.attn.weight_mask_W_O torch.Size([32, 80, 2560]) True
blocks.1.mlp.weight_mask_W_in torch.Size([2560, 10240]) True
blocks.1.mlp.weight_mask_W_out torch.Size([10240, 2560]) True
blocks.1.mlp.weight_mask_b_in torch.Size([10240]) True
blocks.1.mlp.weight_mask_b_out torch.Size([2560]) True
blocks.2.attn.weight_mask_W_Q torch.Si

In [18]:
model.train()
batch_size = 2
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False)
loss = ioi.get_train_loss(model)
loss.backward()

In [19]:
torch.cuda.memory_allocated(device=device) / 1e9

32.17700352

In [20]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if param.requires_grad: # and "edge" in name:
        # check if param.grad is all zeros
        if param.grad is not None and param.grad.sum() != 0:
            print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        else:
            print(f"{name} grad is all zeros")

blocks.0.attn.weight_mask_W_Q grad is not all zeros, param.grad.norm()=tensor(0.0216, device='cuda:0')
blocks.0.attn.weight_mask_W_K grad is not all zeros, param.grad.norm()=tensor(1.5815, device='cuda:0')
blocks.0.attn.weight_mask_W_V grad is not all zeros, param.grad.norm()=tensor(1.3946, device='cuda:0')
blocks.0.attn.weight_mask_W_O grad is not all zeros, param.grad.norm()=tensor(1.7858, device='cuda:0')
blocks.0.mlp.weight_mask_W_in grad is not all zeros, param.grad.norm()=tensor(3.0081, device='cuda:0')
blocks.0.mlp.weight_mask_W_out grad is not all zeros, param.grad.norm()=tensor(5.3979, device='cuda:0')
blocks.0.mlp.weight_mask_b_in grad is not all zeros, param.grad.norm()=tensor(0.5360, device='cuda:0')
blocks.0.mlp.weight_mask_b_out grad is not all zeros, param.grad.norm()=tensor(0.6562, device='cuda:0')
blocks.1.attn.weight_mask_W_Q grad is not all zeros, param.grad.norm()=tensor(0.0523, device='cuda:0')
blocks.1.attn.weight_mask_W_K grad is not all zeros, param.grad.norm()=

In [21]:
acdcpp_nodes

{(0, 'a0.21'),
 (0, 'a0.30'),
 (0, 'm0'),
 (1, 'a1.0'),
 (1, 'a1.10'),
 (1, 'a1.11'),
 (1, 'a1.12'),
 (1, 'a1.14'),
 (1, 'a1.15'),
 (1, 'a1.16'),
 (1, 'a1.17'),
 (1, 'a1.18'),
 (1, 'a1.25'),
 (1, 'a1.26'),
 (1, 'a1.29'),
 (1, 'a1.30'),
 (1, 'a1.4'),
 (1, 'a1.6'),
 (1, 'm1'),
 (2, 'a2.11'),
 (2, 'a2.12'),
 (2, 'm2'),
 (3, 'a3.23'),
 (3, 'a3.29'),
 (3, 'm3'),
 (4, 'a4.10'),
 (4, 'a4.13'),
 (4, 'a4.16'),
 (4, 'a4.20'),
 (4, 'a4.25'),
 (4, 'a4.4'),
 (4, 'm4'),
 (5, 'a5.0'),
 (5, 'a5.17'),
 (5, 'a5.20'),
 (5, 'm5'),
 (6, 'a6.19'),
 (6, 'a6.6'),
 (6, 'm6'),
 (7, 'a7.14'),
 (7, 'a7.15'),
 (7, 'a7.20'),
 (7, 'a7.8'),
 (7, 'a7.9'),
 (7, 'm7'),
 (8, 'a8.11'),
 (8, 'a8.14'),
 (8, 'a8.26'),
 (8, 'a8.6'),
 (8, 'm8'),
 (9, 'a9.0'),
 (9, 'a9.15'),
 (9, 'a9.19'),
 (9, 'a9.3'),
 (9, 'a9.8'),
 (9, 'a9.9'),
 (9, 'm9'),
 (10, 'a10.1'),
 (10, 'a10.10'),
 (10, 'a10.11'),
 (10, 'a10.14'),
 (10, 'a10.21'),
 (10, 'a10.26'),
 (10, 'a10.29'),
 (10, 'm10'),
 (11, 'a11.10'),
 (11, 'a11.12'),
 (11, 'a11.14'),
 (11,

In [22]:
for name, param in model.named_parameters():
    if param.requires_grad and "blocks.29.attn.weight_mask_W_O" in name:
        zero_indices = torch.nonzero(param.grad != 0, as_tuple=True)
        print(f"{name=}, nonzero indices: {zero_indices}")

name='blocks.29.attn.weight_mask_W_O', nonzero indices: (tensor([13, 13, 13,  ..., 13, 13, 13], device='cuda:0'), tensor([ 0,  0,  0,  ..., 79, 79, 79], device='cuda:0'), tensor([   0,    1,    2,  ..., 2557, 2558, 2559], device='cuda:0'))


### Check if correct MLPs flow

In [29]:
for node in acdcpp_nodes:
    if "m" in node[1]:
        print(node)

(2, 'm2')
(10, 'm10')
(5, 'm5')
(3, 'm3')
(7, 'm7')
(1, 'm1')
(9, 'm9')
(8, 'm8')
(4, 'm4')
(0, 'm0')
(-1, 'embed')
(6, 'm6')


In [31]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if "mlp" in name:
        # check if param.grad is all zeros
        if param.grad is not None and param.grad.sum() != 0:
            print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        else:
            print(f"{name} grad is all zeros")

blocks.0.edge_mask_mlp grad is all zeros
blocks.0.edge_mask_mlp_baseline grad is all zeros
blocks.0.edge_mask_mlp_frozen grad is all zeros
blocks.0.mlp.W_in grad is not all zeros, param.grad.norm()=tensor(1.5009, device='cuda:0')
blocks.0.mlp.b_in grad is not all zeros, param.grad.norm()=tensor(0.3840, device='cuda:0')
blocks.0.mlp.W_out grad is not all zeros, param.grad.norm()=tensor(1.9580, device='cuda:0')
blocks.0.mlp.b_out grad is not all zeros, param.grad.norm()=tensor(0.2079, device='cuda:0')
blocks.1.edge_mask_mlp grad is all zeros
blocks.1.edge_mask_mlp_baseline grad is all zeros
blocks.1.edge_mask_mlp_frozen grad is all zeros
blocks.1.mlp.W_in grad is not all zeros, param.grad.norm()=tensor(2.1881, device='cuda:0')
blocks.1.mlp.b_in grad is not all zeros, param.grad.norm()=tensor(0.4780, device='cuda:0')
blocks.1.mlp.W_out grad is not all zeros, param.grad.norm()=tensor(2.0281, device='cuda:0')
blocks.1.mlp.b_out grad is not all zeros, param.grad.norm()=tensor(0.2563, device=

### Check if individual attention heads have gradients

In [76]:
for node in acdcpp_nodes:
    if "a9" in node[1]:
        print(node)

(9, 'a9.7')
(9, 'a9.9')
(9, 'a9.6')
(9, 'a9.8')


In [81]:
param_names = []
model_params = []
for name, param in model.named_parameters():
    if param.requires_grad and "9.attn.weight" in name:
    # if param.requires_grad and "9.attn.b" in name:
        print(param.shape)
        # if param.grad is not None and param.grad.sum() != 0:
        #     print(f"{name} grad is not all zeros, {param.grad.norm()=}")
        # else:
        #     print(f"{name} grad is all zeros")
        print(f"{param.grad[3].norm()=}")
        print(f"{param.grad[4].norm()=}")
        print(f"{param.grad[7].norm()=}")
        print(f"{param.grad[8].norm()=}")
        print(f"{param.grad[11].norm()=}")
        print()


torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.1051, device='cuda:0')
param.grad[8].norm()=tensor(0.0287, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.1273, device='cuda:0')
param.grad[8].norm()=tensor(0.0335, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 768, 64])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.0570, device='cuda:0')
param.grad[8].norm()=tensor(0.0485, device='cuda:0')
param.grad[11].norm()=tensor(0., device='cuda:0')

torch.Size([12, 64, 768])
param.grad[3].norm()=tensor(0., device='cuda:0')
param.grad[4].norm()=tensor(0., device='cuda:0')
param.grad[7].norm()=tensor(0.030

## Start Mask Learning

In [None]:
from tasks import IOITask, SportsTask, OWTTask
batch_size = 64
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, prompt_type="ABBA", nb_templates=1, template_start_idx=0)
sports = SportsTask(batch_size=batch_size, tokenizer=tokenizer, device=device)
owt = OWTTask(batch_size=batch_size, tokenizer=tokenizer, device=device)

ioi_ood = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, prompt_type="ABBA", nb_templates=1, template_start_idx=1) # different template

train_tasks = {"ioi": ioi, "owt": owt}
task_weights = {"ioi": -.2, "owt": 1} # I think means preserve OWT, corrupt IOI
eval_tasks = {"ioi": ioi, "sports": sports, "owt": owt}

  table = cls._concat_blocks(blocks, axis=0)


In [None]:
mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)

In [None]:
from cb_utils.learn_mask import train_masks

epochs_left = 500
steps_per_epoch = 10
lr = .05 # free
weight_decay = 0
evaluate_every = 1
discretize_every = 50 # 5 # free
threshold = 0.5
use_wandb = False
edge_mask_reg_strength = None
weight_mask_reg_strength = 10

wandb_config = {"edge_masks": edge_masks, "weight_masks_attn": weight_masks_attn, "weight_masks_mlp": weight_masks_mlp, "epochs": epochs_left, "steps_per_epoch": steps_per_epoch, "lr": lr, "weight_decay": weight_decay, "evaluate_every": evaluate_every, "discretize_every": discretize_every, "threshold": threshold, "edge_mask_reg_strength": edge_mask_reg_strength, "weight_mask_reg_strength": weight_mask_reg_strength}

optimizer = torch.optim.AdamW(mask_params, lr=lr, weight_decay=weight_decay)
train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
            # param_names=param_names, mask_params=mask_params, 
            task_weights=task_weights, eval_tasks=eval_tasks, evaluate_every=evaluate_every, discretize_every=discretize_every, threshold=threshold, edge_mask_reg_strength=edge_mask_reg_strength, weight_mask_reg_strength=None, verbose=False, use_wandb=use_wandb, wandb_config=wandb_config)

[34m[1mwandb[0m: Currently logged in as: [33mphilliphguo[0m. Use [1m`wandb login --relogin`[0m to force relogin


  6%|▋         | 32/501 [10:48<2:38:21, 20.26s/it]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_93609/3197613862.py", line 15, in <module>
    train_masks(model, tasks=train_tasks, optimizer=optimizer, num_epochs=epochs_left, steps_per_epoch=steps_per_epoch,
  File "/data/phillip_guo/mechanistic-unlearning/cb_utils/learn_mask.py", line 178, in train_masks
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/torch/_tensor.py", line 492, in backward
    torch.autograd.backward(
  File "/data/phillip_guo/miniconda3/envs/unlrn/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call l

In [None]:
import pickle
with open(f"masks/trained_mask_params_{epochs_left=}_{edge_mask_reg_strength=}.pkl", "wb") as f:
    pickle.dump(mask_params, f)

In [None]:
for name, p in zip(param_names, mask_params):
    if p.requires_grad:
        # print(name, p)
        # count how many zeros in p
        print(torch.sum(p == 0))

tensor(8, device='cuda:0')
tensor(2, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(0, device='cuda:0')
tensor(2, device='cuda:0')
tensor(3, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(2, device='cuda:0')
tensor(0, device='cuda:0')
tensor(1, device='cuda:0')
tensor(0, device='cuda:0')
tensor(6, device='cuda:0')
tensor(0, device='cuda:0')
tensor(12, device='cuda:0')
tensor(0, device='cuda:0')
