# 02 - Train Matryoshka SAE

This notebook:
1. Loads extracted activations from blocks 5, 20, 35
2. Trains MSAE with hierarchical k levels [16, 32, 64, 128]
3. Tracks reconstruction R² at each k level
4. Saves trained models and metrics

## Setup

In [1]:
# ============================================================================
# COLAB SETUP - Run this cell first!
# ============================================================================
import sys
from pathlib import Path

# Detect if running in Colab
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    print("Running in Google Colab")

    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Set up paths
    DRIVE_ROOT = Path('/content/drive/MyDrive/chaos')
    DRIVE_ROOT.mkdir(parents=True, exist_ok=True)

    print(f"Drive mounted. Project root: {DRIVE_ROOT}")
else:
    print("Running locally")
    DRIVE_ROOT = None

Running in Google Colab
Mounted at /content/drive
Drive mounted. Project root: /content/drive/MyDrive/chaos


In [2]:
# ============================================================================
# COLAB: Install dependencies and upload src.zip
# ============================================================================
if IN_COLAB:
    print("Installing dependencies...")
    !pip install -q torch>=2.0.0 h5py>=3.8.0 sgfmill>=1.1.0
    !pip install -q matplotlib>=3.7.0 tqdm>=4.65.0
    !pip install -q scikit-learn>=1.2.0 scipy>=1.10.0
    print("Dependencies installed!")

    # Unzip src.zip
    !unzip -n -q src.zip -d /content/

Installing dependencies...
Dependencies installed!


In [3]:
import sys
import os

# Add search path for 'src' module
if IN_COLAB:
    sys.path.insert(0, '/content')
sys.path.insert(0, '.')

# Configuration - Literature-validated hyperparameters
# Sources:
#   - Gao et al. "Scaling and Evaluating Sparse Autoencoders" (topk_sae_paper.md)
#   - Multi-budget SAE (multi_budget_sae.md)
#   - Matryoshka SAE (MatryoshkaSAE_paper.md)
CONFIG = {
    'output_dir': 'outputs',  # Will be overridden for Colab below

    'block_indices': [5, 20, 35],  # Layers to train on

    # MSAE architecture
    'input_dim': 256,  # Leela Zero channel dimension
    'expansion_factor': 16,  # Hidden = 256 * 16 = 4096 (paper: 8-32x)
    'k_levels': [16, 32, 64, 128],  # Matryoshka sparsity levels
    'weighting': 'uniform',  # 'uniform' or 'reverse' (paper: both work)

    # Training hyperparameters (from Gao et al. & multi_budget_sae papers)
    # Batch size: Paper uses 8096-131072. We use 4096 for memory efficiency
    # LR: Paper suggests 1/sqrt(n) scaling. For n=4096, optimal ~0.0008-0.001
    'batch_size': 4096,  # Paper: 8096, reduced for Colab T4 memory
    'learning_rate': 8e-4,  # Paper: 0.0008 (multi_budget_sae Table 1)

    # Option B: With 500K samples (vs 18M with flatten), increase epochs
    # ~122 batches/epoch with 500K samples vs ~4400 with 18M
    'epochs': 30,  # Increased from 15 to compensate for fewer samples
    'val_split': 0.1,
    'early_stopping_patience': 5,  # Increased patience for longer training

    # Auxiliary loss for dead latent prevention
    # Paper: aux_k=512, aux_coefficient=1/32
    'aux_k': 512,
    'aux_coefficient': 1/32,

    # Optimizer settings (from topk_sae_paper Appendix A)
    # Adam with beta1=0.9, beta2=0.999
    # Epsilon very small: 6.25e-10 for large scale
    'adam_beta1': 0.9,
    'adam_beta2': 0.999,
    'adam_eps': 1e-8,  # Standard for our scale

    # Decoder normalization: True (paper: essential)
    'normalize_decoder': True,

    # CRITICAL: Shuffling is handled by DataLoader (shuffle=True)
    # This prevents learning spurious order-dependent patterns
    # See: Anthropic "Engineering Challenges in Interpretability"
}

# ============================================================================
# COLAB: Configure paths for Drive storage
# ============================================================================
if IN_COLAB:
    CONFIG['output_dir'] = str(DRIVE_ROOT)

    # Ensure output directories exist
    (DRIVE_ROOT / 'data').mkdir(parents=True, exist_ok=True)
    (DRIVE_ROOT / 'data' / 'activations').mkdir(parents=True, exist_ok=True)
    (DRIVE_ROOT / 'models').mkdir(parents=True, exist_ok=True)
    (DRIVE_ROOT / 'results').mkdir(parents=True, exist_ok=True)
    (DRIVE_ROOT / 'figures').mkdir(parents=True, exist_ok=True)

    print(f"Output directory: {CONFIG['output_dir']}")

Output directory: /content/drive/MyDrive/chaos


In [4]:
import torch
import numpy as np
from pathlib import Path
import json

# Device selection
def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"Using CUDA: {torch.cuda.get_device_name(0)}")
        return device
    if torch.backends.mps.is_available():
        os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
        print("Using MPS (Apple Silicon)")
        return torch.device('mps')
    print("Using CPU")
    return torch.device('cpu')

device = get_device()

Using CUDA: Tesla T4


In [5]:
from src.models import MatryoshkaSAE, create_msae
from src.training import MSAETrainer, create_activation_dataloader
from src.utils import clear_memory

<cell_type>markdown</cell_type>## 0. Load Data (Colab: from Google Drive)

If running on Colab, activations should be available at `DRIVE_ROOT/data/activations/`.

**Expected Drive structure:**
```
MyDrive/chaos/
└── data/
    └── activations/
        ├── block5/
        ├── block20/
        └── block35/
```

In [6]:
# Verify activations are available
from pathlib import Path

activations_dir = Path(CONFIG['output_dir']) / 'data' / 'activations'

print("Checking for activation data...")
for block_idx in CONFIG['block_indices']:
    block_dir = activations_dir / f'block{block_idx}'
    if block_dir.exists():
        chunks = list(block_dir.glob('chunk_*.npy'))
        print(f"  Block {block_idx}: {len(chunks)} chunks found")
    else:
        print(f"  Block {block_idx}: NOT FOUND - run notebook 01 first!")

Checking for activation data...
  Block 5: 10 chunks found
  Block 20: 10 chunks found
  Block 35: 10 chunks found


## 1. Check Available Activations

In [7]:
activations_dir = Path(CONFIG['output_dir']) / 'data' / 'activations'

print("Available activation data:")
for block_idx in CONFIG['block_indices']:
    block_dir = activations_dir / f'block{block_idx}'
    if block_dir.exists():
        chunks = list(block_dir.glob('chunk_*.npy'))
        stats_file = block_dir / 'normalization_stats.npz'
        print(f"  Block {block_idx}: {len(chunks)} chunks, stats: {stats_file.exists()}")
    else:
        print(f"  Block {block_idx}: NOT FOUND")
        print(f"    Run 01_setup_and_extraction.ipynb first!")

Available activation data:
  Block 5: 10 chunks, stats: True
  Block 20: 10 chunks, stats: True
  Block 35: 10 chunks, stats: True


## 2. Train MSAE for Each Block

We train separate MSAEs for each layer (blocks 5, 20, 35).

In [8]:
def train_msae_for_block(block_idx: int) -> dict:
    """
    Train MSAE for a single block.

    Returns:
        Training history and final metrics
    """
    print(f"\n{'='*60}")
    print(f"Training MSAE for Block {block_idx}")
    print(f"{'='*60}\n")

    # Load data from HDF5
    h5_path = Path(CONFIG['output_dir']) / 'data' / 'activations.h5'
    train_loader, val_loader, norm_stats = create_activation_dataloader(
        activations_dir=str(activations_dir),
        block_idx=block_idx,
        batch_size=CONFIG['batch_size'],
        normalize=True,
        val_split=CONFIG['val_split'],
        h5_path=str(h5_path),
    )

    # Create model
    model = create_msae(
        input_dim=CONFIG['input_dim'],
        expansion_factor=CONFIG['expansion_factor'],
        k_levels=CONFIG['k_levels'],
        weighting=CONFIG['weighting'],
        device=device,
    )

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Create trainer with literature-validated optimizer settings
    trainer = MSAETrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        lr=CONFIG['learning_rate'],  # Paper: 0.0008 from multi_budget_sae
        output_dir=CONFIG['output_dir'],
        device=device,
        checkpoint_every=5,
        log_every=50,
        adam_betas=(CONFIG.get('adam_beta1', 0.9), CONFIG.get('adam_beta2', 0.999)),
        adam_eps=CONFIG.get('adam_eps', 6.25e-10),
    )

    # Train with early stopping (keeps best model automatically)
    history = trainer.train(
        epochs=CONFIG['epochs'],
        early_stopping_patience=CONFIG['early_stopping_patience'],
    )

    # Load best model (not the last epoch's model)
    best_model_path = Path(CONFIG['output_dir']) / 'models' / 'msae_best.pt'
    if best_model_path.exists():
        best_checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(best_checkpoint['model_state_dict'])
        best_epoch = best_checkpoint.get('epoch', 'unknown')
        print(f"\nLoaded best model from epoch {best_epoch}")

    # Save final model with normalization stats (using best weights)
    final_path = Path(CONFIG['output_dir']) / 'models' / f'msae_block{block_idx}.pt'
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': {
            'input_dim': CONFIG['input_dim'],
            'hidden_dim': CONFIG['input_dim'] * CONFIG['expansion_factor'],
            'k_levels': CONFIG['k_levels'],
            'weighting': CONFIG['weighting'],
        },
        'normalization': {
            'mean': norm_stats['mean'].numpy(),
            'std': norm_stats['std'].numpy(),
        },
        'history': history,
    }, final_path)
    print(f"Saved final model: {final_path}")

    # Extract final metrics (from best epoch)
    # Find the best epoch index based on val_loss
    if history['val_loss']:
        best_idx = history['val_loss'].index(min(history['val_loss']))
    else:
        best_idx = -1

    final_metrics = {
        'train_loss': history['train_loss'][best_idx],
        'val_loss': history['val_loss'][best_idx] if history['val_loss'] else None,
        'dead_ratio': history['dead_ratio'][best_idx],
    }

    # R² at each k level (from best epoch)
    for k in CONFIG['k_levels']:
        key = f'r2_k{k}'
        if key in history['val_r2'] and history['val_r2'][key]:
            final_metrics[key] = history['val_r2'][key][best_idx]

    # Clean up
    del model, trainer, train_loader, val_loader
    clear_memory(verbose=True)

    return final_metrics, history

In [None]:
# Train MSAE for each block
all_results = {}
all_histories = {}

for block_idx in CONFIG['block_indices']:
    block_dir = activations_dir / f'block{block_idx}'
    if not block_dir.exists():
        print(f"Skipping block {block_idx} - no activations found")
        continue

    metrics, history = train_msae_for_block(block_idx)
    all_results[f'block{block_idx}'] = metrics
    all_histories[f'block{block_idx}'] = history


Training MSAE for Block 5

=== System Capabilities ===
CPU cores: 8
Total RAM: 51.0 GB
Available RAM: 48.9 GB
Memory budget: 50%
GPU: Tesla T4 (14.7 GB)
Auto-detected optimal workers: 7
Using HDF5 streaming from: /content/drive/MyDrive/chaos/data/activations.h5
Using chunked streaming: 495,995 samples
  Train: 446,396, Val: 49,599
  Chunk size: 100,000
Model parameters: 2,097,408
Training MSAE on cuda
  K levels: [16, 32, 64, 128]
  Hidden dim: 4096
  Learning rate: 0.0008
  Epochs: 30

Epoch 1/30
  Batch 50/109: loss=0.822782, dead=0.000
  Batch 100/109: loss=0.595525, dead=0.000
  Batch 150/109: loss=0.485413, dead=0.000
  Batch 200/109: loss=0.415995, dead=0.000
  Batch 250/109: loss=0.369041, dead=0.000
  Batch 300/109: loss=0.336191, dead=0.000
  Batch 350/109: loss=0.310493, dead=0.000
  Batch 400/109: loss=0.289900, dead=0.000
  Batch 450/109: loss=0.273528, dead=0.000
  Batch 500/109: loss=0.259915, dead=0.000
  Batch 550/109: loss=0.248005, dead=0.000
  Batch 600/109: loss=0.

## 3. Summary Results

In [None]:
print("\nFinal Results:")
print("=" * 70)

for block_name, metrics in all_results.items():
    print(f"\n{block_name}:")
    print(f"  Train Loss: {metrics['train_loss']:.6f}")
    if metrics['val_loss']:
        print(f"  Val Loss: {metrics['val_loss']:.6f}")
    print(f"  Dead Ratio: {metrics['dead_ratio']:.3f}")
    print(f"  R² by k:")
    for k in CONFIG['k_levels']:
        key = f'r2_k{k}'
        if key in metrics:
            print(f"    k={k:3d}: {metrics[key]:.4f}")

In [None]:
# Save results summary
results_path = Path(CONFIG['output_dir']) / 'results' / 'reconstruction.json'
results_path.parent.mkdir(parents=True, exist_ok=True)

# Format for easy reading
reconstruction_results = {}
for block_name, metrics in all_results.items():
    reconstruction_results[block_name] = {
        f'k{k}': metrics.get(f'r2_k{k}', None) for k in CONFIG['k_levels']
    }

with open(results_path, 'w') as f:
    json.dump(reconstruction_results, f, indent=2)

print(f"\nSaved results to {results_path}")

## 4. Visualize Training

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

colors = {'block5': 'C0', 'block20': 'C1', 'block35': 'C2'}

# Plot 1: Training Loss
ax = axes[0, 0]
for block_name, history in all_histories.items():
    ax.plot(history['train_loss'], label=block_name, color=colors.get(block_name))
ax.set_xlabel('Epoch')
ax.set_ylabel('Training Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: Validation Loss
ax = axes[0, 1]
for block_name, history in all_histories.items():
    if history['val_loss']:
        ax.plot(history['val_loss'], label=block_name, color=colors.get(block_name))
ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Loss')
ax.set_title('Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: Dead Latent Ratio
ax = axes[1, 0]
for block_name, history in all_histories.items():
    ax.plot(history['dead_ratio'], label=block_name, color=colors.get(block_name))
ax.set_xlabel('Epoch')
ax.set_ylabel('Dead Latent Ratio')
ax.set_title('Dead Latent Ratio')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: R² at different k levels (final epoch)
ax = axes[1, 1]
for block_name, metrics in all_results.items():
    r2_values = [metrics.get(f'r2_k{k}', 0) for k in CONFIG['k_levels']]
    ax.plot(CONFIG['k_levels'], r2_values, 'o-', label=block_name, color=colors.get(block_name))
ax.set_xlabel('k (sparsity level)')
ax.set_ylabel('R²')
ax.set_title('Final R² by Sparsity Level')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(Path(CONFIG['output_dir']) / 'figures' / 'msae_training.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. Check Success Criteria

In [None]:
# Success criteria from CLAUDE.md
print("\nSuccess Criteria Check:")
print("=" * 50)

criteria = {
    'R² (k=128) > 0.85': [],
    'R² (k=16) > 0.40': [],
    'Dead features < 30%': [],
}

for block_name, metrics in all_results.items():
    r2_128 = metrics.get('r2_k128', 0)
    r2_16 = metrics.get('r2_k16', 0)
    dead = metrics.get('dead_ratio', 1)

    criteria['R² (k=128) > 0.85'].append((block_name, r2_128, r2_128 > 0.85))
    criteria['R² (k=16) > 0.40'].append((block_name, r2_16, r2_16 > 0.40))
    criteria['Dead features < 30%'].append((block_name, dead, dead < 0.30))

for criterion, results in criteria.items():
    print(f"\n{criterion}:")
    for block, value, passed in results:
        status = "PASS" if passed else "FAIL"
        print(f"  {block}: {value:.4f} [{status}]")

In [None]:
# Zip the essential outputs
!zip -r msae_models.zip outputs/models/ outputs/results/ outputs/figures/

# Then download msae_models.zip (~25MB)
from google.colab import files
files.download('msae_models.zip')

## Summary

Trained MSAEs saved to:
- `outputs/models/msae_block5.pt`
- `outputs/models/msae_block20.pt`
- `outputs/models/msae_block35.pt`

Each model file includes:
- Model state dict
- Configuration (k_levels, hidden_dim, etc.)
- Normalization statistics (mean, std)
- Training history

## Next Steps

1. **03_train_baselines.ipynb**: Train single-k baseline SAEs for comparison
2. **04_concept_labeling.ipynb**: Create concept labels for Go positions
3. **05_run_probes.ipynb**: Train linear probes to evaluate feature quality