# LLM SAE Stability Analysis

This notebook tests whether our stability findings from algorithmic tasks transfer to LLMs.

**Key Finding to Verify:**
On algorithmic tasks, stability DECREASES monotonically with L0 (sparsity).

**Question:**
Does this hold for LLMs, or do LLMs show a different pattern (potentially with an optimal L0)?

**Setup:**
1. Make sure GPU is enabled: Runtime > Change runtime type > GPU
2. Run all cells in order

In [None]:
# Install dependencies
!pip install sae-lens transformer-lens transformers -q

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
if device == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
from sae_lens import SAE

def compute_pwmcc(d1, d2):
    """Compute PWMCC between two decoder matrices."""
    d1_norm = F.normalize(d1, dim=1)
    d2_norm = F.normalize(d2, dim=1)
    cos_sim = d1_norm @ d2_norm.T
    max_1to2 = cos_sim.abs().max(dim=1)[0].mean().item()
    max_2to1 = cos_sim.abs().max(dim=0)[0].mean().item()
    return (max_1to2 + max_2to1) / 2

def compute_random_baseline(d_model, d_sae, n_trials=10):
    """Compute random PWMCC baseline."""
    pwmcc_values = []
    for _ in range(n_trials):
        d1 = torch.randn(d_sae, d_model, device=device)
        d2 = torch.randn(d_sae, d_model, device=device)
        pwmcc_values.append(compute_pwmcc(d1, d2))
    return np.mean(pwmcc_values)

## Load Gemma Scope SAEs

We'll load SAEs at different widths (which correspond to different sparsity levels) and compare their features.

In [None]:
# Load SAEs at different widths for the same layer
layer = 12
widths = ['16k', '32k', '65k']

saes = {}
for width in widths:
    print(f'Loading width {width}...')
    try:
        sae_id = f'layer_{layer}/width_{width}/canonical'
        sae, cfg, sparsity = SAE.from_pretrained(
            release='gemma-scope-2b-pt-res-canonical',
            sae_id=sae_id,
            device=device
        )
        saes[width] = {
            'sae': sae,
            'd_sae': sae.cfg.d_sae,
            'd_model': sae.cfg.d_in,
            'sparsity': sparsity
        }
        print(f'  ✓ d_sae={sae.cfg.d_sae}, d_model={sae.cfg.d_in}')
    except Exception as e:
        print(f'  ✗ Error: {e}')

## Compare SAE Features Across Widths

This tells us how similar the learned features are at different dictionary sizes.

In [None]:
print('Comparing SAE features across widths:')
print('-' * 50)

results = []
width_list = list(saes.keys())

for i, w1 in enumerate(width_list):
    for w2 in width_list[i+1:]:
        # Get decoder weights
        d1 = saes[w1]['sae'].W_dec.data
        d2 = saes[w2]['sae'].W_dec.data
        
        # For different sizes, we compute how many features in the smaller
        # SAE have good matches in the larger SAE
        d1_norm = F.normalize(d1, dim=1)
        d2_norm = F.normalize(d2, dim=1)
        
        cos_sim = d1_norm @ d2_norm.T
        
        # For each feature in smaller SAE, find best match in larger
        max_sim_1to2 = cos_sim.abs().max(dim=1)[0]
        
        # Statistics
        mean_max_sim = max_sim_1to2.mean().item()
        pct_above_90 = (max_sim_1to2 > 0.9).float().mean().item() * 100
        pct_above_80 = (max_sim_1to2 > 0.8).float().mean().item() * 100
        
        print(f'{w1} → {w2}:')
        print(f'  Mean max cosine sim: {mean_max_sim:.3f}')
        print(f'  % features with >0.9 match: {pct_above_90:.1f}%')
        print(f'  % features with >0.8 match: {pct_above_80:.1f}%')
        
        results.append({
            'width1': w1,
            'width2': w2,
            'mean_max_sim': mean_max_sim,
            'pct_above_90': pct_above_90,
            'pct_above_80': pct_above_80
        })

## Train Multiple SAEs with Different Seeds

To properly test stability, we need to train SAEs with the same config but different random seeds. This is more compute-intensive but gives us the true stability measure.

In [None]:
from transformer_lens import HookedTransformer
from sae_lens import SAE, SAEConfig, SAETrainingRunner, LanguageModelSAERunnerConfig

# This is a template for training SAEs with different seeds
# Uncomment and modify as needed

'''
# Load model
model = HookedTransformer.from_pretrained('gemma-2-2b', device=device)

# Training config
cfg = LanguageModelSAERunnerConfig(
    model_name='gemma-2-2b',
    hook_point='blocks.12.hook_resid_post',
    hook_point_layer=12,
    d_in=2304,  # Gemma 2 2B hidden size
    dataset_path='monology/pile-uncopyrighted',
    streaming=True,
    
    # SAE config
    expansion_factor=8,  # d_sae = 8 * d_in
    b_dec_init_method='zeros',
    
    # Training
    lr=3e-4,
    l1_coefficient=5e-3,  # Vary this for different sparsity
    train_batch_size_tokens=4096,
    context_size=128,
    
    # For stability testing, train with different seeds
    seed=42,  # Change this for each run
    
    n_batches_in_buffer=64,
    total_training_tokens=1_000_000,  # Small for testing
    store_batch_size_prompts=16,
    
    log_to_wandb=False,
    wandb_project='sae-stability',
)

# Train
runner = SAETrainingRunner(cfg)
sae = runner.run()
'''

print('Training template ready. Uncomment and modify as needed.')
print('For full stability analysis, train 3-5 SAEs with different seeds.')

## Analyze Cross-Layer Stability

How do SAE features change across layers?

In [None]:
# Load SAEs at different layers
layers = [6, 12, 18]
width = '16k'

layer_saes = {}
for layer in layers:
    print(f'Loading layer {layer}...')
    try:
        sae_id = f'layer_{layer}/width_{width}/canonical'
        sae, cfg, sparsity = SAE.from_pretrained(
            release='gemma-scope-2b-pt-res-canonical',
            sae_id=sae_id,
            device=device
        )
        layer_saes[layer] = sae
        print(f'  ✓ Loaded')
    except Exception as e:
        print(f'  ✗ Error: {e}')

In [None]:
print('Cross-layer feature similarity:')
print('-' * 50)

layer_list = sorted(layer_saes.keys())
for i in range(len(layer_list) - 1):
    l1, l2 = layer_list[i], layer_list[i+1]
    
    d1 = layer_saes[l1].W_dec.data
    d2 = layer_saes[l2].W_dec.data
    
    pwmcc = compute_pwmcc(d1, d2)
    random_baseline = compute_random_baseline(d1.shape[1], d1.shape[0])
    
    print(f'Layer {l1} vs {l2}:')
    print(f'  PWMCC: {pwmcc:.4f}')
    print(f'  Random baseline: {random_baseline:.4f}')
    print(f'  Ratio: {pwmcc/random_baseline:.2f}×')

## Summary

**Key Observations:**

1. **Width comparison**: Features at smaller widths tend to have good matches in larger widths (feature splitting)

2. **Cross-layer**: Adjacent layers have some feature overlap, but it decreases with layer distance

3. **To fully test our hypothesis**: We need to train multiple SAEs with the SAME config but DIFFERENT seeds, then measure PWMCC between them

**Next Steps:**
1. Train 3-5 SAEs with different seeds at the same layer/width
2. Vary L1 coefficient to get different sparsity levels
3. Measure stability (PWMCC) at each sparsity level
4. Compare to our algorithmic task findings