In [None]:
from functools import partial
from pathlib import Path
import torch
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

from sparse_but_wrong.toy_models.get_training_batch import generate_random_correlation_matrix, get_training_batch
from sparse_but_wrong.toy_models.toy_model import ToyModel
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT
from sparse_but_wrong.util import DEFAULT_DEVICE

tqdm._instances.clear()  # type: ignore

feat_probs = 0.345 * (50 - torch.arange(50) - 1) / 50 + 0.05
correlations = generate_random_correlation_matrix(
    correlation_strength_range=(0.3, 0.9),
    num_features=50,
    seed=42,
)
indices = torch.arange(50) + 1
df = pd.DataFrame({
    "P_i": feat_probs,
    "feature": map(str, indices.tolist()),
})


toy_model = ToyModel(num_feats=50, hidden_dim=100).to(DEFAULT_DEVICE)


generate_batch = partial(
    get_training_batch,
    firing_probabilities=feat_probs,
    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,
    correlation_matrix=correlations,
)

Path("plots/toy_setup").mkdir(parents=True, exist_ok=True)

plt.rcParams.update({"figure.dpi": 150})
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(3, 2))
    sns.barplot(data=df, x="feature", y="P_i")
    plt.xlabel("Feature")
    plt.ylabel("$P_i$")
    plt.title("Feature firing probabilities $P_i$")
    # Increase tick spacing to prevent overlapping
    plt.xticks(range(0, len(df), 5), [str(i) for i in range(0, len(df), 5)])  # Show every 5th tick, 0-indexed
    plt.tight_layout()
    plt.savefig("plots/toy_setup/toy_model_feature_firing_probabilities.pdf")
    plt.show()

plt.rcParams.update({"figure.dpi": 150})
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(2.5, 2))
    sns.heatmap(correlations, cmap="RdBu", center=0, vmin=-1, vmax=1)
    plt.xlabel("Feature")
    plt.ylabel("Feature")
    plt.title("Feature correlation matrix")
    # Increase tick spacing to prevent overlapping
    plt.xticks(range(0, len(correlations), 10), [str(i) for i in range(0, len(correlations), 10)])  # Show every 10th tick, 0-indexed
    plt.yticks(range(0, len(correlations), 10), [str(i) for i in range(0, len(correlations), 10)])  # Show every 10th tick, 0-indexed
    # plt.tight_layout()
    plt.savefig("plots/toy_setup/toy_model_correlation_matrix.pdf")
    plt.show()

## Finding the True L0

Next, we calculate the true L0 for this dataset (spoiler: it's ~11)

In [None]:
sample = generate_batch(100_000)
true_l0 = (sample > 0).float().sum(dim=-1).mean()
print(f"True L0: {true_l0}")

## Training an SAE with the right L0

Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=11, d_in=toy_model.embed.weight.shape[0], d_sae=50)
sae_full = BatchTopKTrainingSAE(cfg)

train_toy_sae(sae_full, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_full, toy_model, "SAE L0 = True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_eq_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_eq_true_l0_decoder_cos_sims.png")

As we can see above, the SAE has learned the correct features.

## What if reduce the L0 below the number of true features?

If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. How will it handle this? We set L0=9 below.

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=9, d_in=toy_model.embed.weight.shape[0], d_sae=50)
sae_narrow = BatchTopKTrainingSAE(cfg)

train_toy_sae(sae_narrow, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_narrow, toy_model, "SAE L0 < True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title="SAE L0 $<$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_lt_true_l0_decoder_cos_sims.pdf")


The SAE is now broken - we no longer have a clear latent tracking feature 0, even though this is the most important feature! Instead every latent gets a component of feature 0 merged into it. The latents tracking the feature with the highest correlation with feature 0 are the most messed up by this. Essentially, the SAE is trying to "cheat" and reconstruct feature 0 without having to dedicate a latent to it.

This breaks all the latents of the SAE by hedging a component of feautre 0 into them - but this broken behavior will result in a better MSE score than if the SAE learned all latents correctly, as we'll see next.

### MSE comparison: correct SAE vs low L0 SAE

Why does the SAE learn these broken latents instead of learning the correct representations of features and just selecting the top 9 instead of the top 11? Let's compare the MSE loss from running the correct SAE we trained with L0=5 with L0=4 and the broken SAE that we trained explicitly with L0=4.

We'll create 100k training samples, run them through our correct SAE modified with L0=4 and compare the MSE of this SAE with the broken SAE trained explicitly with L0=4.

In [None]:
sample_feats = toy_model(generate_batch(100_000))

# set k=9 for our correct SAE trained with k=11
sae_full.activation_fn.k = 9  # type: ignore[attr-defined]

correct_sae_mse = (sae_full(sample_feats) - sample_feats).pow(2).sum(dim=-1).mean().item()
narrow_sae_mse = (sae_narrow(sample_feats) - sample_feats).pow(2).sum(dim=-1).mean().item()

# reset k=5 for our correct SAE
sae_full.activation_fn.k = 11 # type: ignore[attr-defined]

print(f"Correct SAE MSE: {correct_sae_mse:.2f}")
print(f"Narrow SAE MSE: {narrow_sae_mse:.2f}")


The correct SAE run at L0=4 results in 2x the MSE loss of the broken SAE trained at k=4!

## Let's drop the L0 even further!

What if we drop the L0 further, so k=5?

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=5, d_in=toy_model.embed.weight.shape[0], d_sae=50)
sae_narrower = BatchTopKTrainingSAE(cfg)


# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.
train_toy_sae(sae_narrower, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims

plot_sae_feat_cos_sims(sae_narrower, toy_model, "SAE L0 << True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title="SAE L0 $\\ll$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_llt_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title="SAE L0 $<$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_llt_true_l0_decoder_cos_sims_v2.pdf")
plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title="SAE L0 $\\ll$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_llt_true_l0_decoder_cos_sims.png")


In [None]:
sample_feats = toy_model(generate_batch(100_000))

# set k=9 for our correct SAE trained with k=11
sae_full.activation_fn.k = 5  # type: ignore[attr-defined]

correct_sae_mse = (sae_full(sample_feats) - sample_feats).pow(2).sum(dim=-1).mean().item()
narrower_sae_mse = (sae_narrower(sample_feats) - sample_feats).pow(2).sum(dim=-1).mean().item()

# reset k=5 for our correct SAE
sae_full.activation_fn.k = 11  # type: ignore[attr-defined]

print(f"Correct SAE MSE: {correct_sae_mse:.2f}")
print(f"Narrower SAE MSE: {narrower_sae_mse:.2f}")

As expected, the SAE is even worse now! The lower-frequency features are still reconstructed mostly correctly aside from hedging a component of feature 0, but now even more of the earlier latents are messed up as the SAE finds degenerate ways to improve its reconstruction error despite not having enough L0 to do so correctly. It's hard to see from the plot, but the magnitude of the hedging of feature 0 is also higher now.

## What if we increase the L0 beyond the number of true features?

If we increase the L0, then then SAE will have more latents than necessary to reconstruct the input. Will the SAE still do the right thing? We set L0=14 below.

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=14, d_in=toy_model.embed.weight.shape[0], d_sae=50)
sae_wide = BatchTopKTrainingSAE(cfg)

train_toy_sae(sae_wide, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_wide, toy_model, "SAE L0 > True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_wide, toy_model, title="SAE L0 $>$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_gt_true_l0_decoder_cos_sims.pdf")


This is still mostly correct - we don't see any hedging at least. But the SAE is finding more degenerate local minima than it did before. Let's see what happens if we increase the L0 even more. 

## Increase L0 even more!

Let's see what happens if we increase the L0 to 18

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=18, d_in=toy_model.embed.weight.shape[0], d_sae=50)
sae_wider = BatchTopKTrainingSAE(cfg)


train_toy_sae(sae_wider, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_wider, toy_model, "SAE L0 >> True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title="SAE L0 $\\gg$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_ggt_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title="SAE L0 $>$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_ggt_true_l0_decoder_cos_sims_v2.pdf")
plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title="SAE L0 $\\gg$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_ggt_true_l0_decoder_cos_sims.png")


As expected, the SAE is even worse now, using its extra capacity to allow some more degenerate solutions. There is still notably no hedging though, but the SAE is clearly not doing what we want.

# Can we detect when we're at the correct L0?

Let's train SAEs at a range of L0, ranging from below the correct L0 to above the correct L0

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

saes_by_k = {}
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]:
    print(f"Training SAE with k={k}")
    cfg = BatchTopKTrainingSAEConfig(k=k, d_in=toy_model.embed.weight.shape[0], d_sae=50)
    saes_by_k[k] = BatchTopKTrainingSAE(cfg)
    train_toy_sae(saes_by_k[k], toy_model, generate_batch
)

In [None]:
from pathlib import Path
from typing import Callable
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sae_lens import TrainingSAE
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT

# Assuming ToyModel and other variables are defined elsewhere
# from your_module import ToyModel, saes_by_k, toy_model, generate_batch

sorted_dec_hidden_pres = {}

def calc_l0_dec_thresholds(sae: TrainingSAE, model, generate_batch: Callable[[int], torch.Tensor]) -> list[float]:
    inputs = model.embed(generate_batch(100_000))
    hidden_pre_dec = (inputs - sae.b_dec) @ sae.W_dec.T
    sorted_hidden_pre_dec = hidden_pre_dec.flatten().sort(descending=True).values
    k_inds = torch.arange(hidden_pre_dec.shape[-1]) * hidden_pre_dec.shape[0]
    return sorted_hidden_pre_dec[k_inds].tolist()

thresholds = {}
for sae in tqdm(saes_by_k.values()):
    thresholds[sae.cfg.k] = calc_l0_dec_thresholds(sae, toy_model, generate_batch)

data = []
for l0, l0_thresholds in thresholds.items():
    row = {"l0": l0}
    for i, threshold in enumerate(l0_thresholds):
        row[f"l0_dec_{i}"] = threshold
    data.append(row)
df = pd.DataFrame(data)

Path("plots/dpn").mkdir(parents=True, exist_ok=True)
for k in [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]:
    plt.rcParams.update({"figure.dpi": 150})
    sns.set_theme()
    with plt.rc_context(SEABORN_RC_CONTEXT):
        plt.figure(figsize=(3, 1.5))
        sns.lineplot(data=df, x="l0", y=f"l0_dec_{k}")
        
        # Add vertical line at true L0
        plt.axvline(x=11, color='red', linestyle='--', linewidth=0.5, alpha=0.7, label='True L0')
        
        plt.title(f"N={k} Decoder Projection vs SAE L0")
        plt.xlabel("SAE L0")
        plt.ylabel("$dp_n$")
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f"plots/dpn/dp_{k}.pdf")
        plt.show()

# Sparsity vs Reconstruction Tradeoff



In [None]:
from copy import deepcopy
import pandas as pd
import torch

import seaborn as sns
import matplotlib.pyplot as plt
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT

test_inputs = toy_model(generate_batch(100_000))

def calculate_variance_explained(sae, test_inputs):
    sae_outputs = sae(test_inputs)
    return 1 - (sae_outputs - test_inputs).pow(2).sum(dim=-1).mean().item() / test_inputs.pow(2).sum(dim=-1).mean().item()

with torch.no_grad():
    ground_truth_sae = deepcopy(sae_full)
    ground_truth_sae.activation_fn.k = 11 # type: ignore[attr-defined]
    ground_truth_sae.W_dec.data = toy_model.embed.weight.T.clone()
    ground_truth_sae.W_enc.data = toy_model.embed.weight.clone()
    ground_truth_sae.b_enc.data = torch.zeros_like(ground_truth_sae.b_enc.data)
    ground_truth_sae.b_dec.data = torch.zeros_like(ground_truth_sae.b_dec.data)
    if toy_model.embed.bias is not None:
        ground_truth_sae.b_dec.data = toy_model.embed.bias.clone()

data = []
for k, sae in saes_by_k.items():
    data.append({
        "k": k,
        "variance_explained": calculate_variance_explained(sae, test_inputs),
        "variant": "learned SAE"
    })
    ground_truth_sae.activation_fn.k = k # type: ignore[attr-defined]
    data.append({
        "k": k,
        "variance_explained": calculate_variance_explained(ground_truth_sae, test_inputs),
        "variant": "ground-truth SAE"
    })

df = pd.DataFrame(data)
df.sort_values(by="k", inplace=True)

sns.set_theme()
plt.rcParams.update({"figure.dpi": 150})
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(3, 2))
    sns.lineplot(data=df, x="k", y="variance_explained", hue="variant")
    plt.xlabel("L0")
    plt.ylabel("Variance explained")
    plt.title("Sparsity vs reconstruction tradeoff")
    plt.legend()
    plt.tight_layout()
    plt.savefig("plots/sparsity_vs_reconstruction_tradeoff.pdf")
    plt.show()



In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(ground_truth_sae, toy_model, "Ground truth SAE", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(ground_truth_sae, toy_model, title="Ground truth SAE", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="sae_ground_truth_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(saes_by_k[1], toy_model, title="L0=1 learned SAE", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_1_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(saes_by_k[2], toy_model, title="L0=2 learned SAE", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_2_decoder_cos_sims.pdf")

# Transitioning to the correct L0

Does it make a difference if we start at too low an L0 and transition to the correct L0, vs starting too high?

In [None]:
from sparse_but_wrong.enchanced_batch_topk_sae import (
    EnchancedBatchTopKTrainingSAE,
    EnhancedBatchTopKTrainingSAEConfig,
)
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = EnhancedBatchTopKTrainingSAEConfig(
    k=11,
    d_in=toy_model.embed.weight.shape[0],
    d_sae=50,
    initial_k=20,
    transition_k_duration_steps=8_000,
    transition_k_start_step=3_000,
)
sae_transition_down = EnchancedBatchTopKTrainingSAE(cfg)

train_toy_sae(sae_transition_down, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_transition_down, toy_model, "SAE L0 Decrease 20 to 11", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_transition_down, toy_model, title="SAE L0 Decrease 20 $\\to$ 11", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_20_to_11_decoder_cos_sims.pdf")

In [None]:
from sparse_but_wrong.enchanced_batch_topk_sae import (
    EnchancedBatchTopKTrainingSAE,
    EnhancedBatchTopKTrainingSAEConfig,
)
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = EnhancedBatchTopKTrainingSAEConfig(
    k=11,
    d_in=toy_model.embed.weight.shape[0],
    d_sae=50,
    initial_k=2,
    transition_k_duration_steps=25_000,
    transition_k_start_step=0,
)
sae_transition_up = EnchancedBatchTopKTrainingSAE(cfg)

train_toy_sae(sae_transition_up, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_transition_up, toy_model, "SAE L0 Increase 2 to 11", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_transition_up, toy_model, title="SAE L0 Increase 2 $\\to$ 11", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0/sae_l0_2_to_11_decoder_cos_sims.pdf")