# Set up models for edge or weight masking

## Workflow:
- Load model
- Use Task with clean and corrupt data, use ACDCPP and get the ACDCPP-style edges
- Convert ACDCPP-style edges to edge mask, get either edge superset of node superset
- Apply these masks to the mask training, either by limiting edge mask to only edge superset, node superset, or by limiting weight mask to node superset

- Also need to test other baselines, like regular finetuning

In [1]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment

In [2]:
import os
import sys
import re

# import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


In [3]:
# set up pipeline from acdcpp to edge mask
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
from tasks.ioi.IOITask import IOITask_old, IOITask
ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25)
ioi_task.set_logit_diffs(model)

In [6]:
ioi_metric = ioi_task.get_acdcpp_metric()
def negative_abs_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -abs(ioi_metric(logits))

with t.no_grad():
    clean_logits = model(ioi_task.clean_data.toks)
    corrupt_logits = model(ioi_task.corr_data.toks)
    clean_logit_diff = ioi_task.ave_logit_diff(clean_logits, ioi_task.clean_data).item()
    corrupt_logit_diff = ioi_task.ave_logit_diff(corrupt_logits, ioi_task.corr_data).item()

In [11]:
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, ioi_task.clean_data)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, ioi_task.corr_data)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: 3.040117025375366, Corrupt direction: 1.3159735202789307
Clean metric: 1.0, Corrupt metric: 0.0


In [12]:
from ACDCPPExperiment import ACDCPPExperiment
THRESHOLDS = [0.08, .15]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ioi_task.clean_data.toks,
    corr_data=ioi_task.corr_data.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 15075.78it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 248.65it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 277740.02it/s]


dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])




self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 14782.12it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:04<00:00, 248.14it/s]
Edge pruning: 100%|██████████| 1034/1034 [00:00<00:00, 289881.05it/s]
100%|██████████| 2/2 [00:16<00:00,  8.04s/it]

dict_keys([-1, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 12])





In [13]:
edges_after_acdcpp

{0.08: {'blocks.0.attn.hook_k[:, :, 10]blocks.0.hook_k_input[:, :, 10]',
  'blocks.0.attn.hook_q[:, :, 10]blocks.0.hook_q_input[:, :, 10]',
  'blocks.0.attn.hook_result[:, :, 10]blocks.0.attn.hook_k[:, :, 10]',
  'blocks.0.attn.hook_result[:, :, 10]blocks.0.attn.hook_q[:, :, 10]',
  'blocks.0.attn.hook_result[:, :, 10]blocks.0.attn.hook_v[:, :, 10]',
  'blocks.0.attn.hook_v[:, :, 1]blocks.0.hook_v_input[:, :, 1]',
  'blocks.0.hook_k_input[:, :, 10]blocks.0.hook_resid_pre[:]',
  'blocks.0.hook_k_input[:, :, 9]blocks.0.hook_resid_pre[:]',
  'blocks.0.hook_mlp_in[:]blocks.0.attn.hook_result[:, :, 10]',
  'blocks.0.hook_mlp_in[:]blocks.0.hook_resid_pre[:]',
  'blocks.0.hook_mlp_out[:]blocks.0.hook_mlp_in[:]',
  'blocks.0.hook_q_input[:, :, 10]blocks.0.hook_resid_pre[:]',
  'blocks.1.attn.hook_k[:, :, 11]blocks.1.hook_k_input[:, :, 11]',
  'blocks.1.attn.hook_q[:, :, 11]blocks.1.hook_q_input[:, :, 11]',
  'blocks.1.hook_mlp_in[:]blocks.0.hook_mlp_out[:]',
  'blocks.1.hook_mlp_out[:]blocks.1

In [14]:
import torch
torch.cuda.max_memory_allocated() / 1e9

11.337863168