In [1]:
%reload_ext autoreload
%autoreload 2

import time

import numpy as np
import torch

import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from sparse_autoencoder.loss import normalized_mean_squared_error
from tqdm import tqdm

from openwebtext import load_owt, sample
from pretrained_sae import load_sae

torch.set_grad_enabled(False)

seed = 42
rng = np.random.default_rng(seed)

layer_index = 6
location = "resid_post_mlp"
device = utils.get_device()

In [2]:
ds = load_owt()
gpt2 = HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)

sae_32k = load_sae(32, location, layer_index, device)
sae_128k = load_sae(128, location, layer_index, device)

Loading dataset from disk:   0%|          | 0/152 [00:00<?, ?it/s]

Loaded 8,013,769 sample texts from data/owt_tokenized




Loaded pretrained model gpt2 into HookedTransformer
Loaded pretrained SAE data/sae/v5_32k_location_resid_post_mlp_layer_6.pt
Loaded pretrained SAE data/sae/v5_128k_location_resid_post_mlp_layer_6.pt


In [3]:
n_batch = 256
batch_size = 16
mse_32k_bin = []
mse_128k_bin = []

def sae_mse(sae, act, bin):
    latent_act, info = sae.encode(act)
    recon_act_btd = sae.decode(latent_act, info)
    mse = normalized_mean_squared_error(recon_act_btd, act)
    bin.append(mse.item())

def hook_fn_compute_mse(act_btd, hook):
    sae_mse(sae_32k, act_btd, mse_32k_bin)
    sae_mse(sae_128k, act_btd, mse_128k_bin)


hook_name = utils.get_act_name("resid_post", layer_index)

print(f"start processing MSE for {n_batch * batch_size * 64:,} tokens")
with tqdm(range(n_batch), unit="batch", postfix={"tps": 0}) as pbar:
    for _ in pbar:
        start = time.perf_counter()
        
        batch = sample(ds, batch_size=batch_size, rng=rng)
        gpt2.run_with_hooks(batch, return_type=None, fwd_hooks=[(hook_name, hook_fn_compute_mse)])

        delta = time.perf_counter() - start
        tok_per_batch = batch_size * 64
        tps = tok_per_batch / delta
        
        pbar.set_postfix({"tps": f"{tps:,.2f}"})


start processing MSE for 262,144 tokens


100%|██████████| 256/256 [00:36<00:00,  7.08batch/s, tps=7,811.72]


In [4]:
np.mean(mse_32k_bin), np.mean(mse_128k_bin)

(0.13054200427723117, 0.0986246868269518)