# Training and evaluating a SAE on Gemma-2-2B

You probably don't want to do this sort of training in a notebook, but the processes is demonstrated here regardless. We'll train a single SAE with 32k latents and L0=100 on Gemma-2-2b layer 12 using 500M tokens from the Pile. In the paper, we train a suite of SAEs with different L0s. We won't attempt to do that here, as really you should run that sort of sweep on a cluster, but the process is the same (except for changing L0 for each run).

In [None]:
from sparse_but_wrong.enchanced_batch_topk_sae import (
    EnhancedBatchTopKTrainingSAEConfig,
)

from sae_lens import LanguageModelSAETrainingRunner, LanguageModelSAERunnerConfig


runner_cfg = LanguageModelSAERunnerConfig(
    sae=EnhancedBatchTopKTrainingSAEConfig(
        k=100,
        d_in=1024,
        d_sae=32 * 1024,
        normalize_acts_by_decoder_norm=True,
    ),
    model_name="google/gemma-2-2b",
    model_class_name="AutoModelForCausalLM",
    hook_name=f"model.layers.12",
    context_size=1024,
    training_tokens=500_000_000,
    device="cuda",
    lr=3e-4,
    n_batches_in_buffer=64,
    train_batch_size_tokens=4096,
    store_batch_size_prompts=12,
    eval_batch_size_prompts=6,
    autocast_lm=True,
    autocast=True,
    n_checkpoints=2,
)

runner = LanguageModelSAETrainingRunner(runner_cfg)

sae = runner.run()
sae.save_inference_model("l0_100_sae")

Now that we've trained the SAE, let's evaluate it using the SAE Probes benchmark, and calculate nth decoder projection.

In [None]:
from sae_probes import run_sae_evals

run_sae_evals(
    sae=sae,
    model_name="gemma-2-2b",
    hook_name="blocks.12.hook_resid_post",
    reg_type="l1",
    setting="normal",
    ks=[1, 16],
    results_path="l0_100_sae_probes_results",
)