In [1]:
"""dimension annotation
b: batch
t: token position
d: d_model
v: model token vocab size
l: SAE n latent
k: topk
"""

import numpy as np
import torch

import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from sparse_autoencoder.model import Autoencoder, TopK
from sparse_autoencoder.loss import autoencoder_loss
from tqdm import tqdm

from openwebtext import load_owt, sample


ds = load_owt()
gpt2 = HookedTransformer.from_pretrained("gpt2", center_writing_weights=False)

Loading dataset from disk:   0%|          | 0/152 [00:00<?, ?it/s]

Loaded 8,013,769 sample texts from data/owt_tokenized




Loaded pretrained model gpt2 into HookedTransformer


In [2]:
target_layer = 8
k = 32
batch_size = 64
n_batch = 64

In [3]:
train_data = []

with torch.no_grad():
    for _ in range(n_batch):
        batch = sample(ds, batch_size)
        loss, cache = gpt2.run_with_cache(batch, return_type="loss")
        act_btd = cache[utils.get_act_name("resid_post", layer=target_layer)]
        act_bd = act_btd[:, -1]
        train_data.append(
            (batch.detach().cpu(), act_bd.detach().cpu(), loss.detach().cpu())
        )
    
del(gpt2)

In [4]:
n_latents = 2**15
n_inputs = 768
act_fn = TopK(k)

device = utils.get_device()
sae = Autoencoder(n_latents, n_inputs, act_fn, tied=True, normalize=True).to(device)
optimizer = torch.optim.Adam(sae.parameters(), lr=5e-4)

for i, (_, act, _) in enumerate(train_data):
    act = act.to(device)
    _, latent, recon = sae(act)
    loss = autoencoder_loss(recon, act, latent, l1_weight=0)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

    print(f"step({i}): loss = {loss.item():.3f}")

step(0): loss = 0.772
step(1): loss = 0.765
step(2): loss = 0.760
step(3): loss = 0.754
step(4): loss = 0.745
step(5): loss = 0.738
step(6): loss = 0.730
step(7): loss = 0.718
step(8): loss = 0.710
step(9): loss = 0.707
step(10): loss = 0.701
step(11): loss = 0.684
step(12): loss = 0.670
step(13): loss = 0.667
step(14): loss = 0.664
step(15): loss = 0.657
step(16): loss = 0.648
step(17): loss = 0.641
step(18): loss = 0.631
step(19): loss = 0.626
step(20): loss = 0.624
step(21): loss = 0.613
step(22): loss = 0.615
step(23): loss = 0.592
step(24): loss = 0.599
step(25): loss = 0.602
step(26): loss = 0.584
step(27): loss = 0.587
step(28): loss = 0.586
step(29): loss = 0.578
step(30): loss = 0.567
step(31): loss = 0.566
step(32): loss = 0.559
step(33): loss = 0.559
step(34): loss = 0.534
step(35): loss = 0.533
step(36): loss = 0.539
step(37): loss = 0.519
step(38): loss = 0.505
step(39): loss = 0.531
step(40): loss = 0.527
step(41): loss = 0.502
step(42): loss = 0.522
step(43): loss = 0.50