In [1]:
"""dimension annotation
b: batch
t: token position
d: d_model
v: model token vocab size
l: SAE n latent
k: topk
"""

from functools import partial

import numpy as np
import torch

import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from tqdm import tqdm

from openwebtext import load_owt, sample
from pretrained_sae import load_sae

torch.set_grad_enabled(False)

layer_index = 8
location = "resid_post_mlp"
device = utils.get_device()

ds = load_owt()
gpt2 = HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)
sae = load_sae(32, location, layer_index, device)

Loading dataset from disk:   0%|          | 0/152 [00:00<?, ?it/s]

Loaded 8,013,769 sample texts from data/owt_tokenized




Loaded pretrained model gpt2 into HookedTransformer
Loaded pretrained SAE data/sae/v5_32k_location_resid_post_mlp_layer_8.pt


In [2]:
ablate_token_idx = 0
V = 50257
T = 16

# IndexError: index 31 is out of bounds for dimension 1 with size 31
# was using 32, some sample doesn't have 32 features, so do ablation on top 16 strongest activated feature
K = 30 

D = 768

In [3]:
def fn_ablate_feature(act_btd, hook, ablate_feature_idx, ablate_token_idx):
    lact_btk, _ = sae.encode(act_btd)
    lact_k = lact_btk[0, ablate_token_idx]
    
    # Sort lact_k by absolute value, descending
    sorted_indices = torch.argsort(lact_k.abs(), descending=True)
    lact_k = lact_k[sorted_indices]

    ldir_dk = sae.decoder.weight[:, sorted_indices]
    all_feature_dk = ldir_dk * lact_k

    active_feature = all_feature_dk[:, all_feature_dk.sum(dim=0) != 0]  # (d, 32)s
    ablate_feature = active_feature[:, ablate_feature_idx]  # (d, )

    act = act_btd.clone()

    # subtract ablate_feature from target token AND all previous tokens's activation
    # act[0, : ablate_token_idx + 1] -= ablate_feature

    # subtract ablate_feature only from target token activation
    act[0, ablate_token_idx] -= ablate_feature

    return act

In [4]:
def proc_ablate_feature(sample_1t):
    logit_btv = gpt2(sample_1t)    
    
    bin = []
    for i in range(K):
        ablated_logit_btv = gpt2.run_with_hooks(
            sample_1t,
            return_type="logits",
            fwd_hooks=[
                (
                    utils.get_act_name("resid_post", layer_index),
                    partial(
                        fn_ablate_feature,
                        ablate_feature_idx=i,
                        ablate_token_idx=ablate_token_idx,
                    ),
                )
            ],
        )

        logit_diff_tv = (
            logit_btv[0, ablate_token_idx : ablate_token_idx + T]
            - ablated_logit_btv[0, ablate_token_idx : ablate_token_idx + T]
        )

        median_diff_t = torch.median(logit_diff_tv, dim=1)[0]

        logit_diff_tv -= median_diff_t[..., None]
        bin.append(logit_diff_tv)

    vt = torch.stack(bin).view(-1, V * T)
    l1 = torch.abs(vt).sum(-1)
    l2 = (vt**2).sum(-1) ** 0.5

    bench = (l1 / l2) ** 2
    normalized_bench = bench / (V * T)

    return normalized_bench.mean().item()

In [5]:
def fn_ablate_resid_stream_channel(act_btd, hook, ablate_channel_idx, ablate_token_idx):
    act = act_btd.clone()
    act[:, ablate_token_idx, ablate_channel_idx] = 0 
    return act

In [6]:
def proc_ablate_resid_stream_channel(sample_1t):
    logit_btv = gpt2(sample_1t)

    bin = []
    for i in range(D):
        ablated_logit_btv = gpt2.run_with_hooks(
            sample_1t,
            return_type="logits",
            fwd_hooks=[
                (
                    utils.get_act_name("resid_post", layer_index),
                    partial(
                        fn_ablate_resid_stream_channel,
                        ablate_channel_idx=i,
                        ablate_token_idx=ablate_token_idx,
                    ),
                )
            ],
        )

        logit_diff_tv = (
            logit_btv[0, ablate_token_idx : ablate_token_idx + T]
            - ablated_logit_btv[0, ablate_token_idx : ablate_token_idx + T]
        )

        median_diff_t = torch.median(logit_diff_tv, dim=1)[0]

        logit_diff_tv -= median_diff_t[..., None]
        bin.append(logit_diff_tv)

    vt = torch.stack(bin).view(-1, V * T)
    l1 = torch.abs(vt).sum(-1)
    l2 = (vt**2).sum(-1) ** 0.5

    bench = (l1 / l2) ** 2
    normalized_bench = bench / (V * T)
    
    return normalized_bench.mean().item()

In [7]:
n_sample = 64
sparsity_feature, sparsity_channel = [], []

for _ in tqdm(range(n_sample), unit='sample'):
    sample_1t = sample(ds, 1)
    sf = proc_ablate_feature(sample_1t)
    sparsity_feature.append(sf)

    sc = proc_ablate_resid_stream_channel(sample_1t)
    sparsity_channel.append(sc)

100%|██████████| 64/64 [12:23<00:00, 11.62s/sample]


In [8]:
print(f"downstream effect sparsity of SAE feature: {np.mean(sparsity_feature) * 100:.2f}%")
print(f"downstream effect sparsity of resid stream channel: {np.mean(sparsity_channel) * 100:.2f}%")

downstream effect sparsity of SAE feature: 19.31%
downstream effect sparsity of resid stream channel: 24.81%
