# Training and evaluating an SAE on Gemma-2-2B layer 12

You probably don't want to do this sort of training in a notebook, but the processes is demonstrated here regardless. We'll train a single SAE with 32k latents and L0=100 on Gemma-2-2b layer 12 using 500M tokens from the Pile. In the paper, we train a suite of SAEs with different L0s. We won't attempt to do that here, as really you should run that sort of sweep on a cluster, but the process is the same (except for changing L0 for each run).

In [None]:
from sparse_but_wrong.enchanced_batch_topk_sae import (
    EnhancedBatchTopKTrainingSAEConfig,
)

from sae_lens import LanguageModelSAETrainingRunner, LanguageModelSAERunnerConfig

# These parameters assume a H100 80GB GPU. You can change them to fit your GPU.

runner_cfg = LanguageModelSAERunnerConfig(
    sae=EnhancedBatchTopKTrainingSAEConfig(
        k=100,
        d_in=2304,
        d_sae=32 * 1024,
        normalize_acts_by_decoder_norm=True,
    ),
    model_name="google/gemma-2-2b",
    model_class_name="AutoModelForCausalLM",
    dataset_path="monology/pile-uncopyrighted",
    hook_name=f"model.layers.12",
    context_size=1024,
    training_tokens=500_000_000,
    device="cuda",
    lr=3e-4,
    n_batches_in_buffer=64,
    train_batch_size_tokens=4096,
    store_batch_size_prompts=12,
    eval_batch_size_prompts=6,
    autocast_lm=True,
    autocast=True,
)

runner = LanguageModelSAETrainingRunner(runner_cfg)

sae = runner.run()
sae.save_inference_model("l0_100_sae")

Now that we've trained the SAE, let's calculate its nth decoder projection.

In [None]:
import torch
from tqdm.auto import tqdm
from sae_lens import ActivationsStore, SAE
from transformer_lens import HookedTransformer

from sparse_but_wrong.nth_decoder_projection import nth_decoder_projection

# let's start by loading the saved inference SAE

loaded_sae = SAE.load_from_disk("l0_100_sae", device="cuda")

# First, we load some training activations from the Pile dataset to test nth decoder projection against

@torch.inference_mode()
def get_activation_batches(
    n_batches: int = 100,
    store_batch_size_prompts: int = 8,
    n_batches_in_buffer: int = 8,
    train_batch_size_tokens: int = 4096,
):
    model = HookedTransformer.from_pretrained_no_processing("gemma-2-2b", device="cuda")
    store = ActivationsStore(
        model=model,
        dataset="monology/pile-uncopyrighted",
        d_in=2304,
        hook_name="blocks.12.hook_resid_post",
        hook_head_index=None,
        context_size=1024,
        prepend_bos=True,
        streaming=True,
        store_batch_size_prompts=store_batch_size_prompts,
        train_batch_size_tokens=train_batch_size_tokens,
        n_batches_in_buffer=n_batches_in_buffer,
        total_training_tokens=10**9,
        normalize_activations="none",
        dataset_trust_remote_code=True,
        dtype="float32",
        device=torch.device("cuda"),
        seqpos_slice=(None,),
    )
    return [store.next_batch().squeeze(1) for _ in tqdm(range(n_batches), desc="Generating inputs")]

train_activation_batches = get_activation_batches()

# we'll use N = 12_000, (a bit less than half of the SAE's width).
# this is an arbitrary but reasonable choice. In practice, anything less than ~ d_sae / 2 seems to work well.
N = 12_000

with torch.no_grad():
    projections = []
    for batch in train_activation_batches:
        projections.append(nth_decoder_projection(batch, loaded_sae, N))

    s_n = torch.stack(projections).mean().item()

print(f"nth decoder projection (N={N}): {s_n}")

Finally, let's evaluate the SAE using the [sae-Probes benchmark](https://github.com/sae-probes/sae-probes).

In [None]:
from sae_probes import run_sae_evals

run_sae_evals(
    sae=loaded_sae,
    model_name="gemma-2-2b",
    hook_name="blocks.12.hook_resid_post",
    reg_type="l1",
    setting="normal",
    ks=[1, 16],
    results_path="l0_100_sae_probes_results",
)

This saves the sparse probing results for each dataset to the `l0_100_sae_probes_results` directory. Let's load these and calculate a mean F1 score for the SAE.

In [None]:
import pandas as pd
from pathlib import Path
import json

sae_probes_result_paths = Path("l0_100_sae_probes_results").glob("*/normal_setting/*.json")

rows = []
for sae_probes_path in sae_probes_result_paths:
    with open(sae_probes_path) as f:
        sae_probes_data = json.load(f)
        for ds_result in sae_probes_data:
            rows.append(ds_result)
sae_probes_df = pd.DataFrame(rows)
sae_probes_mean_df = sae_probes_df.groupby(['k'])[['test_auc', 'test_acc', 'test_f1']].mean().reset_index()

sae_probes_mean_df