# Set up models for edge or weight masking

## Workflow:
- Load model
- Use Task with clean and corrupt data, use ACDCPP and get the ACDCPP-style edges
- Convert ACDCPP-style edges to edge mask, get either edge superset of node superset
- Apply these masks to the mask training, either by limiting edge mask to only edge superset, node superset, or by limiting weight mask to node superset

- Also need to test other baselines, like regular finetuning

In [2]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append('acdcpp/Automatic-Circuit-Discovery/')
sys.path.append('acdcpp/')
from acdc import TLACDCExperiment
from acdcpp.ACDCPPExperiment import ACDCPPExperiment

In [3]:
import os
import sys
import re

# import acdc
from acdc.TLACDCExperiment import TLACDCExperiment
from acdc.acdc_utils import TorchIndex, EdgeType
import numpy as np
import torch as t
from torch import Tensor
import einops
import itertools

from transformer_lens import HookedTransformer, ActivationCache

import tqdm.notebook as tqdm
import plotly
from rich import print as rprint
from rich.table import Table

from jaxtyping import Float, Bool
from typing import Callable, Tuple, Union, Dict, Optional

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

Device: cuda


In [4]:
# set up pipeline from acdcpp to edge mask
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)



Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
from tasks.ioi.IOITask import IOITask_old, IOITask
ioi_task = IOITask(batch_size=5, tokenizer=model.tokenizer, device=device, prep_acdcpp=True, acdcpp_N=25)
ioi_task.set_logit_diffs(model)

In [6]:
ioi_metric = lambda logits: ioi_task.get_acdcpp_metric(logits, model=model)
negative_abs_ioi_metric = lambda logits: -abs(ioi_task.get_acdcpp_metric(logits, model=model))

In [7]:
print(ioi_task.clean_data.toks.shape)
print(ioi_task.corr_data.toks.shape)
print(ioi_task.clean_logit_diff)
print(ioi_task.corrupted_logit_diff)

torch.Size([100, 16])
torch.Size([100, 16])
3.131037473678589
1.190014123916626


In [8]:
ioi_task.corr_data.ioi_prompts

[{'[PLACE]': 'garden',
  '[OBJECT]': 'computer',
  'text': 'Then, Joseph and Brandon went to the garden. Connor gave a computer to',
  'IO': 'Brandon',
  'S': 'Connor',
  'TEMPLATE_IDX': 0},
 {'[PLACE]': 'station',
  '[OBJECT]': 'drink',
  'text': 'Then, Daniel and Andrew went to the station. Simon gave a drink to',
  'IO': 'Andrew',
  'S': 'Simon',
  'TEMPLATE_IDX': 0},
 {'[PLACE]': 'house',
  '[OBJECT]': 'drink',
  'text': 'Then, Tyler and Kyle went to the house. Grant gave a drink to',
  'IO': 'Kyle',
  'S': 'Grant',
  'TEMPLATE_IDX': 0},
 {'[PLACE]': 'school',
  '[OBJECT]': 'kiss',
  'text': 'Then, Ford and Amy went to the school. Carter gave a kiss to',
  'IO': 'Amy',
  'S': 'Carter',
  'TEMPLATE_IDX': 0},
 {'[PLACE]': 'house',
  '[OBJECT]': 'basketball',
  'text': 'Then, Ryan and Max went to the house. Kate gave a basketball to',
  'IO': 'Max',
  'S': 'Kate',
  'TEMPLATE_IDX': 0},
 {'[PLACE]': 'school',
  '[OBJECT]': 'snack',
  'text': 'Then, Anthony and Daniel went to the school

In [9]:
def ave_logit_diff(
    logits: Float[Tensor, 'batch seq d_vocab'],
    ioi_dataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    print(f"{logits.shape=}, {ioi_dataset.word_idx['end'].shape=}, {len(ioi_dataset.io_tokenIDs)=}")
    io_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.io_tokenIDs]
    s_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.s_tokenIDs]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()

with t.no_grad():
    clean_logits = model(ioi_task.clean_data.toks)
    corrupt_logits = model(ioi_task.corr_data.toks)
    clean_logit_diff = ave_logit_diff(clean_logits, ioi_task.clean_data).item()
    corrupt_logit_diff = ave_logit_diff(corrupt_logits, ioi_task.corr_data).item()

logits.shape=torch.Size([100, 16, 50257]), ioi_dataset.word_idx['end'].shape=torch.Size([100]), len(ioi_dataset.io_tokenIDs)=100
logits.shape=torch.Size([100, 16, 50257]), ioi_dataset.word_idx['end'].shape=torch.Size([100]), len(ioi_dataset.io_tokenIDs)=100


In [10]:
corrupt_logit_diff

1.190014123916626

In [11]:
import torch
other_code_clean_toks, other_code_corrupted_toks = torch.load("acdcpp/ioi_task/ioi_dataset_toks.pt")

In [12]:
ioi_task.corrupted_logit_diff

1.190014123916626

In [13]:
from ACDCPPExperiment import ACDCPPExperiment
THRESHOLDS = [0.08, .15]#np.arange(0.005, 0.155, 0.005)
RUN_NAME = 'abs_edge'

acdcpp_exp = ACDCPPExperiment(
    model=model,
    clean_data=ioi_task.clean_data.toks,
    corr_data=ioi_task.corr_data.toks,
    acdc_metric=negative_abs_ioi_metric,
    acdcpp_metric=ioi_metric,
    thresholds=THRESHOLDS,
    run_name=RUN_NAME,
    verbose=False,
    attr_absolute_val=True,
    save_graphs_after=-100,
    pruning_mode='edge',
    no_pruned_nodes_attr=1,
    run_acdc=False,
    run_acdcpp=True,
)
# e=acdcpp_exp.setup_exp(0.0)

pruned_heads, num_passes, acdcpp_pruned_attrs, acdc_pruned_attrs, edges_after_acdcpp, edges_after_acdc = acdcpp_exp.run()



self.current_node=TLACDCInterpNode(blocks.11.hook_resid_post, [:])


  0%|          | 0/2 [00:16<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 58.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 21.06 MiB is free. Process 38445 has 2.57 GiB memory in use. Process 39195 has 25.81 GiB memory in use. Including non-PyTorch memory, this process has 50.72 GiB memory in use. Of the allocated memory 48.87 GiB is allocated by PyTorch, and 1.36 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
edges_after_acdcpp

{0.08: set(), 0.15: set()}

In [1]:
import torch
torch.cuda.max_memory_allocated() / 1e9

0.0

In [10]:
len(edges_after_acdcpp[0.08])

0