In [4]:
# @title Running in Colab

from IPython.display import HTML, Javascript, display

display(
    HTML(
        """<a href="https://colab.research.google.com/github/evan-lloyd/mechinterp-experiments/blob/main/next_layer_sae/e2e_demo.ipynb" target="_blank" id="colab-button">
            <button style="background-color: #4285f4; color: white; padding: 10px 20px; border: none; border-radius: 4px; cursor: pointer; font-size: 14px;">
                Run this notebook in Google Colab
            </button>
        </a>"""
    )
)
display(
    Javascript("""
        setTimeout(() => {
            if(typeof google.colab != "undefined") {
                document.querySelector("#colab-button").remove()
            }
        }, 0);
        """)
)

<IPython.core.display.Javascript object>

# Initialize notebook environment


In [None]:
import os

# If we're running in Colab, we need to clone the non-notebook source from git.
if os.getenv("COLAB_RELEASE_TAG") and not os.path.isdir(
    "/content/mechinterp-experiments"
):
    ip = get_ipython()  #  pyright: ignore[reportUndefinedVariable]
    ip.run_cell_magic(
        "bash",
        "",
        """
    git clone --filter=blob:none --no-checkout https://github.com/evan-lloyd/mechinterp-experiments.git
    cd mechinterp-experiments
    git sparse-checkout init --no-cone
    echo "/next_layer_sae" > .git/info/sparse-checkout
    git checkout
  """,
    )
    ip.run_line_magic("cd", "mechinterp-experiments/next_layer_sae")
else:
    print("Already cloned source, or not running in Colab.")

# Nice for dev, but not needed for Colab.
try:
    # This uses a library called jurigged to hot-reload code when it is changed.
    # For reasons I've never been able to figure out, the IPython %autoreload magic
    # completely fails to work with the kind of structure I use in this notebook.
    import next_layer_sae._autoreload
except Exception:
    raise

Already cloned source, or not running in Colab.


In [8]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

# Tweak TRAINING_BATCH_SIZE for your hardware if necessary
if torch.cuda.is_available():
    TRAINING_DEVICE = "cuda"
    TRAINING_BATCH_SIZE = 64
elif torch.mps.is_available():
    TRAINING_DEVICE = "mps"
    TRAINING_BATCH_SIZE = 8
else:
    TRAINING_DEVICE = "cpu"
    TRAINING_BATCH_SIZE = 8

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
training_dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
validation_dataset = load_dataset(
    "roneneldan/TinyStories", split="validation", streaming=True
)
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-33M").to(
    TRAINING_DEVICE
)

print(model)

GPTNeoForCausalLM(
  (transformer): GPTNeoModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-3): 4 x GPTNeoBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPTNeoAttention(
          (attention): GPTNeoSelfAttention(
            (attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_dropout): Dropout(p=0.0, inplace=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPTNeoMLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_feat

In [9]:
# Caches model activations to these directories; modify if necessary, or set to None to disable.
# These take up ~16GB and ~1.5GB respectively, but will save a fair bit of time when running
# the notebook, since they can be re-used across all methods we're comparing.
TRAINING_CACHE_DIR = ".training_cache"
VALIDATION_CACHE_DIR = ".validation_cache"
NUM_TRAINING_TOKENS = 1e6
EVAL_INTERVAL = 1e5
NUM_VALIDATION_TOKENS = 1e5
D_MODEL = model.config.hidden_size
D_SAE = D_MODEL * 4
TOKENIZER_BATCH_SIZE = 128

# Train SAEs for comparison


In [10]:
from next_layer_sae.sae import SAE
from next_layer_sae.training import TrainingConfig, TrainingMethod, train

saes = {
    method: {
        layer: SAE(
            D_MODEL,
            D_SAE,
            device=model.device,
            kind="topk",
            topk=100,
        )
        for layer in range(model.config.num_layers)
    }
    for method in TrainingMethod
}

training_config = {
    method: TrainingConfig(
        tokenizer_batch_size=TOKENIZER_BATCH_SIZE,
        training_batch_size=TRAINING_BATCH_SIZE,
        num_train_tokens=int(1e6),
        dense_weight=0.0,
        idempotency_weight=0.0,
        eval_interval=int(1e5),
        train_layers=list(range(model.config.num_layers)),
        lr=1e-3,
        use_next_layer_sae=method is TrainingMethod.next_layer,
        next_reconstruction_weight=1.0 if method is TrainingMethod.next_layer else 0.0,
        reconstruction_weight=1.0,
        use_kl_on_final_layer=True,
        balance_reconstruction_losses=method is TrainingMethod.next_layer,
        use_weighted_mask=False,
        method=method,
    )
    for method in TrainingMethod
}

In [11]:
import os

from next_layer_sae.training import build_cache

if TRAINING_CACHE_DIR and (
    not os.path.exists(TRAINING_CACHE_DIR) or not os.listdir(TRAINING_CACHE_DIR)
):
    build_cache(
        TRAINING_CACHE_DIR,
        model,
        tokenizer,
        training_dataset,
        tokenizer_batch_size=TOKENIZER_BATCH_SIZE,
        inference_batch_size=TRAINING_BATCH_SIZE,
        num_tokens=NUM_TRAINING_TOKENS,
    )

if VALIDATION_CACHE_DIR and (
    not os.path.exists(VALIDATION_CACHE_DIR) or not os.listdir(VALIDATION_CACHE_DIR)
):
    build_cache(
        VALIDATION_CACHE_DIR,
        model,
        tokenizer,
        validation_dataset,
        tokenizer_batch_size=TOKENIZER_BATCH_SIZE,
        inference_batch_size=TRAINING_BATCH_SIZE,
        num_tokens=NUM_VALIDATION_TOKENS,
    )

## Next-layer auxiliary loss (my method)


In [8]:
train(
    model,
    tokenizer,
    saes[TrainingMethod.next_layer],
    training_dataset,
    training_config[TrainingMethod.next_layer],
    cache_dir=TRAINING_CACHE_DIR,
    reinit_weights=True,
)

Layer 3

  0%|          | 0/1000000 [00:00<?, ?it/s]

Layer 2

  0%|          | 0/1000000 [00:00<?, ?it/s]

Layer 1

  0%|          | 0/1000000 [00:00<?, ?it/s]

Layer 0

  0%|          | 0/1000000 [00:00<?, ?it/s]

[1m[32mAdd next_layer_sae.training.train_e2e @L379[m
[1m[33mUpdate next_layer_sae.sae_data.get_sae_data @L179[m
[1m[33mUpdate next_layer_sae.sae_data.get_sae_data @L179[m


## Full end-to-end training
Recreation of the method SAE_e2e+ds from
> Braun, Dan, Jordan Taylor, Nicholas Goldowsky-Dill, and Lee Sharkey. 2024. “Identifying Functionally Important Features with End-to-End Sparse Dictionary Learning.” arXiv [Cs.LG]. arXiv. http://arxiv.org/abs/2405.12241.


In [12]:
train(
    model,
    tokenizer,
    saes[TrainingMethod.e2e],
    training_dataset,
    training_config[TrainingMethod.e2e],
    cache_dir=TRAINING_CACHE_DIR,
    reinit_weights=True,
)

Layer 3

Layer 2

Layer 1

Layer 0

  0%|          | 0/1000000 [00:00<?, ?it/s]

## End-to-end fine-tuning

Recreation of the KL fine-tuning method from
> Karvonen, Adam. 2025. “Revisiting End-to-End Sparse Autoencoder Training: A Short Finetune Is All You Need.” arXiv [Cs.LG]. arXiv. http://arxiv.org/abs/2503.17272.


In [7]:
train(
    model,
    tokenizer,
    saes[TrainingMethod.finetuned],
    training_dataset,
    training_config[TrainingMethod.finetuned],
    cache_dir=TRAINING_CACHE_DIR,
    reinit_weights=True,
)

Layer 3

  0%|          | 0/1000000 [00:00<?, ?it/s]

Layer 2

  0%|          | 0/1000000 [00:00<?, ?it/s]

Layer 1

  0%|          | 0/1000000 [00:00<?, ?it/s]

Layer 0

  0%|          | 0/1000000 [00:00<?, ?it/s]

# Evaluations and comparisons


In [None]:
import matplotlib.pyplot as plt
import numpy as np

from next_layer_sae.validation import validate_saes

all_evals, replacement_evals, position_ids = validate_saes(
    model,
    tokenizer,
    saes[TrainingMethod.e2e],
    validation_dataset,
    num_tokens=int(1e5),
    # num_batches=1,
    tokenizer_batch_size=training_config[TrainingMethod.e2e].tokenizer_batch_size,
    inference_batch_size=training_config[TrainingMethod.e2e].training_batch_size,
    use_next_layer_sae=False,
    cache_dir=VALIDATION_CACHE_DIR,
)

np.set_printoptions(threshold=10_000)

# Define colors for each layer
colors = [
    "blue",
    "red",
    "green",
    "orange",
    "purple",
    "brown",
    "pink",
    "gray",
    "olive",
    "cyan",
]

# Plot histogram of position_ids
# plt.figure(figsize=(12, 6))
# plt.hist(position_ids, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
# plt.title("Distribution of Position IDs")
# plt.xlabel("Position ID")
# plt.ylabel("Frequency")
# plt.grid(True, alpha=0.3)
# plt.tight_layout()
# plt.show()


# for plot_vs_pos in ["rcn", "next_rcn"]:
#     # Plot next_rcn vs position_ids
#     if plot_vs_pos in all_evals:
#         plt.figure(figsize=(12, 6))
#         # Create scatter plot for each layer
#         for i, (layer, layer_eval) in enumerate(all_evals[plot_vs_pos].items()):
#             plt.scatter(
#                 position_ids,
#                 layer_eval,
#                 alpha=0.6,
#                 color=colors[i % len(colors)],
#                 label=f"Layer {layer}",
#                 s=1
#             )

#         plt.title(f"{plot_vs_pos} vs Position IDs")
#         plt.xlabel("Position ID")
#         plt.ylabel(plot_vs_pos)
#         plt.legend()
#         plt.grid(True, alpha=0.3)
#         plt.tight_layout()
#         plt.show()

#     # Plot mean next_rcn vs position_ids
#     if plot_vs_pos in all_evals:
#         plt.figure(figsize=(12, 6))

#         # Calculate mean next_rcn for each position and layer
#         for i, (layer, layer_eval) in enumerate(all_evals[plot_vs_pos].items()):
#             # Create a dictionary to store values for each position
#             position_values = {}

#             # Group values by position
#             for pos, val in zip(position_ids[:len(layer_eval)], layer_eval):
#                 if pos not in position_values:
#                     position_values[pos] = []
#                 position_values[pos].append(val)

#             # Calculate mean for each position
#             positions = sorted(position_values.keys())
#             mean_values = [np.mean(position_values[pos]) for pos in positions]

#             plt.plot(
#                 positions,
#                 mean_values,
#                 color=colors[i % len(colors)],
#                 label=f"Layer {layer}",
#                 linewidth=2,
#                 marker='o',
#                 markersize=3
#             )

#         plt.title(f"Mean {plot_vs_pos} vs Position IDs")
#         plt.xlabel("Position ID")
#         plt.ylabel(f"Mean {plot_vs_pos}")
#         plt.legend()
#         plt.grid(True, alpha=0.3)
#         plt.tight_layout()
#         plt.show()

for key in all_evals.keys():
    plt.figure(figsize=(12, 6))

    # Create overlapping histograms on the same plot
    for i, (layer, layer_eval) in reversed(list(enumerate(all_evals[key].items()))):
        plt.hist(
            layer_eval,
            bins=50,
            alpha=0.7,
            edgecolor="black",
            color=colors[i % len(colors)],
            label=f"Layer {layer}",
        )
    plt.title(f"{key} distribution comparison")
    plt.xlabel(key)
    plt.ylabel("Frequency")
    plt.legend()


# Plot replacement evaluation metrics
print("Replacement evaluation metrics:")
print("=" * 50)

for key, value in replacement_evals.items():
    plt.figure(figsize=(12, 6))

    plt.hist(
        value,
        bins=50,
        alpha=0.7,
        edgecolor="black",
        color="blue",
    )
    plt.title(f"{key} distribution")
    plt.xlabel(key)
    plt.ylabel("Frequency")
    print(f"{key}: mean = {np.mean(value):.6f}")


plt.grid(True)
plt.tight_layout()
plt.show()

print()


# Print mean values for each evaluation metric
print("Mean evaluation metrics:")
print("=" * 50)

for key in all_evals.keys():
    print(f"{key}:")
    for layer, layer_eval in reversed(all_evals[key].items()):
        mean = np.mean(layer_eval)
        print(f"  Layer {layer}: {mean:.6f}")
