# Small toy model experiments

In this notebook, we explore the effect of L0 on SAEs when there are correlated features with a tiny toy model where everything is easily visible.

# Positive correlations

We'll start with a simple toy model with 5 features, and a correlation matrix where all features are positively correlated with feature 0. All other features are uncorrelated with each other.

In [None]:
from functools import partial
from pathlib import Path
import torch
from tqdm import tqdm

from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch
from sparse_but_wrong.toy_models.toy_model import ToyModel
from sparse_but_wrong.toy_models.plotting import plot_correlation_matrix
from sparse_but_wrong.util import DEFAULT_DEVICE

tqdm._instances.clear()  # type: ignore

# Set up the topy model to have L0=2
feat_probs = torch.ones(5) * 2 / 5
pos_correlations = torch.zeros(5, 5)
pos_correlations[:, 0] = 0.4
pos_correlations[0, :] = 0.4
pos_correlations.fill_diagonal_(1.0)

if Path("small_toy_model.pt").exists():
    print("Loading toy model from disk")
    toy_model = torch.load("small_toy_model.pt", weights_only=False)
else:
    toy_model = ToyModel(num_feats=5, hidden_dim=20).to(DEFAULT_DEVICE)
    torch.save(toy_model, "small_toy_model.pt")


generate_batch_pos = partial(
    get_training_batch,
    firing_probabilities=feat_probs,
    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,
    correlation_matrix=pos_correlations,
)


plot_correlation_matrix(pos_correlations, save_path="plots/toy_setup_small/positive_correlation_matrix.pdf")

## Finding the True L0

Next, we calculate the true L0 for this dataset (spoiler: it's ~2)

In [None]:
sample = generate_batch_pos(100_000)
true_l0 = (sample > 0).float().sum(dim=-1).mean()
print(f"True L0: {true_l0}")

## Training an SAE with the right L0

Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=2, d_in=toy_model.embed.weight.shape[0], d_sae=5)
sae_full = BatchTopKTrainingSAE(cfg)

# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.
train_toy_sae(sae_full, toy_model, generate_batch_pos)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_full, toy_model, "SAE L0 = True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/pos_corr_sae_l0_eq_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/pos_corr_sae_l0_eq_true_l0_decoder_cos_sims.png")

As we can see above, the SAE has learned the correct features despite the feature correlations.

## What if reduce the L0 slightly below the number of true features?

If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. We'll reduce the SAE L0 by 10%, from 2 to 1.8 (BatchTopK SAEs have no problem with fractional L0s, which is nice for our experiments). We'll further initialize the SAE to the ground truth correct solution, so we can be certain that anything that results is a result of gradient pressure rather than just a poor local minimum.

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae
from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model

cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)
sae_narrow = BatchTopKTrainingSAE(cfg)

init_sae_to_match_model(sae_narrow, toy_model)

train_toy_sae(sae_narrow, toy_model, generate_batch_pos)

In [None]:
import plotly.express as px
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_narrow, toy_model, "SAE L0 < True L0", reorder_features=True)
px.imshow(pos_correlations, width=400, height=400, color_continuous_scale="RdBu", zmin=-1, zmax=1, title="Feature correlation matrix", origin="lower").show()
plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title="SAE L0 $<$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/pos_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf")




We see that all SAE latents contain a strong positive component of feature 0, but no component of any other feautre, exactly matching the correlation matrix. Clearly the correlations are what is causing the SAE to mix feature components together.

# Negative correlations

Next, we'll change the correlation matrix so all features are negatively correlated with feature 0 instead of positively correlated. We'll keep the same toy model as before, we're just changing feature firing correlations.

In [None]:
from functools import partial
import torch

from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch
from sparse_but_wrong.toy_models.plotting import plot_correlation_matrix

# Set up the topy model to have L0=2
feat_probs = torch.ones(5) * 2 / 5
neg_correlations = torch.zeros(5, 5)
neg_correlations[:, 0] = -0.4
neg_correlations[0, :] = -0.4
neg_correlations.fill_diagonal_(1.0)


generate_batch_neg = partial(
    get_training_batch,
    firing_probabilities=feat_probs,
    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,
    correlation_matrix=neg_correlations,
)

plot_correlation_matrix(neg_correlations, save_path="plots/toy_setup_small/negative_correlation_matrix.pdf")

## Finding the True L0

Let's double-check that the true L0 is still 2, as before.

In [None]:
sample = generate_batch_neg(100_000)
true_l0 = (sample > 0).float().sum(dim=-1).mean()
print(f"True L0: {true_l0}")

## Training an SAE with the right L0

Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=2, d_in=toy_model.embed.weight.shape[0], d_sae=5)
sae_full = BatchTopKTrainingSAE(cfg)

# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.
train_toy_sae(sae_full, toy_model, generate_batch_neg)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_full, toy_model, "SAE L0 = True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/neg_corr_sae_l0_eq_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/neg_corr_sae_l0_eq_true_l0_decoder_cos_sims.png")

As we can see above, the SAE has learned the correct features despite the feature correlations.

## What if reduce the L0 slightly below the number of true features?

If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. We'll reduce the SAE L0 by 10%, from 2 to 1.8 (BatchTopK SAEs have no problem with fractional L0s, which is nice for our experiments). We'll further initialize the SAE to the ground truth correct solution, so we can be certain that anything that results is a result of gradient pressure rather than just a poor local minimum.

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae
from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model

cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)
sae_narrow = BatchTopKTrainingSAE(cfg)

init_sae_to_match_model(sae_narrow, toy_model)

train_toy_sae(sae_narrow, toy_model, generate_batch_neg)

In [None]:
import plotly.express as px
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_narrow, toy_model, "SAE L0 < True L0", reorder_features=True)
px.imshow(neg_correlations, width=400, height=400, color_continuous_scale="RdBu", zmin=-1, zmax=1, title="Feature correlation matrix", origin="lower").show()
plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title="SAE L0 $<$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/neg_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf")

# How does the degree of mixing change with amount of correlation?

We'll see what happens when we vary the correlation with feature 0 from -0.5 to 0.5

In [None]:
from collections import defaultdict
from functools import partial
from pathlib import Path
import torch

from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch


def get_generator_for_correlation(correlation_strength: float):
    # Set up the topy model to have L0=2
    feat_probs = torch.ones(5) * 2 / 5
    correlations = torch.zeros(5, 5)
    correlations[:, 0] = correlation_strength
    correlations[0, :] = correlation_strength
    correlations.fill_diagonal_(1.0)

    return partial(
        get_training_batch,
        firing_probabilities=feat_probs,
        std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,
        correlation_matrix=correlations,
    )


saes_by_correlation = defaultdict(list)
for seed in [0, 1, 2, 3, 4]:
    for corr in [-0.5, -0.4, -0.3, -0.2, -0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5]:
        Path(f"saes_by_correlation/seed_{seed}").mkdir(parents=True, exist_ok=True)
        if Path(f"saes_by_correlation/seed_{seed}/{corr}").exists():
            print(f"Loading SAE with corr={corr}, seed={seed} from disk")
            sae = BatchTopKTrainingSAE.load_from_disk(f"saes_by_correlation/seed_{seed}/{corr}")
        else:
            print(f"Training SAE with corr={corr}, seed={seed}")
            cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)
            sae = BatchTopKTrainingSAE(cfg)
            init_sae_to_match_model(sae, toy_model)
            train_toy_sae(sae, toy_model, get_generator_for_correlation(corr))
            sae.W_dec.data = sae.W_dec.data.contiguous()
            sae.save_model(f"saes_by_correlation/seed_{seed}/{corr}")
        saes_by_correlation[corr].append(sae)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
from sparse_but_wrong.util import cos_sims
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT

# For each correlation value, calculate mean cosine similarity of latents 1-3 with feature 0
correlation_values = sorted(saes_by_correlation.keys())
mean_cos_sims_per_corr = []
std_cos_sims_per_corr = []

for corr in correlation_values:
    saes = saes_by_correlation[corr]
    cos_sims_for_seeds = []
    
    for sae in saes:
        # Calculate cosine similarities between SAE decoder and true features
        dec_cos_sims = cos_sims(sae.W_dec.T, toy_model.embed.weight)
        
        # Get cosine similarities for latents 1, 2, 3 with feature 0
        latents_1_3_with_feat_0 = dec_cos_sims[[1, 2, 3], 0]
        
        # Take the mean
        mean_cos_sim = latents_1_3_with_feat_0.mean().item()
        cos_sims_for_seeds.append(mean_cos_sim)
    
    # Average across seeds
    mean_cos_sims_per_corr.append(np.mean(cos_sims_for_seeds))
    std_cos_sims_per_corr.append(np.std(cos_sims_for_seeds))

# Convert to numpy arrays for easier manipulation
mean_cos_sims_per_corr = np.array(mean_cos_sims_per_corr)
std_cos_sims_per_corr = np.array(std_cos_sims_per_corr)

# Plot
Path("plots/toy_l0_small").mkdir(parents=True, exist_ok=True)
plt.rcParams.update({"figure.dpi": 150})
sns.set_theme()
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(3, 1.5))
    
    # Plot the mean line
    plt.plot(correlation_values, mean_cos_sims_per_corr, linewidth=1.0)
    
    # Add shaded area for 1 standard deviation
    plt.fill_between(correlation_values, 
                    mean_cos_sims_per_corr - std_cos_sims_per_corr, 
                    mean_cos_sims_per_corr + std_cos_sims_per_corr, 
                    alpha=0.3)
    
    plt.title("Feature mixing vs correlation")
    plt.xlabel("Correlation between $f_0$ and other features")
    plt.ylabel("Mean cos sim")
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('plots/toy_l0_small/mean_cos_sim_vs_correlation.pdf')
    plt.show()

In [None]:
for corr in [-0.5, 0.0, 0.5]:
    plot_sae_feat_cos_sims_seaborn(
        saes_by_correlation[corr][0],
        toy_model,
        title=f"correlation {corr}",
        reorder_features=True,
        decoder_only=True,
        width=5,
        height=2,
        decoder_title=None,
        save_path=f"plots/toy_l0_small/correlation_{corr}.pdf",
    )

# Changing L0 while holding correlation constant

Next, we'll see what happens when we change the L0 while holding the correlation constant. First, we'll train some SAEs with correlation -0.4 and 0.4, as in our earlier experiments, but vary the vary L0 from 1.7 to 2.0 (the true L0 of the toy model is 2.0). We can't go much lower than 1.7 without fully breaking the SAE, as dropping L0 too far causes the SAE latents become so distorted that they bear almost no resemblance at all to the true features, making a clean measurement of correlation between latents 1 - 3 and feature 0 difficult.

In [None]:
from collections import defaultdict
from pathlib import Path

saes_by_l0 = {
    "pos": defaultdict(list),
    "neg": defaultdict(list),
}
POS_CORR = 0.4
NEG_CORR = -0.4

for corr_type, corr in [("pos", POS_CORR), ("neg", NEG_CORR)]:
    for seed in [0, 1, 2, 3, 4]:
        for l0 in [1.7, 1.8, 1.9, 2.0]:
            Path(f"saes_by_l0/{corr_type}/seed_{seed}").mkdir(parents=True, exist_ok=True)
            if Path(f"saes_by_l0/{corr_type}/seed_{seed}/{l0}").exists():
                print(f"Loading SAE with corr={corr}, l0={l0}, seed={seed} from disk")
                sae = BatchTopKTrainingSAE.load_from_disk(f"saes_by_l0/{corr_type}/seed_{seed}/{l0}")
            else:
                print(f"Training SAE with l0={l0}, seed={seed}")
                cfg = BatchTopKTrainingSAEConfig(k=l0, d_in=toy_model.embed.weight.shape[0], d_sae=5)
                sae = BatchTopKTrainingSAE(cfg)
                init_sae_to_match_model(sae, toy_model)
                train_toy_sae(sae, toy_model, get_generator_for_correlation(corr))
                sae.W_dec.data = sae.W_dec.data.contiguous()
                sae.save_model(f"saes_by_l0/{corr_type}/seed_{seed}/{l0}")
            saes_by_l0[corr_type][l0].append(sae)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path
from sparse_but_wrong.util import cos_sims
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT

Path("plots/toy_l0_small").mkdir(parents=True, exist_ok=True)
plt.rcParams.update({"figure.dpi": 150})
sns.set_theme()

# Plot both positive and negative correlation variants
for corr_type, corr_label, corr_value in [("pos", "positive", 0.4), ("neg", "negative", -0.4)]:
    # For each L0 value, calculate mean cosine similarity of latents 1-3 with feature 0
    l0_values = sorted(saes_by_l0[corr_type].keys())
    mean_cos_sims_per_l0 = []
    std_cos_sims_per_l0 = []
    
    for l0 in l0_values:
        saes = saes_by_l0[corr_type][l0]
        cos_sims_for_seeds = []
        
        for sae in saes:
            # Calculate cosine similarities between SAE decoder and true features
            dec_cos_sims = cos_sims(sae.W_dec.T, toy_model.embed.weight)
            
            # Get cosine similarities for latents 1, 2, 3 with feature 0
            latents_1_3_with_feat_0 = dec_cos_sims[[1, 2, 3], 0]
            
            # Take the mean
            mean_cos_sim = latents_1_3_with_feat_0.mean().item()
            cos_sims_for_seeds.append(mean_cos_sim)
        
        # Average across seeds
        mean_cos_sims_per_l0.append(np.mean(cos_sims_for_seeds))
        std_cos_sims_per_l0.append(np.std(cos_sims_for_seeds))
    
    # Convert to numpy arrays for easier manipulation
    mean_cos_sims_per_l0 = np.array(mean_cos_sims_per_l0)
    std_cos_sims_per_l0 = np.array(std_cos_sims_per_l0)
    
    # Plot
    with plt.rc_context(SEABORN_RC_CONTEXT):
        plt.figure(figsize=(3, 1.5))
        
        # Plot the mean line
        plt.plot(l0_values, mean_cos_sims_per_l0, linewidth=1.0)
        
        # Add shaded area for 1 standard deviation
        plt.fill_between(l0_values, 
                        mean_cos_sims_per_l0 - std_cos_sims_per_l0, 
                        mean_cos_sims_per_l0 + std_cos_sims_per_l0, 
                        alpha=0.3)
        
        plt.title(f"Feature mixing vs L0 ({corr_label} correlation)")
        plt.xlabel("SAE L0")
        plt.ylabel("Mean cos sim")
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f'plots/toy_l0_small/mean_cos_sim_vs_l0_{corr_type}.pdf')
        plt.show()

# Superposition noise

Running without superposition noise is the simplest task for an SAE, so there is no reason to expect that if an SAE cannot learn correct latents with no superposition, that it will not have this problem after superposition noise is added. Nevertheless, we demonstrate this for completeness below.

In [None]:
from sparse_but_wrong.toy_models.toy_model import ToyModel
from sparse_but_wrong.util import DEFAULT_DEVICE, cos_sims
import matplotlib.pyplot as plt
import seaborn as sns
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT


super_toy_model = ToyModel(num_feats=5, hidden_dim=20, ortho_num_steps=4).to(DEFAULT_DEVICE)

feature_cos_sims = cos_sims(super_toy_model.embed.weight, super_toy_model.embed.weight).detach()

plt.rcParams.update({"figure.dpi": 150})
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(2.5, 2))
    sns.heatmap(feature_cos_sims, cmap="RdBu", center=0, vmin=-1, vmax=1)
    plt.gca().invert_yaxis()
    plt.xlabel("Feature")
    plt.ylabel("Feature")
    plt.title("Feature cosine similarities (superposition)")
    plt.savefig("plots/toy_setup_small/super_toy_model_cos_sims.pdf")
    plt.show()

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae
from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model

# positive correlations
cfg = BatchTopKTrainingSAEConfig(k=1.9, d_in=super_toy_model.embed.weight.shape[0], d_sae=5)
super_sae_pos = BatchTopKTrainingSAE(cfg)
init_sae_to_match_model(super_sae_pos, super_toy_model)
train_toy_sae(super_sae_pos, super_toy_model, generate_batch_pos)

# negative correlations
cfg = BatchTopKTrainingSAEConfig(k=1.9, d_in=super_toy_model.embed.weight.shape[0], d_sae=5)
super_sae_neg = BatchTopKTrainingSAE(cfg)
init_sae_to_match_model(super_sae_neg, super_toy_model)
train_toy_sae(super_sae_neg, super_toy_model, generate_batch_neg)

In [None]:
import plotly.express as px
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(super_sae_pos, super_toy_model, "SAE, pos correlation, superposition", reorder_features=True)
plot_sae_feat_cos_sims_seaborn(super_sae_pos, super_toy_model, title="SAE, pos correlation, superposition", reorder_features=True, decoder_only=True, width=5, height=2, adjust_for_superposition=True, decoder_title=None, save_path="plots/toy_l0_small/super_pos_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf")

plot_sae_feat_cos_sims(super_sae_neg, super_toy_model, "SAE, neg correlation, superposition", reorder_features=True)
plot_sae_feat_cos_sims_seaborn(super_sae_neg, super_toy_model, title="SAE, neg correlation, superposition", reorder_features=True, decoder_only=True, width=5, height=2, adjust_for_superposition=True, decoder_title=None, save_path="plots/toy_l0_small/super_neg_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf")



In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae
from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model

cfg = BatchTopKTrainingSAEConfig(k=1.9, d_in=super_toy_model.embed.weight.shape[0], d_sae=5)
super_sae_neg = BatchTopKTrainingSAE(cfg)
init_sae_to_match_model(super_sae_neg, super_toy_model)
train_toy_sae(super_sae_neg, super_toy_model, generate_batch_neg)