# Small toy model experiments

In this notebook, we explore the effect of L0 on SAEs when there are correlated features with a tiny toy model where everything is easily visible.

# Positive correlations

We'll start with a simple toy model with 5 features, and a correlation matrix where all features are positively correlated with feature 0. All other features are uncorrelated with each other.

In [None]:
from functools import partial
import torch
from tqdm import tqdm

from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch
from sparse_but_wrong.toy_models.toy_model import ToyModel
from sparse_but_wrong.toy_models.plotting import plot_correlation_matrix
from sparse_but_wrong.util import DEFAULT_DEVICE

tqdm._instances.clear()  # type: ignore

# Set up the topy model to have L0=2
feat_probs = torch.ones(5) * 2 / 5
pos_correlations = torch.zeros(5, 5)
pos_correlations[:, 0] = 0.4
pos_correlations[0, :] = 0.4
pos_correlations.fill_diagonal_(1.0)

toy_model = ToyModel(num_feats=5, hidden_dim=20).to(DEFAULT_DEVICE)


generate_batch_pos = partial(
    get_training_batch,
    firing_probabilities=feat_probs,
    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,
    correlation_matrix=pos_correlations,
)


plot_correlation_matrix(pos_correlations, save_path="plots/toy_setup_small/positive_correlation_matrix.pdf")

## Finding the True L0

Next, we calculate the true L0 for this dataset (spoiler: it's ~2)

In [None]:
sample = generate_batch_pos(100_000)
true_l0 = (sample > 0).float().sum(dim=-1).mean()
print(f"True L0: {true_l0}")

## Training an SAE with the right L0

Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=2, d_in=toy_model.embed.weight.shape[0], d_sae=5)
sae_full = BatchTopKTrainingSAE(cfg)

# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.
train_toy_sae(sae_full, toy_model, generate_batch_pos)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_full, toy_model, "SAE L0 = True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/pos_corr_sae_l0_eq_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/pos_corr_sae_l0_eq_true_l0_decoder_cos_sims.png")

As we can see above, the SAE has learned the correct features despite the feature correlations.

## What if reduce the L0 slightly below the number of true features?

If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. We'll reduce the SAE L0 by 10%, from 2 to 1.8 (BatchTopK SAEs have no problem with fractional L0s, which is nice for our experiments). We'll further initialize the SAE to the ground truth correct solution, so we can be certain that anything that results is a result of gradient pressure rather than just a poor local minimum.

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae
from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model

cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)
sae_narrow = BatchTopKTrainingSAE(cfg)

init_sae_to_match_model(sae_narrow, toy_model)

train_toy_sae(sae_narrow, toy_model, generate_batch_pos)

In [None]:
import plotly.express as px
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_narrow, toy_model, "SAE L0 < True L0", reorder_features=True)
px.imshow(pos_correlations, width=400, height=400, color_continuous_scale="RdBu", zmin=-1, zmax=1, title="Feature correlation matrix", origin="lower").show()
plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title="SAE L0 $<$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/pos_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf")




We see that all SAE latents contain a strong positive component of feature 0, but no component of any other feautre, exactly matching the correlation matrix. Clearly the correlations are what is causing the SAE to mix feature components together.

# Negative correlations

Next, we'll change the correlation matrix so all features are negatively correlated with feature 0 instead of positively correlated. We'll keep the same toy model as before, we're just changing feature firing correlations.

In [None]:
from functools import partial
import torch

from sparse_but_wrong.toy_models.get_training_batch import  get_training_batch
from sparse_but_wrong.toy_models.plotting import plot_correlation_matrix

# Set up the topy model to have L0=2
feat_probs = torch.ones(5) * 2 / 5
neg_correlations = torch.zeros(5, 5)
neg_correlations[:, 0] = -0.4
neg_correlations[0, :] = -0.4
neg_correlations.fill_diagonal_(1.0)


generate_batch_neg = partial(
    get_training_batch,
    firing_probabilities=feat_probs,
    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,
    correlation_matrix=neg_correlations,
)

plot_correlation_matrix(neg_correlations, save_path="plots/toy_setup_small/negative_correlation_matrix.pdf")

## Finding the True L0

Let's double-check that the true L0 is still 2, as before.

In [None]:
sample = generate_batch_neg(100_000)
true_l0 = (sample > 0).float().sum(dim=-1).mean()
print(f"True L0: {true_l0}")

## Training an SAE with the right L0

Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=2, d_in=toy_model.embed.weight.shape[0], d_sae=5)
sae_full = BatchTopKTrainingSAE(cfg)

# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.
train_toy_sae(sae_full, toy_model, generate_batch_neg)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_full, toy_model, "SAE L0 = True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/neg_corr_sae_l0_eq_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/neg_corr_sae_l0_eq_true_l0_decoder_cos_sims.png")

As we can see above, the SAE has learned the correct features despite the feature correlations.

## What if reduce the L0 slightly below the number of true features?

If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. We'll reduce the SAE L0 by 10%, from 2 to 1.8 (BatchTopK SAEs have no problem with fractional L0s, which is nice for our experiments). We'll further initialize the SAE to the ground truth correct solution, so we can be certain that anything that results is a result of gradient pressure rather than just a poor local minimum.

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae
from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model

cfg = BatchTopKTrainingSAEConfig(k=1.8, d_in=toy_model.embed.weight.shape[0], d_sae=5)
sae_narrow = BatchTopKTrainingSAE(cfg)

init_sae_to_match_model(sae_narrow, toy_model)

train_toy_sae(sae_narrow, toy_model, generate_batch_neg)

In [None]:
import plotly.express as px
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_narrow, toy_model, "SAE L0 < True L0", reorder_features=True)
px.imshow(neg_correlations, width=400, height=400, color_continuous_scale="RdBu", zmin=-1, zmax=1, title="Feature correlation matrix", origin="lower").show()
plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title="SAE L0 $<$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, decoder_title=None, save_path="plots/toy_l0_small/neg_corr_sae_l0_lt_true_l0_decoder_cos_sims.pdf")