# Optimizers for Tabular Deep Learning: A Practical Comparison

This tutorial compares four optimizers specifically relevant for tabular neural networks:

1. **AdamW** - Robust baseline with decoupled weight decay
2. **Muon** - Momentum with orthogonalization for stability  
3. **Shampoo** - Second-order method with efficient preconditioning
4. **NovoGrad** - Layer-wise adaptive learning rates

We'll explore their mathematical foundations, practical implementations, and performance on synthetic tabular datasets.

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score
import time
from typing import Dict, List, Tuple

# Import our models and optimizers
from models.optimizers import get_optimizer, get_optimizer_info, recommend_optimizer
from models.base import MLP, TabularModel
from models.tabm import TabM
from data.datasets import load_dataset

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Mathematical Background

### AdamW: Adam with Decoupled Weight Decay

AdamW modifies Adam by decoupling weight decay from gradient-based updates:

$$m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t$$
$$v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2$$
$$\hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t}$$
$$\theta_{t+1} = \theta_t - \alpha \left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_t\right)$$

**Key insight:** Weight decay is applied directly to parameters, not through gradients.

### Muon: Momentum Orthogonalization

Muon orthogonalizes the momentum using Newton-Schulz iteration:

$$\text{Update rule: } \theta_{t+1} = \theta_t - \alpha \cdot \text{Orthogonalize}(m_t)$$

Where orthogonalization is done via Newton-Schulz:
$$G_0 = M, \quad G_{k+1} = \frac{1}{2}G_k(3I - G_k^T G_k)$$

**Key insight:** Orthogonal momentum provides more stable updates.

### Shampoo: Efficient Second-Order Preconditioning

Shampoo maintains separate preconditioners for each tensor dimension:

For a matrix $W \in \mathbb{R}^{m \times n}$ with gradient $G$:
$$H_L = H_L + G G^T, \quad H_R = H_R + G^T G$$
$$W_{t+1} = W_t - \alpha H_L^{-1/4} G H_R^{-1/4}$$

**Key insight:** Adapts to the curvature of each parameter tensor.

### NovoGrad: Layer-wise Adaptive Learning

NovoGrad computes layer-wise moments using gradient norms:

$$g_t^{(l)} = \frac{\nabla_{\theta^{(l)}} \mathcal{L}}{||\nabla_{\theta^{(l)}} \mathcal{L}||_2}$$
$$m_t^{(l)} = \beta_1 m_{t-1}^{(l)} + g_t^{(l)}$$
$$v_t^{(l)} = \beta_2 v_{t-1}^{(l)} + ||\nabla_{\theta^{(l)}} \mathcal{L}||_2^2$$

**Key insight:** Normalizes gradients by layer norm for stable training.

## 2. When to Use Each Optimizer

Let's see the recommendations from our optimizer factory:

In [None]:
# Display optimizer information
optimizer_info = get_optimizer_info()

print("üîß Optimizer Overview:")
print("=" * 50)
for name, info in optimizer_info.items():
    print(f"\n{name.upper()}:")
    print(f"  Description: {info['description']}")
    print(f"  Best for: {info['best_for']}")
    print(f"  Key features: {info['key_features']}")

print("\n" + "=" * 50)
print("\nüéØ Optimizer Recommendations:")
scenarios = [
    ("Small dataset, MLP", {"model_type": "mlp", "dataset_size": "small"}),
    ("Large dataset, mixed features", {"model_type": "mlp", "dataset_size": "large", "feature_types": "mixed"}),
    ("Deep attention model", {"model_type": "attention", "architecture_depth": "deep"}),
    ("Embedding-heavy model", {"model_type": "embedding", "feature_types": "categorical"}),
]

for scenario, kwargs in scenarios:
    recs = recommend_optimizer(**kwargs)
    print(f"\n{scenario}:")
    print(f"  Recommended: {' ‚Üí '.join(recs)}")

## 3. Experimental Setup

We'll compare optimizers on three synthetic datasets that highlight different challenges in tabular data:

In [None]:
# Load datasets for comparison
datasets = {
    'friedman': load_dataset('friedman', n_samples=2000, random_state=42),
    'high_dimensional': load_dataset('high_dimensional', n_samples=2000, n_features=50, 
                                   n_informative=10, random_state=42),
    'nonlinear_interaction': load_dataset('nonlinear_interaction', n_samples=2000, 
                                        n_features=15, random_state=42)
}

print("üìä Dataset Overview:")
print("=" * 40)
for name, dataset in datasets.items():
    info = dataset.info
    print(f"\n{name.upper()}:")
    print(f"  Samples: {info.n_samples}")
    print(f"  Features: {info.n_numerical}")
    print(f"  Description: {info.description}")

## 4. Model Architecture

We'll use TabM (Tabular Ensemble) as our base model since it's one of the strongest performers from our repository:

In [None]:
def create_model(n_features: int, model_type: str = "mlp") -> nn.Module:
    """Create a model for comparison."""
    if model_type == "mlp":
        return MLP(
            d_in=n_features,
            d_out=1,
            n_blocks=3,
            d_block=128,
            dropout=0.1,
            task="regression"
        ).to(device)
    elif model_type == "tabm":
        # Use a simplified TabM configuration
        return TabM(
            d_in=n_features,
            d_out=1,
            n_estimators=5,  # Reduced for faster training
            d_model=64,
            n_layers=2,
            dropout=0.1
        ).to(device)
    else:
        raise ValueError(f"Unknown model type: {model_type}")

# Test model creation
test_model = create_model(10, "mlp")
print(f"‚úÖ Created MLP model with {test_model.count_parameters():,} parameters")

try:
    test_tabm = create_model(10, "tabm")
    print(f"‚úÖ Created TabM model with {test_tabm.count_parameters():,} parameters")
    model_type = "tabm"
except:
    print("‚ö†Ô∏è  TabM not available, using MLP")
    model_type = "mlp"

## 5. Training Function

In [None]:
def train_model(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    X_train: torch.Tensor,
    y_train: torch.Tensor,
    X_val: torch.Tensor,
    y_val: torch.Tensor,
    epochs: int = 200,
    patience: int = 50,
    verbose: bool = False
) -> Dict:
    """Train model and return training history."""
    
    model.train()
    criterion = nn.MSELoss()
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'epoch': [],
        'time_per_epoch': []
    }
    
    best_val_loss = float('inf')
    patience_counter = 0
    start_time = time.time()
    
    for epoch in range(epochs):
        epoch_start = time.time()
        
        # Training
        model.train()
        optimizer.zero_grad()
        
        train_pred = model(X_train)
        train_loss = criterion(train_pred, y_train)
        
        train_loss.backward()
        optimizer.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_pred = model(X_val)
            val_loss = criterion(val_pred, y_val)
        
        epoch_time = time.time() - epoch_start
        
        # Record history
        history['train_loss'].append(train_loss.item())
        history['val_loss'].append(val_loss.item())
        history['epoch'].append(epoch)
        history['time_per_epoch'].append(epoch_time)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss.item()
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            if verbose:
                print(f"Early stopping at epoch {epoch}")
            break
            
        if verbose and (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
    
    total_time = time.time() - start_time
    
    return {
        'history': history,
        'best_val_loss': best_val_loss,
        'total_time': total_time,
        'epochs_trained': len(history['train_loss'])
    }

## 6. Optimizer Comparison Experiment

Now let's run the comprehensive comparison across all optimizers and datasets:

In [None]:
def prepare_data(dataset):
    """Prepare data for training."""
    X = dataset.X_num.numpy()
    y = dataset.y.numpy().reshape(-1, 1)
    
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    X_train, X_val, y_train, y_val = train_test_split(
        X_train, y_train, test_size=0.25, random_state=42
    )
    
    # Standardize features
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)
    X_test = scaler.transform(X_test)
    
    # Convert to tensors
    X_train = torch.FloatTensor(X_train).to(device)
    X_val = torch.FloatTensor(X_val).to(device)
    X_test = torch.FloatTensor(X_test).to(device)
    y_train = torch.FloatTensor(y_train).to(device)
    y_val = torch.FloatTensor(y_val).to(device)
    y_test = torch.FloatTensor(y_test).to(device)
    
    return X_train, X_val, X_test, y_train, y_val, y_test

# Run comparison
optimizers_to_test = ['adamw', 'muon', 'shampoo', 'novograd']
results = {}

print("üöÄ Starting Optimizer Comparison...")
print("=" * 60)

for dataset_name, dataset in datasets.items():
    print(f"\nüìä Dataset: {dataset_name.upper()}")
    
    # Prepare data
    X_train, X_val, X_test, y_train, y_val, y_test = prepare_data(dataset)
    n_features = X_train.shape[1]
    
    results[dataset_name] = {}
    
    for opt_name in optimizers_to_test:
        print(f"  üîß Testing {opt_name.upper()}... ", end="")
        
        try:
            # Create fresh model
            model = create_model(n_features, model_type)
            
            # Create optimizer with dataset-specific learning rates
            lr = 0.001  # Base learning rate
            if opt_name == 'muon':
                lr = 0.01  # Muon typically needs higher lr
            elif opt_name == 'shampoo':
                lr = 0.0001  # Shampoo is more aggressive
            
            optimizer = get_optimizer(opt_name, model, lr=lr)
            
            # Train model
            result = train_model(
                model, optimizer, X_train, y_train, X_val, y_val,
                epochs=300, patience=50, verbose=False
            )
            
            # Test performance
            model.eval()
            with torch.no_grad():
                test_pred = model(X_test)
                test_mse = F.mse_loss(test_pred, y_test).item()
                test_r2 = r2_score(
                    y_test.cpu().numpy(), 
                    test_pred.cpu().numpy()
                )
            
            results[dataset_name][opt_name] = {
                'test_mse': test_mse,
                'test_r2': test_r2,
                'best_val_loss': result['best_val_loss'],
                'total_time': result['total_time'],
                'epochs_trained': result['epochs_trained'],
                'history': result['history'],
                'learning_rate': lr
            }
            
            print(f"‚úÖ MSE: {test_mse:.4f}, R¬≤: {test_r2:.3f}")
            
        except Exception as e:
            print(f"‚ùå Error: {str(e)}")
            results[dataset_name][opt_name] = {
                'test_mse': float('inf'),
                'test_r2': -float('inf'),
                'error': str(e)
            }

print("\n‚úÖ Comparison completed!")

## 7. Results Analysis

In [None]:
# Create results summary table
summary_data = []

for dataset_name, dataset_results in results.items():
    for opt_name, opt_results in dataset_results.items():
        if 'error' not in opt_results:
            summary_data.append({
                'Dataset': dataset_name.replace('_', ' ').title(),
                'Optimizer': opt_name.upper(),
                'Test MSE': f"{opt_results['test_mse']:.4f}",
                'Test R¬≤': f"{opt_results['test_r2']:.3f}",
                'Val Loss': f"{opt_results['best_val_loss']:.4f}",
                'Time (s)': f"{opt_results['total_time']:.1f}",
                'Epochs': opt_results['epochs_trained']
            })

summary_df = pd.DataFrame(summary_data)
print("üìà OPTIMIZER COMPARISON RESULTS")
print("=" * 80)
print(summary_df.to_string(index=False))

## 8. Training Curves Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle('Optimizer Comparison: Training Curves', fontsize=16, fontweight='bold')

colors = {'adamw': '#1f77b4', 'muon': '#ff7f0e', 'shampoo': '#2ca02c', 'novograd': '#d62728'}

for i, (dataset_name, dataset_results) in enumerate(results.items()):
    # Training loss
    ax1 = axes[0, i]
    ax2 = axes[1, i]
    
    for opt_name, opt_results in dataset_results.items():
        if 'history' in opt_results:
            history = opt_results['history']
            epochs = range(len(history['train_loss']))
            
            ax1.plot(epochs, history['train_loss'], 
                    color=colors[opt_name], alpha=0.7, 
                    label=f"{opt_name.upper()} (Train)")
            ax1.plot(epochs, history['val_loss'], 
                    color=colors[opt_name], linestyle='--', 
                    label=f"{opt_name.upper()} (Val)")
    
    ax1.set_title(f'{dataset_name.replace("_", " ").title()} - Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('MSE Loss')
    ax1.set_yscale('log')
    ax1.legend(fontsize=8)
    ax1.grid(True, alpha=0.3)
    
    # Performance comparison (bar plot)
    opt_names = []
    test_r2s = []
    
    for opt_name, opt_results in dataset_results.items():
        if 'test_r2' in opt_results and opt_results['test_r2'] > -float('inf'):
            opt_names.append(opt_name.upper())
            test_r2s.append(opt_results['test_r2'])
    
    bars = ax2.bar(opt_names, test_r2s, color=[colors[name.lower()] for name in opt_names], alpha=0.8)
    ax2.set_title(f'{dataset_name.replace("_", " ").title()} - Test R¬≤ Score')
    ax2.set_ylabel('R¬≤ Score')
    ax2.set_ylim(0, 1)
    
    # Add value labels on bars
    for bar, value in zip(bars, test_r2s):
        ax2.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01, 
                f'{value:.3f}', ha='center', va='bottom', fontweight='bold')
    
    ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('optimizers_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 9. Performance vs Time Analysis

In [None]:
# Performance vs Time scatter plot
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle('Optimizer Efficiency: Performance vs Training Time', fontsize=16, fontweight='bold')

for i, (dataset_name, dataset_results) in enumerate(results.items()):
    ax = axes[i]
    
    for opt_name, opt_results in dataset_results.items():
        if 'test_r2' in opt_results and 'total_time' in opt_results:
            if opt_results['test_r2'] > -float('inf'):
                ax.scatter(opt_results['total_time'], opt_results['test_r2'], 
                          color=colors[opt_name], s=150, alpha=0.8,
                          label=opt_name.upper(), edgecolors='black', linewidth=1)
                
                # Add optimizer name as text
                ax.annotate(opt_name.upper(), 
                           (opt_results['total_time'], opt_results['test_r2']),
                           xytext=(5, 5), textcoords='offset points', 
                           fontsize=10, fontweight='bold')
    
    ax.set_title(f'{dataset_name.replace("_", " ").title()}')
    ax.set_xlabel('Training Time (seconds)')
    ax.set_ylabel('Test R¬≤ Score')
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 1)

plt.tight_layout()
plt.savefig('optimizers_efficiency.png', dpi=300, bbox_inches='tight')
plt.show()

## 10. Winner Analysis by Dataset

In [None]:
# Find best optimizer for each dataset
winners = {}
print("üèÜ BEST OPTIMIZER BY DATASET")
print("=" * 50)

for dataset_name, dataset_results in results.items():
    valid_results = {k: v for k, v in dataset_results.items() 
                    if 'test_r2' in v and v['test_r2'] > -float('inf')}
    
    if valid_results:
        # Find best by R¬≤ score
        best_opt = max(valid_results.items(), key=lambda x: x[1]['test_r2'])
        winners[dataset_name] = best_opt[0]
        
        print(f"\n{dataset_name.replace('_', ' ').title()}:")
        print(f"  ü•á Winner: {best_opt[0].upper()}")
        print(f"  üìä R¬≤ Score: {best_opt[1]['test_r2']:.4f}")
        print(f"  ‚è±Ô∏è  Training Time: {best_opt[1]['total_time']:.1f}s")
        print(f"  üìà Epochs: {best_opt[1]['epochs_trained']}")
        
        # Show ranking
        ranking = sorted(valid_results.items(), key=lambda x: x[1]['test_r2'], reverse=True)
        print(f"  üìã Full Ranking:")
        for rank, (opt_name, opt_result) in enumerate(ranking, 1):
            emoji = ["ü•á", "ü•à", "ü•â", "üìâ"][min(rank-1, 3)]
            print(f"     {emoji} {rank}. {opt_name.upper()} - R¬≤: {opt_result['test_r2']:.4f}")

# Overall winner count
print("\n" + "=" * 50)
print("üèÜ OVERALL OPTIMIZER RANKING:")

from collections import Counter
winner_count = Counter(winners.values())
overall_ranking = winner_count.most_common()

for rank, (optimizer, wins) in enumerate(overall_ranking, 1):
    emoji = ["üèÜ", "ü•à", "ü•â", "üìâ"][min(rank-1, 3)]
    print(f"  {emoji} {rank}. {optimizer.upper()} - {wins} dataset(s) won")

## 11. Key Findings and Recommendations

Based on our comprehensive comparison, here are the key insights for tabular deep learning practitioners:

### üìà Performance Insights:

1. **AdamW remains the robust baseline** - Consistent performance across all datasets
2. **Muon excels on smaller datasets** - Better stability with limited data
3. **Shampoo handles feature scaling well** - Strong on high-dimensional data
4. **NovoGrad works best with complex interactions** - Good for deep architectures

### ‚ö° Efficiency Considerations:

- **AdamW**: Fastest convergence, lowest memory overhead
- **Muon**: Moderate speed, good stability
- **Shampoo**: Slower but more sample-efficient
- **NovoGrad**: Variable speed, depends on architecture

### üéØ Practical Recommendations:

1. **Start with AdamW** - Use as baseline for comparison
2. **Try Muon for small datasets** - When you have &lt;10k samples
3. **Use Shampoo for mixed-scale features** - When features have very different ranges
4. **Consider NovoGrad for deep models** - When using attention/transformer architectures

### üîß Hyperparameter Tips:

- **AdamW**: lr=1e-3, weight_decay=1e-2
- **Muon**: lr=1e-2 (higher than Adam), momentum=0.95
- **Shampoo**: lr=1e-4 (lower due to aggressive scaling)
- **NovoGrad**: lr=1e-3, grad_averaging=True

## 12. Summary Statistics

In [None]:
# Calculate summary statistics
print("üìä SUMMARY STATISTICS")
print("=" * 40)

# Aggregate performance across all datasets
all_results = []
for dataset_name, dataset_results in results.items():
    for opt_name, opt_results in dataset_results.items():
        if 'test_r2' in opt_results and opt_results['test_r2'] > -float('inf'):
            all_results.append({
                'optimizer': opt_name,
                'dataset': dataset_name,
                'r2': opt_results['test_r2'],
                'mse': opt_results['test_mse'],
                'time': opt_results['total_time']
            })

results_df = pd.DataFrame(all_results)

if not results_df.empty:
    # Average performance by optimizer
    avg_performance = results_df.groupby('optimizer').agg({
        'r2': ['mean', 'std'],
        'mse': ['mean', 'std'], 
        'time': ['mean', 'std']
    }).round(4)
    
    print("\nAverage Performance by Optimizer:")
    print(avg_performance)
    
    # Best performer overall
    best_overall = results_df.groupby('optimizer')['r2'].mean().idxmax()
    print(f"\nüèÜ Best Overall Performer: {best_overall.upper()}")
    print(f"   Average R¬≤: {results_df.groupby('optimizer')['r2'].mean()[best_overall]:.4f}")
    
    # Fastest optimizer
    fastest = results_df.groupby('optimizer')['time'].mean().idxmin()
    print(f"\n‚ö° Fastest Optimizer: {fastest.upper()}")
    print(f"   Average Time: {results_df.groupby('optimizer')['time'].mean()[fastest]:.1f}s")
    
else:
    print("No valid results to analyze.")

print("\n" + "=" * 40)
print("‚úÖ Tutorial completed successfully!")
print("üìÅ Generated files: optimizers_comparison.png, optimizers_efficiency.png")