In [1]:
"""dimension annotation
b: batch
t: token position
d: d_model
v: model token vocab size
l: SAE n latent
k: topk
"""

from functools import partial

import numpy as np
import torch
import einops

import transformer_lens.utils as utils
from transformer_lens import HookedTransformer

from openwebtext import load_owt, sample
from pretrained_sae import load_sae

torch.set_grad_enabled(False)

layer_index = 8
location = "resid_post_mlp"
device = utils.get_device()

ds = load_owt()
gpt2 = HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
sae = load_sae(32, location, layer_index, device)

Loading dataset from disk:   0%|          | 0/152 [00:00<?, ?it/s]

Loaded 8,013,769 sample texts from data/owt_tokenized




Loaded pretrained model gpt2 into HookedTransformer
Loaded pretrained SAE data/sae/v5_32k_location_resid_post_mlp_layer_8.pt


In [2]:
ablate_token_idx = 32
T = 16

batch = sample(ds, 1)
logit, cache = gpt2.run_with_cache(batch, return_type='logits')

In [3]:
def fn(act_btd, hook, ablate_feature_idx, ablate_token_idx=ablate_token_idx):
    lact_btk, _ = sae.encode(act_btd)
    lact_k = lact_btk[0, ablate_token_idx]

    ldir_dk = sae.decoder.weight
    all_feature_dk = ldir_dk * lact_k

    active_feature = all_feature_dk[:, all_feature_dk.sum(dim=0) != 0]  # (d, 32)
    ablate_feature = active_feature[:, ablate_feature_idx]  # (d, )

    act = act_btd.clone()

    # subtrack ablate feature from target token activation
    act[0, ablate_token_idx] -= ablate_feature

    return act

In [4]:
bin = []

for i in range(32):
    ablated_logit = gpt2.run_with_hooks(
        batch,
        return_type="logits",
        fwd_hooks=[
            (
                utils.get_act_name("resid_post", layer_index),
                partial(fn, ablate_feature_idx=0),
            )
        ],
    )

    logit_diff = (
        logit[0, ablate_token_idx : ablate_token_idx + T]
        - ablated_logit[0, ablate_token_idx : ablate_token_idx + T]
    )

    median_diff = torch.median(logit_diff, dim=1)[0]

    logit_diff -= median_diff[..., None]
    bin.append(logit_diff)


In [5]:
vt = torch.stack(bin)
vt.shape

torch.Size([32, 16, 50257])

In [6]:
l1 = torch.abs(vt).sum(-1)
l2 = (vt ** 2).sum(-1) ** 0.5

l1.shape, l2.shape

(torch.Size([32, 16]), torch.Size([32, 16]))

In [7]:
bench = (l1 / l2) ** 2
bench.shape


torch.Size([32, 16])

In [8]:
bench = (bench - bench.min()) / (bench.max() - bench.min())
bench.mean()

tensor(0.6502, device='cuda:0')