# Toy Model Experiments: JumpReLU edition

We focus on BatchTopK SAEs in the paper since its easier to directly control their L0, but we show here that the same findings still hold for JumpReLU SAEs as well. We just run the nth decoder projection experiment here, showing that JumpReLU SAEs also minimize this metric at the correct L0.

In [None]:
from functools import partial
from pathlib import Path
import torch
import pandas as pd
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

from sparse_but_wrong.toy_models.get_training_batch import generate_random_correlation_matrix, get_training_batch
from sparse_but_wrong.toy_models.toy_model import ToyModel
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT
from sparse_but_wrong.util import DEFAULT_DEVICE

tqdm._instances.clear()  # type: ignore

feat_probs = 0.345 * (50 - torch.arange(50) - 1) / 50 + 0.05
correlations = generate_random_correlation_matrix(
    correlation_strength_range=(0.3, 0.9),
    num_features=50,
    seed=42,
)
indices = torch.arange(50) + 1
df = pd.DataFrame({
    "P_i": feat_probs,
    "feature": map(str, indices.tolist()),
})


toy_model = ToyModel(num_feats=50, hidden_dim=100).to(DEFAULT_DEVICE)


generate_batch = partial(
    get_training_batch,
    firing_probabilities=feat_probs,
    std_firing_magnitudes=torch.ones_like(feat_probs) * 0.15,
    correlation_matrix=correlations,
)

Path("plots/toy_setup_jumprelu").mkdir(parents=True, exist_ok=True)

plt.rcParams.update({"figure.dpi": 150})
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(3, 2))
    sns.barplot(data=df, x="feature", y="P_i")
    plt.xlabel("Feature")
    plt.ylabel("$P_i$")
    plt.title("Feature firing probabilities $P_i$")
    # Increase tick spacing to prevent overlapping
    plt.xticks(range(0, len(df), 5), [str(i) for i in range(0, len(df), 5)])  # Show every 5th tick, 0-indexed
    plt.tight_layout()
    plt.savefig("plots/toy_setup_jumprelu/toy_model_feature_firing_probabilities.pdf")
    plt.show()

plt.rcParams.update({"figure.dpi": 150})
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(2.5, 2))
    sns.heatmap(correlations, cmap="RdBu", center=0, vmin=-1, vmax=1)
    plt.xlabel("Feature")
    plt.ylabel("Feature")
    plt.title("Feature correlation matrix")
    # Increase tick spacing to prevent overlapping
    plt.xticks(range(0, len(correlations), 10), [str(i) for i in range(0, len(correlations), 10)])  # Show every 10th tick, 0-indexed
    plt.yticks(range(0, len(correlations), 10), [str(i) for i in range(0, len(correlations), 10)])  # Show every 10th tick, 0-indexed
    # plt.tight_layout()
    plt.savefig("plots/toy_setup_jumprelu/toy_model_correlation_matrix.pdf")
    plt.show()

## Finding the True L0

Next, we calculate the true L0 for this dataset (spoiler: it's ~11)

In [None]:
sample = generate_batch(100_000)
true_l0 = (sample > 0).float().sum(dim=-1).mean()
print(f"True L0: {true_l0}")

# Can we detect when we're at the correct L0?

Let's train SAEs at a range of L0, ranging from below the correct L0 to above the correct L0

In [None]:
from sae_lens import JumpReLUTrainingSAE, JumpReLUTrainingSAEConfig
from sparse_but_wrong.toy_models.train_toy_sae import train_toy_sae
from sparse_but_wrong.toy_models.eval_sae import eval_sae



saes_by_k = {}
for l1 in [0.05, 0.07, 0.09, 0.1, 0.3, 0.5, 0.75, 1.0, 1.2, 1.3, 1.4, 1.5]:
    print(f"Training SAE with l1={l1}")
    cfg = JumpReLUTrainingSAEConfig(
        l0_coefficient=l1,
        jumprelu_bandwidth=2.0,
        jumprelu_init_threshold=0.1,
        jumprelu_sparsity_loss_mode="tanh",
        l0_warm_up_steps=10_000,
        normalize_activations="expected_average_only_in",
        d_in=toy_model.embed.weight.shape[0],
        d_sae=50,
    )
    sae = JumpReLUTrainingSAE(cfg)
    train_toy_sae(sae, toy_model, generate_batch)
    stats = eval_sae(sae, toy_model, generate_batch)
    print(f"SAE L0: {stats.sae_l0}, dead latents: {stats.dead_features}, L1 coefficient: {l1:.2f}")
    saes_by_k[stats.sae_l0] = sae


In [None]:
from pathlib import Path
from typing import Callable
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from sae_lens import TrainingSAE
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT

# Assuming ToyModel and other variables are defined elsewhere
# from your_module import ToyModel, saes_by_k, toy_model, generate_batch

sorted_dec_hidden_pres = {}

def calc_l0_dec_thresholds(sae: TrainingSAE, model, generate_batch: Callable[[int], torch.Tensor]) -> list[float]:
    inputs = model.embed(generate_batch(100_000))
    hidden_pre_dec = (inputs - sae.b_dec) @ sae.W_dec.T
    sorted_hidden_pre_dec = hidden_pre_dec.flatten().sort(descending=True).values
    k_inds = torch.arange(hidden_pre_dec.shape[-1]) * hidden_pre_dec.shape[0]
    return sorted_hidden_pre_dec[k_inds].tolist()

thresholds = {}
for k, sae in tqdm(saes_by_k.items()):
    thresholds[k] = calc_l0_dec_thresholds(sae, toy_model, generate_batch)

data = []
for l0, l0_thresholds in thresholds.items():
    row = {"l0": l0}
    for i, threshold in enumerate(l0_thresholds):
        row[f"l0_dec_{i}"] = threshold
    data.append(row)
df = pd.DataFrame(data)

Path("plots/dpn_jumprelu").mkdir(parents=True, exist_ok=True)
for k in [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]:
    plt.rcParams.update({"figure.dpi": 150})
    sns.set_theme()
    with plt.rc_context(SEABORN_RC_CONTEXT):
        plt.figure(figsize=(3, 1.5))
        sns.lineplot(data=df, x="l0", y=f"l0_dec_{k}")
        
        # Add vertical line at true L0
        plt.axvline(x=11, color='red', linestyle='--', linewidth=0.5, alpha=0.7, label='True L0')
        
        plt.title(f"N={k} Decoder Projection vs SAE L0")
        plt.xlabel("SAE L0")
        plt.ylabel("$s_n$")
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(f"plots/dpn_jumprelu/dp_{k}.pdf")
        plt.show()

In [None]:
from sparse_but_wrong.toy_models.plotting import SEABORN_RC_CONTEXT
import matplotlib.pyplot as plt
import seaborn as sns

data = []
for k, sae in tqdm(saes_by_k.items()):
    data.append({"l0": k, "l0_coeff": sae.cfg.l0_coefficient})
df = pd.DataFrame(data)

plt.rcParams.update({"figure.dpi": 150})
with plt.rc_context(SEABORN_RC_CONTEXT):
    plt.figure(figsize=(3, 2))
    sns.scatterplot(data=df, x="l0", y="l0_coeff")

    # Add vertical line at true L0
    plt.axvline(x=11, color='red', linestyle='--', linewidth=0.5, alpha=0.7, label='True L0')
    plt.xlabel("SAE L0")
    plt.ylabel("L0 Coefficient")
    plt.title("SAE L0 vs L0 Coefficient for JumpReLU SAEs")
    plt.tight_layout()
    plt.savefig("plots/toy_l0_jumprelu/l0_vs_l0_coeff_jumprelu.pdf")
    plt.show()