In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment

import os
import sys
import re

# import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

from tqdm import tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

import torch
import pickle

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


## Model Loading

In [2]:
from cb_utils.models import load_demo_gpt2, tokenizer

def load_gpt2_model(mask_type, localization_path):
    try:
        with open(f"{localization_path}", "rb") as f:
            acdcpp_nodes, acdcpp_edges, acdcpp_mask_dict, acdcpp_weight_mask_attn_dict, acdcpp_weight_mask_mlp_dict = pickle.load(f)

        mask_dict_superset = acdcpp_mask_dict if mask_type=="edge" else None
        weight_mask_attn_dict = acdcpp_weight_mask_attn_dict if mask_type=="weight" else None
        weight_mask_mlp_dict = acdcpp_weight_mask_mlp_dict if mask_type=="weight" else None

    except:
        acdcpp_nodes = None
        acdcpp_edges = None
        acdcpp_mask_dict = None
        acdcpp_weight_mask_attn_dict = None
        acdcpp_weight_mask_mlp_dict = None

        mask_dict_superset = None
        weight_mask_attn_dict = None
        weight_mask_mlp_dict = None
    base_weight_attn_dict = None
    base_weight_mlp_dict = None

    edge_masks = mask_type == 'edge'
    weight_masks = mask_type == 'weight'
    if mask_type == 'edge':
        model = load_demo_gpt2(means=False, edge_mask=True, weight_mask=False,
                        edge_masks=edge_masks, mask_dict_superset=mask_dict_superset)
    elif mask_type == 'weight':
        model = load_demo_gpt2(means=False, edge_mask=False, weight_mask=True,
                        weight_masks_attn=weight_masks, weight_masks_mlp=weight_masks, weight_mask_attn_dict=weight_mask_attn_dict, weight_mask_mlp_dict=weight_mask_mlp_dict)
    else:
        model = load_demo_gpt2(means=False, edge_mask=False, weight_mask=False,
                        edge_masks=edge_masks, mask_dict_superset=mask_dict_superset, weight_masks_attn=weight_masks, weight_masks_mlp=weight_masks, weight_mask_attn_dict=weight_mask_attn_dict, weight_mask_mlp_dict=weight_mask_mlp_dict, train_base_weights=False, base_weight_attn_dict=base_weight_attn_dict, base_weight_mlp_dict=base_weight_mlp_dict)
    return model

def load_params(model, mask_params_path):
    with open(f"{mask_params_path}", "rb") as f:
        new_param_names, new_mask_params = pickle.load(f)
        # zip to dictionary
        new_mask_params = {name: param for name, param in zip(new_param_names, new_mask_params)}

    for name, p in model.named_parameters():
        if p.requires_grad:
            p.data = new_mask_params[name]
    
    return model, new_param_names, new_mask_params

Using device: cuda:0


In [3]:
# want to test random_masks
def create_random_params(mask_params, prop_0s=None):
    # count total number of 1s and 0s in mask_params
    if prop_0s is None:
        tot_0s = 0
        tot_1s = 0
        for param in mask_params:
            tot_0s += (param == 0).sum()
            tot_1s += (param == 1).sum()
        prop_0s = tot_0s / (tot_0s + tot_1s)
    
    new_params = {}
    # create random masks with same ratio
    for param_name in mask_params:
        param = mask_params[param_name]
        new_param = t.rand_like(param) > prop_0s
        new_param = new_param.to(param.device)
        # require grad, set floating point
        new_param = new_param.float()
        new_param.requires_grad = True
        new_params[param_name] = new_param
    
    return new_params


def load_random_params(model, mask_params_path, prop_0s=None):
    with open(f"{mask_params_path}", "rb") as f:
        new_param_names, new_mask_params = pickle.load(f)
        # zip to dictionary
        new_mask_params = {name: param for name, param in zip(new_param_names, new_mask_params)}

    new_mask_params = create_random_params(new_mask_params, prop_0s=prop_0s)
    
    for name, p in model.named_parameters():
        if p.requires_grad:
            p.data = new_mask_params[name]
            print("number of 0s:", (p.data == 0).sum())
    
    return model, new_param_names, new_mask_params

In [4]:
# for induction
localization_paths = {
    "acdcpp": "localizations/eap/induction/gpt2_threshold=0.03.pkl",
    "ct": "localizations/causal_tracing/induction/gpt2_small_localizations.pkl",
    "none": None
}

# mask_params_paths = {
#     "acdcpp": {"edge": "masks/induction/edge_masks_localize=acdcpp/ckpts/mask_params_final.pkl",
#                 "weight": "masks/induction/weight_masks_localize=acdcpp/ckpts/mask_params_final.pkl"},
#     "ct": {"edge": "masks/induction/edge_masks_localize=ct/ckpts/mask_params_final.pkl",
#             "weight": "masks/induction/weight_masks_localize=ct/ckpts/mask_params_final.pkl"},
#     "none": {"edge": "masks/induction/edge_masks_localize=none/ckpts/mask_params_final.pkl",
#             "weight": "masks/induction/weight_masks_localize=none/ckpts/mask_params_final.pkl"}
# }

mask_params_paths = {
    "acdcpp": {"edge": "masks/induction/edge_masks_localize=acdcpp/ckpts/mask_params_epoch=50.pkl",
                "weight": "masks/induction/weight_masks_localize=acdcpp/ckpts/mask_params_epoch=50.pkl"},
    "ct": {"edge": "masks/induction/edge_masks_localize=ct/ckpts/mask_params_epoch=50.pkl",
            "weight": "masks/induction/weight_masks_localize=ct/ckpts/mask_params_epoch=50.pkl"},
    "none": {"edge": "masks/induction/edge_masks_localize=none/ckpts/mask_params_epoch=50.pkl",
            "weight": "masks/induction/weight_masks_localize=none/ckpts/mask_params_epoch=50.pkl"}
}

In [20]:
import pandas as pd

# want to record accuracy on different metrics for each localization type and mask type, stored in a dataframe
# columns should be localization_type, rows should be mask_type, and the values should be the accuracy

batch_size = 80

from tasks import IOITask, SportsTask, OWTTask, IOITask_Uniform, GreaterThanTask, InductionTask, InductionTask_Uniform
owt = OWTTask(batch_size=batch_size, tokenizer=tokenizer, device=device, ctx_length=40)
greaterthan = GreaterThanTask(batch_size=batch_size, tokenizer=tokenizer, device=device)
ioi = IOITask(batch_size=batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, nb_templates=4, prompt_type="ABBA")
induction = InductionTask(batch_size=batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15)
induction_df = pd.DataFrame(columns=["acdcpp", "ct", "none"], index=["edge", "weight"])
gt_df = pd.DataFrame(columns=["acdcpp", "ct", "none"], index=["edge", "weight"])
owt_df = pd.DataFrame(columns=["acdcpp", "ct", "none"], index=["edge", "weight"])
ioi_df = pd.DataFrame(columns=["acdcpp", "ct", "none"], index=["edge", "weight"])

task_dict = {"owt": (owt, owt_df), "gt": (greaterthan, gt_df), "ioi": (ioi, ioi_df), "induction": (induction, induction_df)}

prop_0s = .2
n_iters = 5

for localization_type in mask_params_paths:
    for mask_type in mask_params_paths[localization_type]:
        print(f"localization_type: {localization_type}, mask_type: {mask_type}")
        localization_path = localization_paths[localization_type]
        mask_params_path = mask_params_paths[localization_type][mask_type]

        
        model = load_gpt2_model(mask_type, localization_path)
        
        
        model, new_param_names, new_mask_params = load_random_params(model, mask_params_path, prop_0s=prop_0s)

        for task_name, (task, df) in task_dict.items():
            if task_name == "owt":
                loss = 0
                for i in range(n_iters):
                    loss += task.get_test_loss(model).item()
                df.loc[mask_type, localization_type] = loss / n_iters
            else:
                acc = 0
                for i in range(n_iters):
                    acc += task.get_test_accuracy(model)
                df.loc[mask_type, localization_type] = acc / n_iters


  table = cls._concat_blocks(blocks, axis=0)


localization_type: acdcpp, mask_type: edge
Loaded edge-masked transformer
localization_type: acdcpp, mask_type: weight
Loaded weight-masked transformer
localization_type: ct, mask_type: edge
Loaded edge-masked transformer
localization_type: ct, mask_type: weight
Loaded weight-masked transformer
localization_type: none, mask_type: edge
Loaded edge-masked transformer
localization_type: none, mask_type: weight
Loaded weight-masked transformer


In [21]:
print(f"Masking random {prop_0s*100}% of induction-localized components")

print("Induction Accuracy")
display(induction_df)

print("IOI Accuracy")
display(ioi_df)

print("OWT Cross-Entropy")
display(owt_df)

print("GreaterThan Accuracy")
display(gt_df)


Masking random 20.0% of induction-localized components
Induction Accuracy


Unnamed: 0,acdcpp,ct,none
edge,0.975,0.995,0.09
weight,0.8275,1.0,0.3975


IOI Accuracy


Unnamed: 0,acdcpp,ct,none
edge,1.0,0.9925,0.345
weight,0.75,0.99,0.2475


OWT Cross-Entropy


Unnamed: 0,acdcpp,ct,none
edge,4.150366,3.740044,9.615586
weight,4.539439,3.726847,4.874142


GreaterThan Accuracy


Unnamed: 0,acdcpp,ct,none
edge,0.9875,0.9875,0.225
weight,0.9,0.9875,0.7125
