# Relicating How We Train SAEs

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
import torch
import os
import sys

sys.path.append("..")

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.lm_runner import language_model_sae_runner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gelu-1l")

Let's use test driven design. I'm going to put in the config args that should make the library replicate the SAE training result and then work backward from there. 

# Notebook


In [None]:


import plotly.express as px
from sae_lens.training.sparse_autoencoder import SparseAutoencoder

sparse_autoencoder = SparseAutoencoder(cfg)
px.histogram(sparse_autoencoder.W_dec.norm(dim=1).detach().cpu()).show()


In [None]:
from sae_lens.training.optim import L1Scheduler

total_training_steps = 20_000
l1_warmup_steps = 1_000
final_l1_value = sparse_autoencoder.cfg.l1_coefficient

l1_scheduler = L1Scheduler(
    total_steps=sparse_autoencoder.cfg.training_tokens // sparse_autoencoder.cfg.train_batch_size,
    l1_warm_up_steps=l1_warmup_steps,
    sparse_autoencoder=sparse_autoencoder
)

l1_values = []
for _ in range(total_training_steps):
    l1_values.append(sparse_autoencoder.l1_coefficient)
    l1_scheduler.step()
        
px.line(y=l1_values).show()

# Loss Function

In [None]:
from sae_lens import ActivationsStore
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gelu-1l")

activation_store = ActivationsStore.from_config(model, sparse_autoencoder.cfg)

In [None]:
cfg.d_in ** 0.5

In [None]:
activation_store.estimated_norm_scaling_factor

In [None]:
activation_store.get_buffer(32).norm(dim=-1).flatten().detach().cpu().mean()

In [None]:
px.histogram(activation_store.get_buffer(32).norm(dim=-1).flatten().detach().cpu()).show()

In [None]:
activations = activation_store.next_batch()

feature_acts, hidden_pre = sparse_autoencoder._encode_with_hidden_pre(activations)

In [None]:
sparse_autoencoder.get_sparsity_loss_term_decoder_norm(feature_acts).shape

# Activation Scaling

In [None]:
sparse_autoencoder.cfg.n_batches_in_buffer

In [None]:
import numpy as np
buffer_norm_means = []
for _ in range(10):
    buffer_norm_means.append(activation_store.get_buffer(64).squeeze().norm(dim=1).mean().item())
    


In [None]:
px.histogram(x=buffer_norm_means, 
        title = f"Buffer Norm Mean over 10 batches, mean: {np.array(buffer_norm_means).mean()} std:{np.array(buffer_norm_means).std()}").show()


In [None]:
import numpy as np 
np.sqrt(sparse_autoencoder.d_in)