In [1]:
# flake8: noqa
# %%
import sys
sys.path.append("/workspace/smol-sae")
import torch.nn.functional as F
import transformer_lens as tl
import sae_lens as sl
from functools import partial
from torch import Tensor
from jaxtyping import Float
from smol_sae.base import Config
from smol_sae.utils import get_splits
from smol_sae.vanilla import VanillaSAE

from datasets import load_dataset 

# define loss function

# x [batch, d_model]
# grad_sae_acts [batch, ]
# assume gradients have a batch dim

def loss(
    sae, 
    x: Float[Tensor, "batch d_model"], 
    x_rec:  Float[Tensor, "batch d_model"], 
    sae_acts:  Float[Tensor, "batch d_sae"], 
    # backward hook
    grad_sae_acts:  Float[Tensor, "batch d_sae"],
    grad_x: Float[Tensor, "batch d_model"],
    lamda: float = 1.0,
    alpha: float = 1.0,
    beta: float = 1.0,
):
    # reconstruction term 
    l2_loss = (x-x_rec).square().sum()
    l1_loss = sae_acts.abs().sum()
    attr_loss = (sae_acts * grad_sae_acts).abs().sum()
    unexplained_loss = ((x-x_rec) * grad_x).abs().sum()
    
    return (
        l2_loss 
        + lamda * l1_loss 
        + alpha * attr_loss 
        + beta * unexplained_loss
    )
# %%


In [2]:
device = "cuda"
model = tl.HookedTransformer.from_pretrained("gelu-1l")
for hook_point in model.hook_points():
    print(hook_point.name)



Loaded pretrained model gelu-1l into HookedTransformer
hook_embed
hook_pos_embed
blocks.0.ln1.hook_scale
blocks.0.ln1.hook_normalized
blocks.0.ln2.hook_scale
blocks.0.ln2.hook_normalized
blocks.0.attn.hook_k
blocks.0.attn.hook_q
blocks.0.attn.hook_v
blocks.0.attn.hook_z
blocks.0.attn.hook_attn_scores
blocks.0.attn.hook_pattern
blocks.0.attn.hook_result
blocks.0.mlp.hook_pre
blocks.0.mlp.hook_post
blocks.0.hook_attn_in
blocks.0.hook_q_input
blocks.0.hook_k_input
blocks.0.hook_v_input
blocks.0.hook_mlp_in
blocks.0.hook_attn_out
blocks.0.hook_mlp_out
blocks.0.hook_resid_pre
blocks.0.hook_resid_mid
blocks.0.hook_resid_post
ln_final.hook_scale
ln_final.hook_normalized


In [3]:
train_dataset = load_dataset(
    "NeelNanda/c4-tokenized-2b", split="train", streaming=True
).with_format("torch")
train_batch = list(train_dataset.take(32))
print(train_batch)

Resolving data files:   0%|          | 0/23 [00:00<?, ?it/s]

[{'tokens': tensor([    1,   769, 41811,  ...,  6109,  1849,   360])}, {'tokens': tensor([    1,  2009,  7402,  ...,   825,   274, 13085])}, {'tokens': tensor([    1,  1582,  7537,  ...,  2411,  5379, 12147])}, {'tokens': tensor([    1,  1056,   274,  ...,  7538, 10523,   274])}, {'tokens': tensor([    1, 23947,  9863,  ...,   274,   254,  4237])}, {'tokens': tensor([    1,  7732, 13492,  ...,   368, 15813,   772])}, {'tokens': tensor([    1,  1391,  1209,  ...,   282, 13905,   276])}, {'tokens': tensor([   1,  282, 6134,  ...,  254, 7736, 2832])}, {'tokens': tensor([    1,   300,    15,  ..., 15543,    14,   390])}, {'tokens': tensor([   1, 3056,   16,  ...,  368,  670, 3202])}, {'tokens': tensor([   1, 3983,  328,  ...,  618,  801,  671])}, {'tokens': tensor([    1,     0,  7345,  ...,   254,  7885, 19282])}, {'tokens': tensor([    1,   286,  1006,  ..., 44292,    16,   380])}, {'tokens': tensor([    1, 11444,   407,  ...,    25,    28,  2004])}, {'tokens': tensor([    1, 10077,   47

In [None]:
config = Config(
    n_buffers=100, expansion=4, buffer_size=2**8, sparsities=(0.1, 1.0), device=device
)
sae = VanillaSAE(config, model)
print(sae.d_model)
print(model.cfg.d_model)


In [6]:
gradients = {}

def fwd_patch_model_with_sae(act, hook, sae):
    sae_out, hidden = sae(act)[:2]
    sae_err = act - sae_out.detach()
    return sae_out + sae_err 

def bwd_patch_model_gradient(grad_act, hook):
    global gradients 
    gradients[hook.name] = grad_act.detach()
    return grad_act



with model.hooks(
    fwd_hooks=[("blocks.0.mlp.hook_pre", partial(fwd_patch_model_with_sae, sae=sae))],
    bwd_hooks=[("blocks.0.mlp.hook_pre", bwd_patch_model_gradient)],
):
    model(train_batch[0]["tokens"].to(device))
    # sae(train_batch[0]["input_ids"])

RuntimeError: The size of tensor a (2048) must match the size of tensor b (512) at non-singleton dimension 2