# Muon Optimizer: Accelerating Grokking Reproduction

This notebook demonstrates the Muon optimizer and its ability to accelerate the grokking phenomenon compared to AdamW.

## What is Muon?

Muon is an optimizer that combines:
1. **Spectral norm constraints** for better optimization landscapes
2. **Orthogonalized gradients** for improved convergence
3. **Second-order information** for smarter parameter updates

**Key Results:**
- 33% faster grokking (153 → 103 epochs average)
- Statistically significant improvement across 7 modular arithmetic tasks
- Better path from memorization to generalization

## Setup and Installation

In [None]:
! [ ! -d "muon" ] && git clone https://github.com/bangyen/muon.git
! cd muon && pip install -e .

print("Setup complete!")

In [None]:
import os

os.chdir("./muon")
print(f"Current working directory: {os.getcwd()}")

## Imports and Configuration

In [None]:
import numpy as np
import torch
from torch import nn

# Import Muon optimizer
from muon import SingleDeviceMuonWithAuxAdam

# Import local modules
from src.dataset import DatasetConfig, ModularArithmeticDataset
from src.model import GrokkingTransformer, ModelConfig

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Constants
MAX_EPOCHS = 50  # Reduced for quick demo
BATCH_SIZE = 32
MIN_DIMENSIONS = 2  # Minimum dimensions for hidden weights
GROKKING_THRESHOLD = 95.0  # Accuracy threshold for grokking

## ZSharp vs SGD


In [None]:
def quick_grokking_demo():
    """Quick demonstration of Muon vs AdamW on modular addition"""
    print("Muon vs AdamW Grokking Demo")
    print("=" * 40)

    # Create dataset for modular addition (simplest task)
    dataset_config = DatasetConfig(
        task_type="add", modulus=97, train_split=0.8
    )
    dataset = ModularArithmeticDataset(dataset_config)

    # Create model
    model_config = ModelConfig(
        vocab_size=dataset.vocab_size,
        hidden_size=128,
        num_layers=4,
        num_heads=8,
        ff_size=512,
    )

    # Create two identical models
    model_muon = GrokkingTransformer(model_config).to(device)
    model_adamw = GrokkingTransformer(model_config).to(device)

    # Copy weights to ensure identical starting points
    model_adamw.load_state_dict(model_muon.state_dict())

    # Create optimizers
    # Muon optimizer with proper parameter grouping
    hidden_weights = [
        p for p in model_muon.parameters() if p.ndim >= MIN_DIMENSIONS
    ]
    other_params = [
        p for p in model_muon.parameters() if p.ndim < MIN_DIMENSIONS
    ]

    param_groups_muon = [
        dict(params=hidden_weights, use_muon=True, lr=0.02, weight_decay=1e-2),
        dict(
            params=other_params,
            use_muon=False,
            lr=0.002,
            betas=(0.9, 0.95),
            weight_decay=1e-2,
        ),
    ]
    optimizer_muon = SingleDeviceMuonWithAuxAdam(param_groups_muon)

    # AdamW optimizer
    optimizer_adamw = torch.optim.AdamW(
        model_adamw.parameters(), lr=0.001, weight_decay=1e-2
    )

    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding tokens

    # Training loop
    print(f"Training for {MAX_EPOCHS} epochs...")

    muon_grokking_epoch = None
    adamw_grokking_epoch = None

    for epoch in range(MAX_EPOCHS):
        # Train Muon model
        model_muon.train()
        muon_loss = 0
        muon_correct = 0
        muon_total = 0

        # Get validation data and process in batches
        val_data = dataset.get_val_data()[:BATCH_SIZE]

        # Stack individual samples into batches
        input_batch = torch.stack([sample["input"] for sample in val_data]).to(
            device
        )
        target_batch = torch.stack(
            [sample["target"] for sample in val_data]
        ).to(device)

        optimizer_muon.zero_grad()
        outputs = model_muon(input_batch)
        loss = criterion(
            outputs.view(-1, outputs.size(-1)), target_batch.view(-1)
        )
        loss.backward()
        optimizer_muon.step()

        muon_loss += loss.item()
        _, predicted = torch.max(outputs, -1)
        muon_total += target_batch.numel()
        muon_correct += (predicted == target_batch).sum().item()

        muon_acc = 100 * muon_correct / muon_total

        # Train AdamW model
        model_adamw.train()
        adamw_loss = 0
        adamw_correct = 0
        adamw_total = 0

        optimizer_adamw.zero_grad()
        outputs = model_adamw(input_batch)
        loss = criterion(
            outputs.view(-1, outputs.size(-1)), target_batch.view(-1)
        )
        loss.backward()
        optimizer_adamw.step()

        adamw_loss += loss.item()
        _, predicted = torch.max(outputs, -1)
        adamw_total += target_batch.numel()
        adamw_correct += (predicted == target_batch).sum().item()

        adamw_acc = 100 * adamw_correct / adamw_total

        print(
            f"Epoch {epoch + 1:2d}: Muon Acc: {muon_acc:5.1f}%, AdamW Acc: {adamw_acc:5.1f}%"
        )

        # Check for grokking (95% accuracy threshold)
        if muon_grokking_epoch is None and muon_acc >= GROKKING_THRESHOLD:
            muon_grokking_epoch = epoch + 1
            print(f"🎉 Muon grokked at epoch {muon_grokking_epoch}!")

        if adamw_grokking_epoch is None and adamw_acc >= GROKKING_THRESHOLD:
            adamw_grokking_epoch = epoch + 1
            print(f"🎉 AdamW grokked at epoch {adamw_grokking_epoch}!")

    # Results
    print("\n" + "=" * 40)
    print("RESULTS:")
    print(f"Muon grokking epoch:   {muon_grokking_epoch or 'Not achieved'}")
    print(f"AdamW grokking epoch:  {adamw_grokking_epoch or 'Not achieved'}")

    if muon_grokking_epoch and adamw_grokking_epoch:
        speedup = adamw_grokking_epoch / muon_grokking_epoch
        print(f"Speedup: {speedup:.2f}x faster with Muon")
    elif muon_grokking_epoch:
        print("✅ Muon achieved grokking, AdamW did not")
    elif adamw_grokking_epoch:
        print("❌ AdamW achieved grokking, Muon did not")
    else:
        print("⚠️  Neither optimizer achieved grokking in this quick demo")


# Run the demo
quick_grokking_demo()

## Key Takeaways

This demo shows how Muon accelerates grokking:

1. **Faster Convergence**: Muon typically reaches grokking in fewer epochs than AdamW
2. **Better Optimization**: Spectral norm constraints and orthogonalized gradients improve the optimization landscape
3. **Grokking Phenomenon**: Models transition from memorization to true understanding

### Next Steps

- Run full experiments: `make run-experiments`
- Try different tasks: `python -m scripts.train_tasks --single_task`
- Analyze results: `make analyze`

For more information, see the [GitHub repository](https://github.com/bangyen/muon) and the [original paper](https://arxiv.org/abs/2504.16041).