# Open source SAE L0 analysis

Looking through the open-source SAEs with known L0 from SAELens / Neuronpedia.


In [None]:
from collections import defaultdict
from sae_lens.loading.pretrained_saes_directory import get_pretrained_saes_directory

saes = get_pretrained_saes_directory()

# Fix the model_l0s data structure to correctly group by model
model_l0s = defaultdict(list)
for sae_info in saes.values():
    if sae_info.expected_l0 is None:
        continue
    for name, l0 in sae_info.expected_l0.items():
        if sae_info.neuronpedia_id.get(name) is None:
            continue
        if l0 > 0:
            model_l0s[sae_info.model].append(l0)

print("Models found:", list(model_l0s.keys()))
print("Number of L0 values per model:")
for model, l0_values in model_l0s.items():
    print(f"  {model}: {len(l0_values)} SAEs")


In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


sns.set_theme()

# Create histograms for each model
models = list(model_l0s.keys())
n_models = len(models)

# Calculate grid dimensions
n_cols = min(3, n_models)  # Max 3 columns
n_rows = (n_models + n_cols - 1) // n_cols  # Ceiling division

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))

# Handle case where we only have one subplot
if n_models == 1:
    axes = [axes]
elif n_rows == 1:
    axes = [axes] if n_models == 1 else axes
else:
    axes = axes.flatten()

for i, model in enumerate(models):
    l0_values = model_l0s[model]
    
    # Create histogram
    axes[i].hist(l0_values, bins=30, alpha=0.7, edgecolor='black')
    axes[i].set_title(f'{model}\n({len(l0_values)} SAEs)')
    axes[i].set_xlabel('L0 (Average number of active features)')
    axes[i].set_ylabel('Number of SAEs')
    axes[i].grid(True, alpha=0.3)
    
    # Add statistics as text
    mean_l0 = np.mean(l0_values)
    median_l0 = np.median(l0_values)
    axes[i].axvline(mean_l0, color='red', linestyle='--', alpha=0.7, label=f'Mean: {mean_l0:.1f}')
    axes[i].axvline(median_l0, color='orange', linestyle='--', alpha=0.7, label=f'Median: {median_l0:.1f}')
    axes[i].legend(fontsize=8)

# Hide any unused subplots
for i in range(n_models, len(axes)):
    axes[i].set_visible(False)

plt.tight_layout()
Path("plots").mkdir(parents=True, exist_ok=True)
plt.savefig("plots/open_source_saes_analysis.pdf")
plt.show()
