# Tutorial 5: Advanced Optimizers and Techniques

**Difficulty**: Advanced  
**Duration**: 40-50 minutes  
**Prerequisites**: Tutorials [3](./optim_tutorial_03_optax_getting_started.ipynb) and [4](./optim_tutorial_04_learning_rate_scheduling.ipynb) completion

## Learning Objectives
- Use specialized optimizers for specific scenarios
- Implement second-order optimization methods
- Apply gradient-free optimization
- Understand memory-efficient optimizers

## Topics Covered
1. **Specialized gradient-based optimizers**
   - Lion: Memory-efficient optimizer
   - Adafactor: Factorized second moments
   - Lookahead: k-step forward optimization
   - RAdam: Rectified Adam

2. **Large-scale training optimizers**
   - LAMB: Layer-wise adaptive large batch
   - LARS: Layer-wise adaptive rate scaling
   - SM3: Memory-efficient for large models

3. **Alternative optimization paradigms**
   - LBFGS: Quasi-Newton method
   - Rprop: Resilient backpropagation
   - Yogi: Additive adaptive methods

4. **Gradient-free optimization**
   - NevergradOptimizer integration
   - ScipyOptimizer for constrained problems

In [None]:
import time

import brainstate
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.gridspec import GridSpec

import braintools

## 1. Setting up Test Models and Data

We'll create different model architectures to test various optimizer characteristics.

In [None]:
class TransformerBlock(brainstate.nn.Module):
    """Simplified Transformer block for testing large-scale optimizers."""

    def __init__(self, dim=512, num_heads=8, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads

        # Multi-head attention components
        self.qkv = brainstate.nn.Linear(dim, dim * 3)
        self.proj = brainstate.nn.Linear(dim, dim)

        # MLP components
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.fc1 = brainstate.nn.Linear(dim, mlp_hidden_dim)
        self.fc2 = brainstate.nn.Linear(mlp_hidden_dim, dim)

        # Layer norms
        self.norm1 = brainstate.nn.LayerNorm(dim)
        self.norm2 = brainstate.nn.LayerNorm(dim)

    def __call__(self, x):
        # Simplified attention (without actual attention computation)
        residual = x
        x = self.norm1(x)

        # QKV projection
        qkv = self.qkv(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        # Simplified attention output (just use v for demonstration)
        attn_output = self.proj(v)
        x = residual + attn_output

        # MLP block
        residual = x
        x = self.norm2(x)
        x = self.fc1(x)
        x = jax.nn.gelu(x)
        x = self.fc2(x)
        x = residual + x

        return x


class CNNModel(brainstate.nn.Module):
    """CNN for testing memory-efficient optimizers."""

    def __init__(self, in_size, num_classes=10):
        super().__init__()
        # Conv layers
        self.conv1 = brainstate.nn.Conv2d(in_size, 64, kernel_size=3, padding=1)
        self.pool1 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv1.out_size)
        self.conv2 = brainstate.nn.Conv2d(self.pool1.out_size, 128, kernel_size=3, padding=1)
        self.pool2 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv2.out_size)
        self.conv3 = brainstate.nn.Conv2d(self.pool2.out_size, 256, kernel_size=3, padding=1)
        self.pool3 = brainstate.nn.MaxPool2d(2, 2, in_size=self.conv3.out_size)

        # Dense layers
        self.fc1 = brainstate.nn.Linear(int(np.prod(self.pool3.out_size)), 512)
        self.fc2 = brainstate.nn.Linear(512, num_classes)

    def __call__(self, x):
        # Reshape if needed
        if len(x.shape) == 2:
            x = x.reshape(-1, 32, 32, 3)

        # Conv blocks
        x = self.conv1(x)
        x = jax.nn.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = jax.nn.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = jax.nn.relu(x)
        x = self.pool3(x)

        # Flatten and FC layers
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = jax.nn.relu(x)
        x = self.fc2(x)

        return x


class SimpleRNN(brainstate.nn.Module):
    """Simple RNN for testing gradient stability."""

    def __init__(self, input_dim=10, hidden_dim=128, output_dim=10):
        super().__init__()
        self.rnn = brainstate.nn.ValinaRNNCell(input_dim, hidden_dim, num_layers=2)
        self.fc = brainstate.nn.Linear(hidden_dim, output_dim)

    def __call__(self, x):
        # x shape: (batch, seq_len, features)
        outputs = brainstate.transform.for_loop(self.rnn, x)
        # Use last timestep
        return self.fc(outputs[-1])

In [None]:
def create_synthetic_data(data_type='vision', n_samples=1000, seed=42):
    """Create synthetic data for different model types."""
    with brainstate.random.seed_context(seed):
        if data_type == 'vision':
            # Image-like data (32x32x3)
            X = brainstate.random.normal(size=(n_samples, 32, 32, 3)) * 0.5
            y = brainstate.random.randint(0, 10, size=(n_samples,))
        elif data_type == 'transformer':
            # Sequence data for transformer (seq_len=64, dim=512)
            X = brainstate.random.normal(size=(n_samples, 64, 512)) * 0.1
            y = brainstate.random.randint(0, 10, size=(n_samples,))
        elif data_type == 'sequence':
            # Sequence data for RNN (seq_len=20, features=10)
            X = brainstate.random.normal(size=(n_samples, 20, 10)) * 0.5
            y = brainstate.random.randint(0, 10, size=(n_samples,))
        else:
            # Default: flat features
            X = brainstate.random.normal(size=(n_samples, 784)) * 0.5
            y = brainstate.random.randint(0, 10, size=(n_samples,))

    return X, y


# Create datasets
X_vision, y_vision = create_synthetic_data('vision', n_samples=2000)
X_transformer, y_transformer = create_synthetic_data('transformer', n_samples=1000)
X_sequence, y_sequence = create_synthetic_data('sequence', n_samples=2000)

print(f"Vision data shape: {X_vision.shape}")
print(f"Transformer data shape: {X_transformer.shape}")
print(f"Sequence data shape: {X_sequence.shape}")

## 2. Gradient Computation and Training Infrastructure

Following the style from previous tutorials, we'll set up our gradient computation.

In [None]:
def compute_loss_and_grads(model, X, y, param_states, loss_type='classification'):
    """Compute loss and gradients following braintools style."""

    def loss_fn():
        # Forward pass
        outputs = model(X)

        if loss_type == 'classification':
            # Cross-entropy loss
            log_probs = jax.nn.log_softmax(outputs, axis=-1)
            one_hot = jax.nn.one_hot(y, num_classes=10)
            loss = -jnp.mean(jnp.sum(one_hot * log_probs, axis=-1))
        else:
            # MSE loss for regression
            loss = jnp.mean((outputs - y) ** 2)

        # Add L2 regularization
        l2_reg = 1e-4
        for state in param_states.values():
            loss = loss + l2_reg * jnp.sum(state.value ** 2)

        return loss

    # Compute loss and gradients
    loss = loss_fn()
    grads = brainstate.transform.grad(loss_fn, grad_states=param_states)()

    # Compute accuracy for classification
    if loss_type == 'classification':
        outputs = model(X)
        predictions = jnp.argmax(outputs, axis=-1)
        accuracy = jnp.mean(predictions == y)
    else:
        accuracy = -loss  # Use negative loss as metric for regression

    return loss, grads, accuracy


def train_with_optimizer(
    model: brainstate.nn.Module,
    optimizer: braintools.optim.OptaxOptimizer,
    X_train, y_train,
    X_val, y_val,
    n_epochs=30,
    batch_size=64,
    verbose=True
):
    """Generic training function for any optimizer."""

    # Get parameter states
    param_states = braintools.optim.UniqueStateManager(
        model.states(brainstate.ParamState)
    ).to_dict()

    # Register parameters with optimizer
    optimizer.register_trainable_weights(param_states)

    @brainstate.transform.jit
    def train_step(X_batch, y_batch):
        loss, grads, acc = compute_loss_and_grads(model, X_batch, y_batch, param_states)
        optimizer.update(grads)
        return loss, acc

    @brainstate.transform.jit
    def eval_step(X_batch, y_batch):
        loss, _, acc = compute_loss_and_grads(model, X_batch, y_batch, param_states)
        return loss, acc

    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'epoch_time': []
    }

    n_batches = len(X_train) // batch_size

    for epoch in range(n_epochs):
        epoch_start = time.time()

        # Shuffle data
        perm = brainstate.random.permutation(len(X_train))
        X_train_shuffled = X_train[perm]
        y_train_shuffled = y_train[perm]

        train_losses = []
        train_accs = []

        for batch_idx in range(n_batches):
            start_idx = batch_idx * batch_size
            end_idx = start_idx + batch_size
            X_batch = X_train_shuffled[start_idx:end_idx]
            y_batch = y_train_shuffled[start_idx:end_idx]

            loss, acc = train_step(X_batch, y_batch)
            train_losses.append(float(loss))
            train_accs.append(float(acc))

        # Validation
        val_loss, val_acc = eval_step(X_val[:500], y_val[:500])  # Use subset for speed

        # Update learning rate if scheduler is attached
        optimizer.lr.step()

        # Record metrics
        history['train_loss'].append(np.mean(train_losses))
        history['train_acc'].append(np.mean(train_accs))
        history['val_loss'].append(float(val_loss))
        history['val_acc'].append(float(val_acc))
        history['epoch_time'].append(time.time() - epoch_start)

        if verbose and (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch + 1}/{n_epochs} - "
                  f"Loss: {history['train_loss'][-1]:.4f}, "
                  f"Acc: {history['train_acc'][-1]:.4f}, "
                  f"Val Loss: {history['val_loss'][-1]:.4f}, "
                  f"Val Acc: {history['val_acc'][-1]:.4f}")

    return history

## 3. Specialized Gradient-Based Optimizers

Let's explore advanced optimizers designed for specific scenarios.

### 3.1 Lion Optimizer - Memory Efficient

Lion (EvoLved Sign Momentum) is a memory-efficient optimizer that uses sign updates.

In [None]:
# Lion optimizer
model_lion = CNNModel()

lion_optimizer = braintools.optim.Lion(
    lr=3e-4,  # Lion typically uses smaller learning rates
    betas=(0.9, 0.99),
    weight_decay=1e-4
)

print("Training with Lion optimizer (memory-efficient)...")
history_lion = train_with_optimizer(
    model_lion, lion_optimizer,
    X_vision[:1000], y_vision[:1000],
    X_vision[1000:1500], y_vision[1000:1500],
    n_epochs=30, batch_size=32
)

### 3.2 Adafactor - Factorized Second Moments

Adafactor reduces memory usage by factorizing the second moment estimation.

In [None]:
# Adafactor optimizer
model_adafactor = TransformerBlock()

adafactor_optimizer = braintools.optim.Adafactor(
    lr=1e-3,
    decay_rate=0.8,
    factored=True,  # Enable factorization for memory efficiency
    clip_threshold=1.0
)

print("Training with Adafactor (factorized second moments)...")
history_adafactor = train_with_optimizer(
    model_adafactor, adafactor_optimizer,
    X_transformer[:500], y_transformer[:500],
    X_transformer[500:700], y_transformer[500:700],
    n_epochs=30, batch_size=16
)

### 3.3 Lookahead Optimizer - k-step Forward

Lookahead maintains two sets of weights and performs k-step forward optimization.

In [None]:
# Lookahead optimizer wrapping SGD
model_lookahead = CNNModel()

# Base optimizer
base_optimizer = braintools.optim.SGD(lr=0.1, momentum=0.9)

# Wrap with Lookahead
lookahead_optimizer = braintools.optim.Lookahead(
    base_optimizer,
    sync_period=5,  # Update slow weights every 5 steps
    alpha=0.5  # Interpolation factor
)

print("Training with Lookahead optimizer (k-step forward)...")
history_lookahead = train_with_optimizer(
    model_lookahead, lookahead_optimizer,
    X_vision[:1000], y_vision[:1000],
    X_vision[1000:1500], y_vision[1000:1500],
    n_epochs=30, batch_size=32
)

### 3.4 RAdam - Rectified Adam

RAdam rectifies the variance of the adaptive learning rate to stabilize training.

In [None]:
# RAdam optimizer
model_radam = SimpleRNN()

radam_optimizer = braintools.optim.RAdam(
    lr=1e-3,
    betas=(0.9, 0.999),
    eps=1e-8,
    weight_decay=1e-4
)

print("Training with RAdam (Rectified Adam)...")
history_radam = train_with_optimizer(
    model_radam, radam_optimizer,
    X_sequence[:1000], y_sequence[:1000],
    X_sequence[1000:1500], y_sequence[1000:1500],
    n_epochs=30, batch_size=32
)

## 4. Large-Scale Training Optimizers

These optimizers are designed for training with large batch sizes and distributed settings.

### 4.1 LAMB - Layer-wise Adaptive Large Batch

LAMB enables large batch training by adapting the learning rate per layer.

In [None]:
# LAMB optimizer for large batch training
model_lamb = TransformerBlock()

lamb_optimizer = braintools.optim.Lamb(
    lr=2e-3,
    betas=(0.9, 0.999),
    eps=1e-6,
    weight_decay=0.01,
    grad_clip_value=10.0  # Gradient clipping
)

print("Training with LAMB (Large Batch optimizer)...")
# Simulate large batch by using larger batch size
history_lamb = train_with_optimizer(
    model_lamb, lamb_optimizer,
    X_transformer[:800], y_transformer[:800],
    X_transformer[800:], y_transformer[800:],
    n_epochs=30, batch_size=128  # Large batch size
)

### 4.2 LARS - Layer-wise Adaptive Rate Scaling

LARS adapts the learning rate for each layer based on the ratio of weight and gradient norms.

In [None]:
# LARS optimizer
model_lars = CNNModel()

lars_optimizer = braintools.optim.Lars(
    lr=0.1,
    momentum=0.9,
    weight_decay=1e-4,
    trust_coefficient=0.001,  # LARS-specific parameter
    eps=1e-8
)

print("Training with LARS (Layer-wise Adaptive Rate Scaling)...")
history_lars = train_with_optimizer(
    model_lars, lars_optimizer,
    X_vision[:1000], y_vision[:1000],
    X_vision[1000:1500], y_vision[1000:1500],
    n_epochs=30, batch_size=128
)

### 4.3 SM3 - Memory-Efficient for Large Models

SM3 uses a memory-efficient approximation of adaptive learning rates.

In [None]:
# SM3 optimizer for memory efficiency
model_sm3 = TransformerBlock()

sm3_optimizer = braintools.optim.SM3(
    lr=1e-3,
    momentum=0.9,
    eps=1e-8
)

print("Training with SM3 (Memory-efficient optimizer)...")
history_sm3 = train_with_optimizer(
    model_sm3, sm3_optimizer,
    X_transformer[:500], y_transformer[:500],
    X_transformer[500:700], y_transformer[500:700],
    n_epochs=30, batch_size=16
)

## 5. Alternative Optimization Paradigms

These optimizers use different principles than standard gradient descent.

### 5.1 L-BFGS - Quasi-Newton Method

L-BFGS approximates the Hessian matrix for second-order optimization.

In [None]:
# L-BFGS optimizer (Note: requires special handling)
from brainstate.nn import Linear


class SimpleMLP(brainstate.nn.Module):
    """Simple MLP for L-BFGS testing."""

    def __init__(self):
        super().__init__()
        self.fc1 = Linear(784, 128)
        self.fc2 = Linear(128, 10)

    def __call__(self, x):
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = jax.nn.relu(x)
        x = self.fc2(x)
        return x


model_lbfgs = SimpleMLP()

# L-BFGS requires full-batch training
lbfgs_optimizer = braintools.optim.LBFGS(
    lr=1.0,
    memory_size=10,
    line_search_fn='zoom'
)

print("Training with L-BFGS (Quasi-Newton method)...")
# Note: L-BFGS typically works better with full-batch
X_small = X_vision[:200].reshape(200, -1)
y_small = y_vision[:200]
X_val_small = X_vision[1000:1100].reshape(100, -1)
y_val_small = y_vision[1000:1100]

history_lbfgs = train_with_optimizer(
    model_lbfgs, lbfgs_optimizer,
    X_small, y_small,
    X_val_small, y_val_small,
    n_epochs=20, batch_size=200  # Full batch
)

### 5.2 Rprop - Resilient Backpropagation

Rprop uses only the sign of the gradient and adapts step sizes individually.

In [None]:
# Rprop optimizer
model_rprop = SimpleMLP()

rprop_optimizer = braintools.optim.Rprop(
    lr=1e-3,
    etas=(0.5, 1.2),  # Step size adaptation factors
    step_sizes=(1e-6, 50)  # Min and max step sizes
)

print("Training with Rprop (Resilient Backpropagation)...")
history_rprop = train_with_optimizer(
    model_rprop, rprop_optimizer,
    X_small, y_small,
    X_val_small, y_val_small,
    n_epochs=30, batch_size=32
)

### 5.3 Yogi - Additive Adaptive Methods

Yogi uses additive updates instead of multiplicative for better convergence.

In [None]:
# Yogi optimizer
model_yogi = CNNModel()

yogi_optimizer = braintools.optim.Yogi(
    lr=1e-2,
    betas=(0.9, 0.999),
    eps=1e-3  # Yogi typically uses larger epsilon
)

print("Training with Yogi (Additive adaptive method)...")
history_yogi = train_with_optimizer(
    model_yogi, yogi_optimizer,
    X_vision[:1000], y_vision[:1000],
    X_vision[1000:1500], y_vision[1000:1500],
    n_epochs=30, batch_size=32
)

## 6. Comparing Optimizer Performance

Let's visualize and compare the performance of different optimizer categories.

In [None]:
def plot_optimizer_comparison(histories, names, title="Optimizer Comparison"):
    """Create comprehensive comparison plots."""

    fig = plt.figure(figsize=(16, 10))
    gs = GridSpec(3, 3, figure=fig)

    # Define color palette
    colors = plt.cm.tab10(np.linspace(0, 1, len(histories)))

    # Training loss
    ax1 = fig.add_subplot(gs[0, 0])
    for hist, name, color in zip(histories, names, colors):
        ax1.plot(hist['train_loss'], label=name, color=color, linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss')
    ax1.set_title('Training Loss')
    ax1.legend(fontsize=8)
    ax1.grid(True, alpha=0.3)

    # Validation loss
    ax2 = fig.add_subplot(gs[0, 1])
    for hist, name, color in zip(histories, names, colors):
        ax2.plot(hist['val_loss'], label=name, color=color, linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Validation Loss')
    ax2.set_title('Validation Loss')
    ax2.legend(fontsize=8)
    ax2.grid(True, alpha=0.3)

    # Training accuracy
    ax3 = fig.add_subplot(gs[0, 2])
    for hist, name, color in zip(histories, names, colors):
        ax3.plot(hist['train_acc'], label=name, color=color, linewidth=2)
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Training Accuracy')
    ax3.set_title('Training Accuracy')
    ax3.legend(fontsize=8)
    ax3.grid(True, alpha=0.3)

    # Convergence speed (loss reduction)
    ax4 = fig.add_subplot(gs[1, 0])
    for hist, name, color in zip(histories, names, colors):
        loss_reduction = np.array(hist['train_loss']) / hist['train_loss'][0]
        ax4.plot(loss_reduction, label=name, color=color, linewidth=2)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Loss Reduction Ratio')
    ax4.set_title('Convergence Speed')
    ax4.legend(fontsize=8)
    ax4.grid(True, alpha=0.3)

    # Training time per epoch
    ax5 = fig.add_subplot(gs[1, 1])
    avg_times = [np.mean(hist['epoch_time']) for hist in histories]
    bars = ax5.bar(range(len(names)), avg_times, color=colors)
    ax5.set_xticks(range(len(names)))
    ax5.set_xticklabels(names, rotation=45, ha='right')
    ax5.set_ylabel('Average Time per Epoch (s)')
    ax5.set_title('Training Efficiency')
    ax5.grid(True, alpha=0.3, axis='y')

    # Final performance comparison
    ax6 = fig.add_subplot(gs[1, 2])
    final_train_loss = [hist['train_loss'][-1] for hist in histories]
    final_val_loss = [hist['val_loss'][-1] for hist in histories]

    x = np.arange(len(names))
    width = 0.35

    bars1 = ax6.bar(x - width / 2, final_train_loss, width, label='Train Loss', color='steelblue')
    bars2 = ax6.bar(x + width / 2, final_val_loss, width, label='Val Loss', color='coral')

    ax6.set_xticks(x)
    ax6.set_xticklabels(names, rotation=45, ha='right')
    ax6.set_ylabel('Final Loss')
    ax6.set_title('Final Performance')
    ax6.legend()
    ax6.grid(True, alpha=0.3, axis='y')

    # Loss landscape smoothness (variance of loss)
    ax7 = fig.add_subplot(gs[2, 0])
    for hist, name, color in zip(histories, names, colors):
        # Calculate rolling variance
        window = 5
        loss_array = np.array(hist['train_loss'])
        if len(loss_array) >= window:
            rolling_var = np.convolve(
                (loss_array - np.mean(loss_array)) ** 2,
                np.ones(window) / window,
                mode='valid'
            )
            ax7.plot(rolling_var, label=name, color=color, linewidth=2)
    ax7.set_xlabel('Epoch')
    ax7.set_ylabel('Loss Variance')
    ax7.set_title('Training Stability')
    ax7.legend(fontsize=8)
    ax7.grid(True, alpha=0.3)

    # Memory usage estimate (simplified)
    ax8 = fig.add_subplot(gs[2, 1:])  # Span two columns

    # Optimizer memory footprint (relative estimates)
    memory_factors = {
        'Lion': 0.5,  # Sign-based, very memory efficient
        'Adafactor': 0.6,  # Factorized moments
        'SM3': 0.7,  # Sparse second moments
        'Rprop': 0.8,  # Only step sizes
        'SGD': 0.9,  # Momentum only
        'Adam': 1.0,  # Baseline (first and second moments)
        'RAdam': 1.0,  # Same as Adam
        'Yogi': 1.0,  # Similar to Adam
        'Lookahead': 1.5,  # Two sets of weights
        'LAMB': 1.2,  # Layer-wise adaptation
        'LARS': 1.1,  # Layer-wise scaling
        'L-BFGS': 2.0,  # History of gradients
    }

    mem_values = [memory_factors.get(name, 1.0) for name in names]
    bars = ax8.barh(range(len(names)), mem_values, color=colors)
    ax8.set_yticks(range(len(names)))
    ax8.set_yticklabels(names)
    ax8.set_xlabel('Relative Memory Usage')
    ax8.set_title('Memory Efficiency Comparison')
    ax8.grid(True, alpha=0.3, axis='x')

    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Compare specialized optimizers
specialized_histories = [history_lion, history_adafactor, history_radam, history_yogi]
specialized_names = ['Lion', 'Adafactor', 'RAdam', 'Yogi']

plot_optimizer_comparison(
    specialized_histories,
    specialized_names,
    "Specialized Gradient-Based Optimizers"
)

In [None]:
# Compare large-scale optimizers
largescale_histories = [history_lamb, history_lars, history_sm3]
largescale_names = ['LAMB', 'LARS', 'SM3']

plot_optimizer_comparison(
    largescale_histories,
    largescale_names,
    "Large-Scale Training Optimizers"
)

## 7. Gradient-Free Optimization

For gradient-free optimization, braintools provides integration with specialized libraries.

### 7.1 Nevergrad Integration

Nevergrad provides a wide range of gradient-free optimization algorithms, please refer to the [nevergrad tutorial documentation](./optim_tutorial_01_nevergrad_optimizer.ipynb) for details.

### 7.2 SciPy Optimization

SciPy provides classical optimization algorithms including constrained optimization, please refer to the [scipy tutorial documentation](./optim_tutorial_02_scipy_optimizer.ipynb) for details.

## Summary and Best Practices

**Key Takeaways**

1. **Memory-Efficient Optimizers**
   - **Lion**: Best for very large models with memory constraints
   - **Adafactor**: Good balance of memory and performance
   - **SM3**: Excellent for sparse models

2. **Large-Scale Training**
   - **LAMB/LARS**: Essential for large batch training
   - Enable linear scaling of batch size with learning rate
   - Critical for distributed training

3. **Stability and Robustness**
   - **RAdam**: Rectified variance for stability
   - **Lookahead**: Reduces variance through averaging
   - **Yogi**: Additive updates for better convergence

4. **Alternative Paradigms**
   - **L-BFGS**: Excellent for small datasets with second-order information
   - **Rprop**: Robust to gradient noise
   - **Gradient-free**: When gradients are unavailable or unreliable

**When to Use Advanced Optimizers**

| Scenario | Recommended Optimizer | Reason |
|----------|----------------------|--------|
| Large Language Models | Lion, Adafactor | Memory efficiency |
| Distributed Training | LAMB, LARS | Large batch handling |
| Noisy Gradients | RAdam, Lookahead | Stability |
| Small Dataset | L-BFGS | Fast convergence |
| Research/Experimentation | Yogi, Custom | Novel behaviors |
| Constrained Optimization | ScipyOptimizer | Built-in constraints |
| Black-box Optimization | NevergradOptimizer | No gradients needed |

## Exercises

1. **Memory Comparison**: Train the same large model with Adam, Lion, and Adafactor. Monitor and compare memory usage.

2. **Large Batch Scaling**: Test how well different optimizers handle increasing batch sizes from 32 to 1024.

3. **Stability Analysis**: Add artificial noise to gradients and compare optimizer robustness.

4. **Hybrid Approach**: Implement a training schedule that switches optimizers (e.g., Adam → L-BFGS for fine-tuning).

5. **Custom Optimizer**: Create your own optimizer by combining ideas from different methods.

6. **Constraint Satisfaction**: Use ScipyOptimizer to solve a constrained optimization problem in neural network training.

7. **Hyperparameter Optimization**: Use NevergradOptimizer to tune the hyperparameters of another optimizer.