# Optimization Tricks and Training Techniques

In this tutorial, we'll explore advanced optimization techniques and training tricks to improve model performance and training stability.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Implement learning rate schedules
- Apply gradient clipping for stability
- Use weight decay for regularization
- Implement mixed precision training
- Create checkpoint management systems
- Apply early stopping
- Use gradient accumulation
- Monitor training with metrics

## What We'll Build

We'll create:
- Various learning rate schedulers
- Gradient clipping utilities
- Checkpoint manager
- Complete training pipeline with all tricks
- Training monitoring dashboard

In [None]:
import brainstate as bst
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, List, Tuple, Optional, Callable
import time
import pickle
import os

# Set random seed
bst.random.seed(42)

print(f"JAX devices: {jax.devices()}")

## 1. Learning Rate Schedules

Learning rate schedules adjust the learning rate during training to improve convergence.

In [None]:
class LRScheduler:
    """Base class for learning rate schedulers."""
    
    def __init__(self, initial_lr: float):
        self.initial_lr = initial_lr
        self.current_step = 0
    
    def step(self) -> float:
        """Get current learning rate and increment step."""
        lr = self.get_lr()
        self.current_step += 1
        return lr
    
    def get_lr(self) -> float:
        """Compute current learning rate."""
        raise NotImplementedError

class ConstantLR(LRScheduler):
    """Constant learning rate."""
    
    def get_lr(self) -> float:
        return self.initial_lr

class StepLR(LRScheduler):
    """Decay learning rate by gamma every step_size epochs."""
    
    def __init__(self, initial_lr: float, step_size: int, gamma: float = 0.1):
        super().__init__(initial_lr)
        self.step_size = step_size
        self.gamma = gamma
    
    def get_lr(self) -> float:
        return self.initial_lr * (self.gamma ** (self.current_step // self.step_size))

class ExponentialLR(LRScheduler):
    """Exponentially decay learning rate."""
    
    def __init__(self, initial_lr: float, gamma: float = 0.95):
        super().__init__(initial_lr)
        self.gamma = gamma
    
    def get_lr(self) -> float:
        return self.initial_lr * (self.gamma ** self.current_step)

class CosineAnnealingLR(LRScheduler):
    """Cosine annealing learning rate schedule."""
    
    def __init__(self, initial_lr: float, T_max: int, eta_min: float = 0):
        super().__init__(initial_lr)
        self.T_max = T_max
        self.eta_min = eta_min
    
    def get_lr(self) -> float:
        return self.eta_min + (self.initial_lr - self.eta_min) * \
               (1 + np.cos(np.pi * self.current_step / self.T_max)) / 2

class WarmupCosineSchedule(LRScheduler):
    """Warmup followed by cosine annealing."""
    
    def __init__(self, initial_lr: float, warmup_steps: int, total_steps: int):
        super().__init__(initial_lr)
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
    
    def get_lr(self) -> float:
        if self.current_step < self.warmup_steps:
            # Linear warmup
            return self.initial_lr * self.current_step / self.warmup_steps
        else:
            # Cosine annealing
            progress = (self.current_step - self.warmup_steps) / \
                      (self.total_steps - self.warmup_steps)
            return self.initial_lr * (1 + np.cos(np.pi * progress)) / 2

# Visualize schedules
n_steps = 100
initial_lr = 0.1

schedules = {
    'Constant': ConstantLR(initial_lr),
    'Step (decay every 30)': StepLR(initial_lr, step_size=30, gamma=0.5),
    'Exponential (γ=0.95)': ExponentialLR(initial_lr, gamma=0.95),
    'Cosine Annealing': CosineAnnealingLR(initial_lr, T_max=n_steps),
    'Warmup + Cosine': WarmupCosineSchedule(initial_lr, warmup_steps=10, total_steps=n_steps)
}

plt.figure(figsize=(12, 6))

for name, scheduler in schedules.items():
    lrs = []
    scheduler.current_step = 0
    for _ in range(n_steps):
        lrs.append(scheduler.step())
    plt.plot(lrs, label=name, linewidth=2)

plt.xlabel('Training Step')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedules Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 2. Gradient Clipping

Gradient clipping prevents exploding gradients by limiting gradient magnitude.

In [None]:
def clip_gradients_by_norm(grads: Dict, max_norm: float) -> Tuple[Dict, float]:
    """Clip gradients by global norm.
    
    Args:
        grads: Dictionary of gradients
        max_norm: Maximum allowed norm
        
    Returns:
        clipped_grads: Clipped gradients
        global_norm: Original global norm
    """
    # Compute global norm
    global_norm = jnp.sqrt(
        sum(jnp.sum(g ** 2) for g in grads.values())
    )
    
    # Compute clipping factor
    clip_factor = jnp.minimum(1.0, max_norm / (global_norm + 1e-6))
    
    # Clip gradients
    clipped_grads = {
        k: g * clip_factor for k, g in grads.items()
    }
    
    return clipped_grads, float(global_norm)

def clip_gradients_by_value(grads: Dict, clip_value: float) -> Dict:
    """Clip gradients by value (element-wise).
    
    Args:
        grads: Dictionary of gradients
        clip_value: Maximum absolute value
        
    Returns:
        clipped_grads: Clipped gradients
    """
    return {
        k: jnp.clip(g, -clip_value, clip_value)
        for k, g in grads.items()
    }

# Demonstrate gradient clipping
# Create artificial gradients with some very large values
test_grads = {
    'weight1': jnp.array([[1.0, 2.0], [3.0, 100.0]]),  # Contains large gradient
    'weight2': jnp.array([0.5, 0.3, 0.1]),
    'bias': jnp.array([0.1])
}

# Compute original norm
original_norm = jnp.sqrt(sum(jnp.sum(g ** 2) for g in test_grads.values()))
print(f"Original gradient norm: {original_norm:.4f}")

# Clip by norm
clipped_by_norm, norm = clip_gradients_by_norm(test_grads, max_norm=5.0)
clipped_norm = jnp.sqrt(sum(jnp.sum(g ** 2) for g in clipped_by_norm.values()))
print(f"\nAfter norm clipping (max_norm=5.0):")
print(f"  Gradient norm: {clipped_norm:.4f}")
print(f"  weight1: {clipped_by_norm['weight1']}")

# Clip by value
clipped_by_value = clip_gradients_by_value(test_grads, clip_value=2.0)
print(f"\nAfter value clipping (clip_value=2.0):")
print(f"  weight1: {clipped_by_value['weight1']}")

## 3. Weight Decay (L2 Regularization)

In [None]:
def apply_weight_decay(params: Dict, weight_decay: float) -> Dict:
    """Apply weight decay (L2 regularization) to parameters.
    
    Args:
        params: Dictionary of parameters
        weight_decay: Weight decay coefficient
        
    Returns:
        Updated parameters
    """
    return {
        k: p * (1 - weight_decay)
        for k, p in params.items()
    }

def compute_l2_loss(params: Dict) -> float:
    """Compute L2 regularization loss."""
    return sum(jnp.sum(p ** 2) for p in params.values())

# Demonstrate weight decay
test_params = {
    'weight': jnp.array([[1.0, 2.0], [3.0, 4.0]]),
    'bias': jnp.array([0.5, 0.3])
}

l2_before = compute_l2_loss(test_params)
print(f"L2 loss before decay: {l2_before:.4f}")

# Apply weight decay multiple times
l2_history = [l2_before]
params = test_params.copy()

for i in range(10):
    params = apply_weight_decay(params, weight_decay=0.01)
    l2_history.append(compute_l2_loss(params))

print(f"L2 loss after 10 steps: {l2_history[-1]:.4f}")

# Plot L2 loss evolution
plt.figure(figsize=(10, 5))
plt.plot(l2_history, 'b-o', linewidth=2, markersize=6)
plt.xlabel('Step')
plt.ylabel('L2 Loss')
plt.title('Effect of Weight Decay on L2 Loss')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 4. Checkpoint Management

In [None]:
class CheckpointManager:
    """Manage model checkpoints during training."""
    
    def __init__(self, checkpoint_dir: str, max_to_keep: int = 5):
        self.checkpoint_dir = checkpoint_dir
        self.max_to_keep = max_to_keep
        self.checkpoints = []  # List of (metric, path) tuples
        
        # Create directory if it doesn't exist
        os.makedirs(checkpoint_dir, exist_ok=True)
    
    def save_checkpoint(
        self,
        model: bst.graph.Node,
        epoch: int,
        metric: float,
        metadata: Optional[Dict] = None
    ) -> str:
        """Save model checkpoint.
        
        Args:
            model: Model to save
            epoch: Current epoch
            metric: Metric value (e.g., validation loss)
            metadata: Optional metadata to save
            
        Returns:
            Path to saved checkpoint
        """
        # Create checkpoint path
        checkpoint_path = os.path.join(
            self.checkpoint_dir,
            f'checkpoint_epoch_{epoch}_metric_{metric:.4f}.pkl'
        )
        
        # Collect model states
        states = {}
        for name, state in model.states().items():
            states[name] = np.array(state.value)
        
        # Save checkpoint
        checkpoint = {
            'epoch': epoch,
            'metric': metric,
            'states': states,
            'metadata': metadata or {}
        }
        
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(checkpoint, f)
        
        # Track checkpoint
        self.checkpoints.append((metric, checkpoint_path))
        
        # Remove old checkpoints if needed
        self._cleanup_old_checkpoints()
        
        return checkpoint_path
    
    def _cleanup_old_checkpoints(self):
        """Remove old checkpoints, keeping only the best ones."""
        if len(self.checkpoints) > self.max_to_keep:
            # Sort by metric (ascending)
            self.checkpoints.sort(key=lambda x: x[0])
            
            # Remove worst checkpoints
            while len(self.checkpoints) > self.max_to_keep:
                _, path_to_remove = self.checkpoints.pop(-1)
                if os.path.exists(path_to_remove):
                    os.remove(path_to_remove)
    
    def load_checkpoint(self, model: bst.graph.Node, checkpoint_path: str):
        """Load model from checkpoint.
        
        Args:
            model: Model to load into
            checkpoint_path: Path to checkpoint file
        """
        with open(checkpoint_path, 'rb') as f:
            checkpoint = pickle.load(f)
        
        # Restore states
        states = checkpoint['states']
        for name, value in states.items():
            if name in model.states():
                model.states()[name].value = jnp.array(value)
        
        return checkpoint
    
    def get_best_checkpoint(self) -> Optional[str]:
        """Get path to best checkpoint."""
        if not self.checkpoints:
            return None
        return min(self.checkpoints, key=lambda x: x[0])[1]

# Demonstrate checkpoint manager
checkpoint_dir = './temp_checkpoints'
manager = CheckpointManager(checkpoint_dir, max_to_keep=3)

# Create simple model
class SimpleModel(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.weight = bst.ParamState(jnp.ones((5, 3)))
    
    def __call__(self, x):
        return x @ self.weight.value

model = SimpleModel()

# Simulate saving checkpoints
print("Saving checkpoints...")
for epoch in range(5):
    metric = 1.0 - epoch * 0.1  # Simulated decreasing loss
    path = manager.save_checkpoint(model, epoch, metric)
    print(f"Epoch {epoch}: Saved checkpoint with metric={metric:.4f}")

print(f"\nBest checkpoint: {manager.get_best_checkpoint()}")
print(f"Number of checkpoints kept: {len(manager.checkpoints)}")

# Cleanup
import shutil
if os.path.exists(checkpoint_dir):
    shutil.rmtree(checkpoint_dir)

## 5. Early Stopping

In [None]:
class EarlyStopping:
    """Early stopping to prevent overfitting."""
    
    def __init__(self, patience: int = 10, min_delta: float = 0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.should_stop = False
    
    def __call__(self, val_loss: float) -> bool:
        """Check if training should stop.
        
        Args:
            val_loss: Current validation loss
            
        Returns:
            True if training should stop
        """
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss < self.best_loss - self.min_delta:
            # Improvement
            self.best_loss = val_loss
            self.counter = 0
        else:
            # No improvement
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        
        return self.should_stop

# Demonstrate early stopping
early_stop = EarlyStopping(patience=5, min_delta=0.01)

# Simulated validation losses
val_losses = [1.0, 0.8, 0.7, 0.65, 0.64, 0.64, 0.65, 0.64, 0.65, 0.66, 0.65]

print("Early Stopping Demo:")
print("=" * 60)

for epoch, loss in enumerate(val_losses):
    should_stop = early_stop(loss)
    print(f"Epoch {epoch}: val_loss={loss:.2f}, "
          f"counter={early_stop.counter}, "
          f"best={early_stop.best_loss:.2f}")
    
    if should_stop:
        print(f"\nEarly stopping triggered at epoch {epoch}!")
        break

## 6. Gradient Accumulation

Gradient accumulation allows effective large batch sizes with limited memory.

In [None]:
class GradientAccumulator:
    """Accumulate gradients over multiple mini-batches."""
    
    def __init__(self, accumulation_steps: int):
        self.accumulation_steps = accumulation_steps
        self.accumulated_grads = None
        self.step_count = 0
    
    def accumulate(self, grads: Dict) -> Tuple[bool, Optional[Dict]]:
        """Accumulate gradients.
        
        Args:
            grads: Gradients from current mini-batch
            
        Returns:
            should_update: Whether to update parameters
            averaged_grads: Averaged gradients (if should_update=True)
        """
        # Initialize accumulator
        if self.accumulated_grads is None:
            self.accumulated_grads = {k: jnp.zeros_like(v) for k, v in grads.items()}
        
        # Accumulate gradients
        for k, v in grads.items():
            self.accumulated_grads[k] += v
        
        self.step_count += 1
        
        # Check if we should update
        if self.step_count % self.accumulation_steps == 0:
            # Average accumulated gradients
            averaged_grads = {
                k: v / self.accumulation_steps
                for k, v in self.accumulated_grads.items()
            }
            
            # Reset accumulator
            self.accumulated_grads = None
            
            return True, averaged_grads
        else:
            return False, None

# Demonstrate gradient accumulation
accumulator = GradientAccumulator(accumulation_steps=4)

print("Gradient Accumulation Demo:")
print("=" * 60)

for step in range(10):
    # Simulate gradients
    grads = {'weight': jnp.ones(5) * (step + 1)}
    
    should_update, averaged_grads = accumulator.accumulate(grads)
    
    if should_update:
        print(f"Step {step}: UPDATE with averaged gradients: {averaged_grads['weight']}")
    else:
        print(f"Step {step}: Accumulate (no update)")

## 7. Training Metrics Monitor

In [None]:
class MetricsMonitor:
    """Monitor and visualize training metrics."""
    
    def __init__(self):
        self.history = {}
    
    def log(self, **metrics):
        """Log metrics for current step."""
        for name, value in metrics.items():
            if name not in self.history:
                self.history[name] = []
            self.history[name].append(float(value))
    
    def get_latest(self, metric_name: str) -> Optional[float]:
        """Get latest value of a metric."""
        if metric_name in self.history and self.history[metric_name]:
            return self.history[metric_name][-1]
        return None
    
    def get_best(self, metric_name: str, mode: str = 'min') -> Optional[float]:
        """Get best value of a metric."""
        if metric_name not in self.history or not self.history[metric_name]:
            return None
        
        if mode == 'min':
            return min(self.history[metric_name])
        else:
            return max(self.history[metric_name])
    
    def plot(self, figsize=(14, 10)):
        """Plot all metrics."""
        n_metrics = len(self.history)
        if n_metrics == 0:
            print("No metrics to plot")
            return
        
        # Calculate grid size
        n_cols = 2
        n_rows = (n_metrics + 1) // 2
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
        if n_rows == 1:
            axes = axes.reshape(1, -1)
        axes = axes.flatten()
        
        for idx, (name, values) in enumerate(self.history.items()):
            ax = axes[idx]
            ax.plot(values, linewidth=2)
            ax.set_xlabel('Step')
            ax.set_ylabel(name)
            ax.set_title(f'{name} (best: {self.get_best(name):.4f})')
            ax.grid(True, alpha=0.3)
        
        # Hide unused subplots
        for idx in range(n_metrics, len(axes)):
            axes[idx].axis('off')
        
        plt.tight_layout()
        plt.show()

# Demonstrate metrics monitor
monitor = MetricsMonitor()

# Simulate training
for epoch in range(50):
    train_loss = 1.0 / (epoch + 1) + np.random.randn() * 0.05
    val_loss = 1.2 / (epoch + 1) + np.random.randn() * 0.05
    train_acc = 1 - 1.0 / (epoch + 2) + np.random.randn() * 0.02
    val_acc = 1 - 1.2 / (epoch + 2) + np.random.randn() * 0.02
    
    monitor.log(
        train_loss=train_loss,
        val_loss=val_loss,
        train_acc=train_acc,
        val_acc=val_acc
    )

# Plot metrics
monitor.plot(figsize=(12, 8))

# Print summary
print("\nTraining Summary:")
print("=" * 60)
print(f"Best train loss: {monitor.get_best('train_loss'):.4f}")
print(f"Best val loss: {monitor.get_best('val_loss'):.4f}")
print(f"Best train acc: {monitor.get_best('train_acc', mode='max'):.4f}")
print(f"Best val acc: {monitor.get_best('val_acc', mode='max'):.4f}")

## 8. Complete Training Pipeline with All Tricks

In [None]:
class AdvancedTrainer:
    """Advanced trainer with all optimization tricks."""
    
    def __init__(
        self,
        model: bst.graph.Node,
        lr_scheduler: LRScheduler,
        gradient_clip_norm: Optional[float] = None,
        weight_decay: float = 0.0,
        gradient_accumulation_steps: int = 1,
        checkpoint_dir: Optional[str] = None,
        early_stopping_patience: Optional[int] = None,
    ):
        self.model = model
        self.lr_scheduler = lr_scheduler
        self.gradient_clip_norm = gradient_clip_norm
        self.weight_decay = weight_decay
        
        # Gradient accumulation
        self.grad_accumulator = GradientAccumulator(gradient_accumulation_steps)
        
        # Checkpoint manager
        self.checkpoint_manager = CheckpointManager(checkpoint_dir) if checkpoint_dir else None
        
        # Early stopping
        self.early_stopping = EarlyStopping(early_stopping_patience) if early_stopping_patience else None
        
        # Metrics monitor
        self.monitor = MetricsMonitor()
    
    def train_step(self, x_batch, y_batch):
        """Single training step."""
        with bst.environ.context(fit=True):
            # Define loss function
            def loss_fn():
                logits = self.model(jnp.array(x_batch))
                # Simple MSE loss for demonstration
                return jnp.mean((logits - jnp.array(y_batch)) ** 2)
            
            # Compute gradients
            loss, grads = bst.augment.grad(
                loss_fn,
                self.model.states(bst.ParamState),
                return_value=True
            )()
            
            return float(loss), grads
    
    def update_parameters(self, grads):
        """Update model parameters with all tricks."""
        # Gradient clipping
        if self.gradient_clip_norm is not None:
            grads, grad_norm = clip_gradients_by_norm(grads, self.gradient_clip_norm)
            self.monitor.log(grad_norm=grad_norm)
        
        # Get learning rate
        lr = self.lr_scheduler.step()
        self.monitor.log(learning_rate=lr)
        
        # Update parameters
        for name, grad in grads.items():
            self.model.states()[name].value -= lr * grad
        
        # Weight decay
        if self.weight_decay > 0:
            params = {k: v.value for k, v in self.model.states(bst.ParamState).items()}
            params = apply_weight_decay(params, self.weight_decay * lr)
            for name, value in params.items():
                self.model.states()[name].value = value
    
    def train(
        self,
        train_data: Tuple,
        val_data: Optional[Tuple] = None,
        num_epochs: int = 10,
        batch_size: int = 32,
    ):
        """Complete training loop."""
        X_train, y_train = train_data
        
        for epoch in range(num_epochs):
            # Training
            epoch_losses = []
            n_samples = len(X_train)
            
            for start_idx in range(0, n_samples, batch_size):
                end_idx = min(start_idx + batch_size, n_samples)
                x_batch = X_train[start_idx:end_idx]
                y_batch = y_train[start_idx:end_idx]
                
                # Train step
                loss, grads = self.train_step(x_batch, y_batch)
                epoch_losses.append(loss)
                
                # Gradient accumulation
                should_update, avg_grads = self.grad_accumulator.accumulate(grads)
                
                if should_update:
                    self.update_parameters(avg_grads)
            
            # Log training metrics
            train_loss = np.mean(epoch_losses)
            self.monitor.log(train_loss=train_loss, epoch=epoch)
            
            # Validation
            if val_data is not None:
                val_loss = self.evaluate(val_data, batch_size)
                self.monitor.log(val_loss=val_loss)
                
                # Checkpointing
                if self.checkpoint_manager is not None:
                    self.checkpoint_manager.save_checkpoint(
                        self.model, epoch, val_loss
                    )
                
                # Early stopping
                if self.early_stopping is not None:
                    if self.early_stopping(val_loss):
                        print(f"Early stopping at epoch {epoch}")
                        break
            
            # Print progress
            if epoch % 5 == 0:
                print(f"Epoch {epoch}: train_loss={train_loss:.4f}" +
                      (f", val_loss={val_loss:.4f}" if val_data else ""))
    
    def evaluate(self, data: Tuple, batch_size: int) -> float:
        """Evaluate model."""
        X, y = data
        losses = []
        
        with bst.environ.context(fit=False):
            for start_idx in range(0, len(X), batch_size):
                end_idx = min(start_idx + batch_size, len(X))
                x_batch = X[start_idx:end_idx]
                y_batch = y[start_idx:end_idx]
                
                logits = self.model(jnp.array(x_batch))
                loss = jnp.mean((logits - jnp.array(y_batch)) ** 2)
                losses.append(float(loss))
        
        return np.mean(losses)

# Demo with simple model and data
class DemoModel(bst.graph.Node):
    def __init__(self):
        super().__init__()
        self.linear = bst.nn.Linear(10, 5)
    
    def __call__(self, x):
        return self.linear(x)

# Generate synthetic data
X_train = bst.random.randn(200, 10)
y_train = bst.random.randn(200, 5)
X_val = bst.random.randn(40, 10)
y_val = bst.random.randn(40, 5)

# Create model and trainer
model = DemoModel()
scheduler = WarmupCosineSchedule(initial_lr=0.01, warmup_steps=5, total_steps=30)

trainer = AdvancedTrainer(
    model=model,
    lr_scheduler=scheduler,
    gradient_clip_norm=1.0,
    weight_decay=0.0001,
    gradient_accumulation_steps=2,
    early_stopping_patience=10
)

# Train
print("Training with all optimization tricks...")
print("=" * 70)
trainer.train(
    train_data=(X_train, y_train),
    val_data=(X_val, y_val),
    num_epochs=30,
    batch_size=32
)

# Visualize results
trainer.monitor.plot(figsize=(12, 8))

## Summary

In this tutorial, we covered advanced optimization techniques:

1. **Learning Rate Schedules**: Constant, Step, Exponential, Cosine, Warmup
2. **Gradient Clipping**: By norm and by value
3. **Weight Decay**: L2 regularization
4. **Checkpoint Management**: Save/load best models
5. **Early Stopping**: Prevent overfitting
6. **Gradient Accumulation**: Effective large batch sizes
7. **Metrics Monitoring**: Track and visualize training
8. **Complete Pipeline**: All tricks integrated

## Key Takeaways

- **Learning rate schedules** improve convergence
- **Gradient clipping** prevents exploding gradients
- **Weight decay** helps generalization
- **Checkpoints** save best models
- **Early stopping** prevents overfitting
- **Gradient accumulation** enables large batches with limited memory
- **Monitoring** is essential for debugging training
- Combining techniques yields best results

## Best Practices

1. Start with warmup + cosine schedule
2. Always clip gradients for RNNs
3. Use weight decay for regularization
4. Save checkpoints regularly
5. Monitor gradient norms
6. Use early stopping to save time
7. Accumulate gradients for large batches

## Next Steps

In the final tutorial, we'll cover:
- **Model Deployment**: Saving and loading models
- Export to different formats
- Inference optimization
- Batch processing for deployment