In [16]:
import sys
sys.path.append("/root/circuit-finder")
import pandas as pd
import json
import torch
from dataclasses import dataclass
import transformer_lens as tl
from circuit_finder.patching.eap_graph import EAPGraph
from circuit_finder.utils import clear_memory
from circuit_finder.constants import device
from circuit_finder.pretrained import (
    load_attn_saes,
    load_hooked_mlp_transcoders,
)
from circuit_finder.metrics import batch_avg_answer_diff
from circuit_finder.patching.causal_interventions import run_with_ablations
import os
os.chdir("/root/circuit-finder/")

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f4bffb34070>

In [2]:
# Initialize SAEs
attn_saes = load_attn_saes(use_error_term=True)
mlp_transcoders = load_hooked_mlp_transcoders(use_error_term=True)

# Load model
model = tl.HookedSAETransformer.from_pretrained("gpt2").cuda()

Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]



Loaded pretrained model gpt2 into HookedTransformer
Moving model to device:  cuda


In [3]:
def metric_fn(model, tokens):
    logits = model(tokens)
    last_token_logits = logits[:, -1, :]
    return batch_avg_answer_diff(last_token_logits, batch)

In [49]:
def load_batch(file_path, batch_size=10, first_batch=0):
    with open(file_path, 'r') as file:
        data = json.load(file)

    word_idxs = data['word_idxs']

    prompts = data['prompts'][first_batch:first_batch+batch_size]

    clean_prompts = [prompt['clean'] for prompt in prompts]
    corrupt_prompts = [prompt['corrupt'] for prompt in prompts]
    correct_answers = [prompt['answers'][0] for prompt in prompts]
    wrong_answers = [prompt['wrong_answers'][0] for prompt in prompts]
   

    correct_ids = model.to_tokens(correct_answers)[:,1]
    wrong_ids = model.to_tokens(wrong_answers)[:,1]

    return word_idxs, clean_prompts, corrupt_prompts, correct_ids, wrong_ids

file_path = 'datasets/ioi/ioi_ABBA_template_0_prompts.json'
word_idxs, clean_prompts, corrupt_prompts, correct_ids, wrong_ids = load_batch(file_path)

print("Clean Prompts:", clean_prompts)
print("Corrupt Prompts:", corrupt_prompts)
print("Correct IDs:", correct_ids)
print("Wrong IDs:", wrong_ids)


Clean Prompts: ['Then, Scott and Jeremy went to the hospital. Jeremy gave a snack to', 'Then, Stephen and Victoria went to the station. Victoria gave a necklace to', 'Then, Kyle and Carter went to the restaurant. Carter gave a kiss to', 'Then, Lisa and Robert went to the restaurant. Robert gave a computer to', 'Then, Ruby and Andy went to the office. Andy gave a snack to', 'Then, Jacob and John went to the school. John gave a snack to', 'Then, Leon and Emily went to the store. Emily gave a kiss to', 'Then, Robert and Tyler went to the store. Tyler gave a snack to', 'Then, Edward and Ruby went to the garden. Ruby gave a basketball to', 'Then, Dean and Sullivan went to the station. Sullivan gave a kiss to']
Corrupt Prompts: ['Then, Michael and Anderson went to the hospital. Rachel gave a snack to', 'Then, Charlie and Jay went to the station. Elizabeth gave a necklace to', 'Then, Emily and River went to the restaurant. Brian gave a kiss to', 'Then, Daniel and Austin went to the restaurant

In [53]:
S1 = word_idxs['S1']

In [11]:
from eindex import eindex
def correct_logit(logits, correct_ids, wrong_ids):
    return eindex(logits[:,-1], correct_ids, "batch [batch]").mean()

def wrong_logit(logits, correct_ids, wrong_ids):
    return eindex(logits[:,-1], wrong_ids, "batch [batch]").mean()

def logit_diff(logits, correct_ids, wrong_ids):
    return correct_logit(logits, correct_ids, wrong_ids) - wrong_logit(logits, correct_ids, wrong_ids)

In [None]:
logits, cache = run_with_ablations(
    model,
    clean_prompts,
    attn_saes,
    mlp_transcoders,
    ablation_list = [("attn", 6, 17410, S1)],
    patch_list = [],
    cache_names_filter = [f'blocks.{layer}.attn.hook_pattern' for layer in range(12)]
)

logit_diff(logits, correct_ids, wrong_ids)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
       device='cuda:0')


tensor(4.2633, device='cuda:0')

In [42]:
import plotly.express as px
layer = 10
px.bar(cache[f'blocks.{layer}.attn.hook_pattern'].mean([0,1])[-1].cpu()).show()

In [None]:
mean_pattern = cache['blocks.10.attn.hook_pattern'].mean([0,1]