In [1]:
"""dimension annotation
b: batch
t: token position
d: d_model

t64 means it's still a t dimension with size of 64
"""

from functools import partial

import numpy as np
import torch

import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from tqdm import tqdm

from openwebtext import load_owt, sample
from pretrained_sae import load_sae

torch.set_grad_enabled(False)

seed = 42
rng = np.random.default_rng(seed)

layer_index = 6
location = "resid_post_mlp"
device = utils.get_device()

ds = load_owt()
gpt2 = HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)

sae_32k = load_sae(32, location, layer_index, device)
sae_128k = load_sae(128, location, layer_index, device)

Loading dataset from disk:   0%|          | 0/152 [00:00<?, ?it/s]

Loaded 8,013,769 sample texts from data/owt_tokenized




Loaded pretrained model gpt2 into HookedTransformer
Loaded pretrained SAE data/sae/v5_32k_location_resid_post_mlp_layer_6.pt
Loaded pretrained SAE data/sae/v5_128k_location_resid_post_mlp_layer_6.pt


In [2]:
def hook_fn_reconstruct_act(act_btd, hook, sae):
    latent, info = sae.encode(act_btd)
    recon_act_btd = sae.decode(latent, info)

    return recon_act_btd

In [3]:
batch_size = 16
n_batch = 256

loss = []
loss_sae32k = []
loss_sae128k = []

for _ in tqdm(range(n_batch), unit='batch'):
    batch_bt64 = sample(ds, batch_size, rng=rng)

    l = gpt2(batch_bt64, return_type='loss')
    l32 = gpt2.run_with_hooks(
        batch_bt64,
        return_type='loss',
        fwd_hooks=[(
            utils.get_act_name('resid_post', layer_index),
            partial(hook_fn_reconstruct_act, sae=sae_32k)
        )]
    )
    l128 = gpt2.run_with_hooks(
        batch_bt64,
        return_type='loss',
        fwd_hooks=[(
            utils.get_act_name('resid_post', layer_index),
            partial(hook_fn_reconstruct_act, sae=sae_128k)
        )]
    )

    loss.append(l.item())
    loss_sae32k.append(l32.item())
    loss_sae128k.append(l128.item())


  0%|          | 0/256 [00:00<?, ?batch/s]

100%|██████████| 256/256 [01:09<00:00,  3.69batch/s]


In [4]:
print(f"loss {np.mean(loss): .3f}")
print(f"sae 32k loss {np.mean(loss_sae32k): .3f}")
print(f"sae 128k loss {np.mean(loss_sae128k): .3f}")

loss  3.733
sae 32k loss  3.834
sae 128k loss  3.792


In [5]:
print(f'delta loss, sae 32k {np.mean(loss_sae32k) - np.mean(loss):.5f}')
print(f'delta loss, sae 128k {np.mean(loss_sae128k) - np.mean(loss):.5f}')

delta loss, sae 32k 0.10161
delta loss, sae 128k 0.05963
