In [1]:
import torch as t
from torch import Tensor
from jaxtyping import Float
from tqdm import tqdm
import numpy as np

from nnsight.models.UnifiedTransformer import UnifiedTransformer

device = "cuda"

In [2]:
model = UnifiedTransformer(
    'gpt2-small',
    processing=False,
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
tokenizer = model.tokenizer

model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
from ioi_dataset import IOIDataset

N = 25
clean_dataset = IOIDataset(
    prompt_type='mixed',
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=device
)
corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')

In [4]:
def ave_logit_diff(
    logits: Float[Tensor, 'batch seq d_vocab'],
    ioi_dataset: IOIDataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    batch_size = logits.size(0)
    io_logits = logits[range(batch_size), ioi_dataset.word_idx['end'][:batch_size], ioi_dataset.io_tokenIDs[:batch_size]]
    s_logits = logits[range(batch_size), ioi_dataset.word_idx['end'][:batch_size], ioi_dataset.s_tokenIDs[:batch_size]]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()


with t.no_grad():
    with model.trace(clean_dataset.toks):
        clean_logits = model.output.save()

    with model.trace(corr_dataset.toks):
        corrupt_logits = model.output.save()

clean_logits = clean_logits.value
corrupt_logits = corrupt_logits.value

clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item()
corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item()

def ioi_metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
    ioi_dataset: IOIDataset = clean_dataset
 ):
    patched_logit_diff = ave_logit_diff(logits, ioi_dataset)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

def negative_ioi_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -ioi_metric(logits)
    
# Get clean and corrupt logit differences
with t.no_grad():
    clean_metric = ioi_metric(clean_logits, corrupt_logit_diff, clean_logit_diff, clean_dataset)
    corrupt_metric = ioi_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff, corr_dataset)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Clean direction: 2.805180311203003, Corrupt direction: 1.2550485134124756
Clean metric: 1.0, Corrupt metric: 0.0


In [5]:
import eap

import importlib

importlib.reload(eap)

graph = eap.EAP(model.cfg, components=["head", "mlp"])

In [6]:
graph.run(
    model,
    clean_dataset.toks,
    corr_dataset.toks,
    batch_size=25,
    metric=ioi_metric,
)

tensor([[[-4.2697e-01,  1.1448e+00,  1.3131e-01,  ..., -2.4350e+00,
           6.0842e-02, -1.7969e+00],
         [-1.1787e-01,  1.4245e+00, -1.0625e+00,  ..., -7.6892e-01,
           1.8339e-01, -1.5101e+00],
         [ 9.1228e-01,  1.3094e-01, -8.9158e-01,  ..., -4.3268e-02,
           6.3102e-01, -1.7191e-01],
         ...,
         [-3.7074e-01,  6.8026e-02,  1.1891e+00,  ...,  3.2080e+00,
           6.8503e-01,  8.4383e-01],
         [-4.7524e-01, -7.3930e-02,  1.1964e+00,  ...,  2.9511e+00,
           4.7496e-01,  1.0620e+00],
         [-5.4846e-01, -1.6747e-01,  1.1752e+00,  ...,  2.7625e+00,
           3.4602e-01,  1.1410e+00]],

        [[-4.2697e-01,  1.1448e+00,  1.3131e-01,  ..., -2.4350e+00,
           6.0842e-02, -1.7969e+00],
         [-5.0372e-01, -6.8822e-01, -3.5946e-01,  ...,  3.0826e-01,
          -4.7470e-01, -1.8963e+00],
         [ 1.1867e+00,  2.8087e-01, -9.6663e-01,  ..., -1.8876e-01,
           7.2888e-01, -1.9992e-01],
         ...,
         [-4.2782e-01,  2

In [7]:
edges = graph.top_edges(n=20, format=True)

head.9.9 -> [-0.021] -> head.11.10.q
head.10.7 -> [0.02] -> head.11.10.q
head.5.5 -> [0.016] -> head.8.6.v
mlp.0 -> [0.014] -> head.6.9.q
head.9.9 -> [-0.013] -> head.10.7.q
head.5.5 -> [-0.011] -> head.6.9.q
mlp.0 -> [-0.01] -> head.11.10.k
head.9.6 -> [-0.01] -> head.11.10.q
head.9.6 -> [-0.009] -> head.10.7.q
head.4.11 -> [0.009] -> head.6.9.k
mlp.0 -> [-0.007] -> head.10.7.k
head.5.5 -> [0.007] -> head.7.9.v
mlp.0 -> [0.006] -> head.3.0.k
mlp.0 -> [-0.006] -> head.10.7.v
head.10.10 -> [-0.006] -> head.11.10.q
head.5.5 -> [0.006] -> head.8.10.v
mlp.0 -> [-0.006] -> head.7.9.k
head.8.6 -> [0.006] -> head.9.9.q
mlp.0 -> [0.005] -> head.3.0.q
head.6.9 -> [0.005] -> head.8.6.v
