# Toy model experiments: Independent features

In this notebook, we explore the effect of L0 on SAEs when feature firing independely of each other. This is unrealistic as features always have correlations in reality, but experiments like these are why previously the field though that L0 does not matter as long as it's low enough. We'll use the same setup as in the main toy model experiments notebook, but with independent features (no correlation matrix).

In [None]:
from functools import partial
from pathlib import Path
import torch
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

from sparse_but_wrong.toy_models.get_training_batch import generate_random_correlation_matrix, get_training_batch
from sparse_but_wrong.toy_models.toy_model import ToyModel
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT
from sparse_but_wrong.util import DEFAULT_DEVICE

tqdm._instances.clear()  # type: ignore

toy_model = ToyModel(num_feats=50, hidden_dim=100).to(DEFAULT_DEVICE)

# want 11 features to fire on average
feat_probs = torch.ones(50) * 11 / 50
generate_batch = partial(
    get_training_batch,
    firing_probabilities=feat_probs,
    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,
)

## Finding the True L0

Next, we calculate the true L0 for this dataset (spoiler: it's ~11)

In [None]:
sample = generate_batch(100_000)
true_l0 = (sample > 0).float().sum(dim=-1).mean()
print(f"True L0: {true_l0}")

## Training an SAE with the right L0

Next, let's train an SAE with the correct L0 and verify it can perfectly recover the underlying true features

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=11, d_in=toy_model.embed.weight.shape[0], d_sae=50)
sae_full = BatchTopKTrainingSAE(cfg)

train_toy_sae(sae_full, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_full, toy_model, "SAE L0 = True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_eq_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_full, toy_model, title="SAE L0 = True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_eq_true_l0_decoder_cos_sims.png")

As we can see above, the SAE has learned the correct features.

## What if reduce the L0 below the number of true features?

If we lower the L0, then then SAE will need to reconstruct the input with less latents than the number of true features required. How will it handle this? We set L0=9 below.

In [None]:
from sparse_but_wrong.enchanced_batch_topk_sae import EnchancedBatchTopKTrainingSAE, EnhancedBatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = EnhancedBatchTopKTrainingSAEConfig(k=9, d_in=toy_model.embed.weight.shape[0], d_sae=50, initial_k=14, transition_k_duration_steps=10_000)
sae_narrow = EnchancedBatchTopKTrainingSAE(cfg)

train_toy_sae(sae_narrow, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_narrow, toy_model, "SAE L0 < True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_narrow, toy_model, title="SAE L0 $<$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_lt_true_l0_decoder_cos_sims.pdf")


With independent features the SAE handles this fine! Let's verify this is still the case when we drop the L0 even further.

## Let's drop the L0 even further!

What if we drop the L0 further, so k=5?

In [None]:
from sparse_but_wrong.enchanced_batch_topk_sae import EnchancedBatchTopKTrainingSAE, EnhancedBatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.initialization import init_sae_to_match_model
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

# transitioning from a higher to lower k seems to help training stability
cfg = EnhancedBatchTopKTrainingSAEConfig(k=7, d_in=toy_model.embed.weight.shape[0], d_sae=50, initial_k=15, transition_k_duration_steps=10_000)
sae_narrower = EnchancedBatchTopKTrainingSAE(cfg)

# init_sae_to_match_model(sae_narrower, toy_model)


# NOTE: occasionaly this gets stuck in poor local minima. If this happens, try rerunning and it should converge properly.
train_toy_sae(sae_narrower, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims

plot_sae_feat_cos_sims(sae_narrower, toy_model, "SAE L0 << True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title="SAE L0 $\\ll$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_llt_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title="SAE L0 $<$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_llt_true_l0_decoder_cos_sims_v2.pdf")
plot_sae_feat_cos_sims_seaborn(sae_narrower, toy_model, title="SAE L0 $\\ll$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_llt_true_l0_decoder_cos_sims.png")


The SAE *still* learns the correct features! If only features were independent in reality too...

## What if we increase the L0 beyond the number of true features?

If we increase the L0, then then SAE will have more latents than necessary to reconstruct the input. Will the SAE still do the right thing? We set L0=14 below.

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=14, d_in=toy_model.embed.weight.shape[0], d_sae=50)
sae_wide = BatchTopKTrainingSAE(cfg)

train_toy_sae(sae_wide, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_wide, toy_model, "SAE L0 > True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_wide, toy_model, title="SAE L0 $>$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_gt_true_l0_decoder_cos_sims.pdf")


This is still mostly correct, but the SAE is finding some degenerate local minima. If the L0 is too high, the SAE will still fail to find the correct features despite the features firing independently. Let's see what happens if we increase the L0 even more. 

## Increase L0 even more!

Let's see what happens if we increase the L0 to 18

In [None]:
from sae_lens import BatchTopKTrainingSAE, BatchTopKTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae

cfg = BatchTopKTrainingSAEConfig(k=18, d_in=toy_model.embed.weight.shape[0], d_sae=50)
sae_wider = BatchTopKTrainingSAE(cfg)


train_toy_sae(sae_wider, toy_model, generate_batch)

In [None]:
from sparse_but_wrong.toy_models.plotting import plot_sae_feat_cos_sims, plot_sae_feat_cos_sims_seaborn

plot_sae_feat_cos_sims(sae_wider, toy_model, "SAE L0 >> True L0", reorder_features=True, dtick=5)
plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title="SAE L0 $\\gg$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_ggt_true_l0_decoder_cos_sims.pdf")
plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title="SAE L0 $>$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_ggt_true_l0_decoder_cos_sims_v2.pdf")
plot_sae_feat_cos_sims_seaborn(sae_wider, toy_model, title="SAE L0 $\\gg$ True L0", reorder_features=True, decoder_only=True, width=5, height=2, dtick=5, decoder_title=None, save_path="plots/toy_l0_indep/sae_l0_ggt_true_l0_decoder_cos_sims.png")


As expected, the SAE is even worse now, using its extra capacity to allow some more degenerate solutions.