In [1]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("/root/circuit-finder")
import pandas as pd
import json
import torch
from dataclasses import dataclass
import transformer_lens as tl
from circuit_finder.patching.eap_graph import EAPGraph
from circuit_finder.utils import clear_memory
from circuit_finder.constants import device
from circuit_finder.pretrained import (
    load_attn_saes,
    load_hooked_mlp_transcoders,
)
from circuit_finder.metrics import batch_avg_answer_diff
from circuit_finder.patching.causal_interventions import run_with_ablations
import os
os.chdir("/root/circuit-finder/")

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fec2ddf02e0>

In [2]:
# Initialize SAEs
attn_saes = load_attn_saes(use_error_term=True)
mlp_transcoders = load_hooked_mlp_transcoders(use_error_term=True)

# Load model
model = tl.HookedSAETransformer.from_pretrained("gpt2").cuda()

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]



Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda


In [3]:
from eindex import eindex
def correct_logit(logits, correct_ids, wrong_ids):
    return eindex(logits[:,-1], correct_ids, "batch [batch]").mean()

def wrong_logit(logits, correct_ids, wrong_ids):
    return eindex(logits[:,-1], wrong_ids, "batch [batch]").mean()

def logit_diff(logits, correct_ids, wrong_ids):
    return correct_logit(logits, correct_ids, wrong_ids) - wrong_logit(logits, correct_ids, wrong_ids)

In [15]:
clean_prompts = ["The doctor is ready, you can go see"]
corrupt_prompts = ["The nurse is ready, you can go see"]
correct_ids = model.to_tokens(" him")[:,1]
wrong_ids = model.to_tokens(" her")[:,1]

logits, clean_cache = run_with_ablations(
    model,
    clean_prompts,
    attn_saes,
    mlp_transcoders,
    ablation_list = [("attn", 1, 22089, -1), ],
    patch_list = [],
    cache_names_filter = [f'blocks.{layer}.attn.hook_pattern' for layer in range(12)]
)

print("clean logit diff = ", logit_diff(logits, correct_ids, wrong_ids))

clean logit diff =  tensor(0.7484, device='cuda:0')


In [13]:
model(clean_prompts)

tensor([[[ 7.5261, 11.1214,  7.8919,  ..., -3.1299, -3.3873,  8.5934],
         [-0.7078,  1.4936, -0.0422,  ..., -0.3339, -3.6979,  1.6328],
         [ 5.9632,  6.7539,  1.4267,  ..., -1.1700,  4.4327,  4.9453],
         ...,
         [ 9.4609,  8.2143,  2.5712,  ..., -2.5130, -1.0821,  6.8435],
         [13.8023,  9.0044,  4.3323,  ..., -0.7796, -0.7950,  8.8284],
         [11.2392,  6.8841,  2.0912,  ..., -1.3461, -2.3885,  6.7478]]],
       device='cuda:0')