# Saving and Loading Model Checkpoints

This tutorial demonstrates how to save and load model checkpoints in BrainState, enabling you to:
- Preserve trained models for later use
- Resume interrupted training
- Share models with collaborators
- Deploy models in production

## Learning Objectives

By the end of this tutorial, you will:
- Extract and save model states
- Use Orbax for checkpointing
- Load models with abstract initialization
- Handle model versioning
- Implement best practices for model persistence

## Setup and Imports

In [None]:
import os
from tempfile import TemporaryDirectory

import jax
import jax.numpy as jnp
import orbax.checkpoint as orbax  # Google's checkpointing library

import brainstate

# Set random seed
brainstate.random.seed(42)

## Building a Simple Model

Let's create a multi-layer perceptron (MLP) for demonstration:

In [None]:
class MLP(brainstate.nn.Module):
    """Multi-layer perceptron for classification."""
    
    def __init__(self, din: int, dmid: int, dout: int):
        super().__init__()
        
        # Two dense layers
        self.dense1 = brainstate.nn.Linear(din, dmid)
        self.dense2 = brainstate.nn.Linear(dmid, dout)
    
    def __call__(self, x: jax.Array) -> jax.Array:
        """Forward pass through the network."""
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        return x

## Model Factory Function

It's good practice to define a factory function that creates models with consistent parameters:

In [None]:
def create_model(seed: int) -> MLP:
    """Create an MLP model with a given random seed.
    
    Args:
        seed: Random seed for initialization
        
    Returns:
        model: Initialized MLP
    """
    brainstate.random.seed(seed)
    return MLP(din=10, dmid=20, dout=30)


# Create and inspect a model
test_model = create_model(42)
print("Model created:")
print(test_model)
print(f"\nParameter shapes:")
for name, param in test_model.states(brainstate.ParamState).to_dict_values().items():
    print(f"  {name}: {param.shape}")

## Saving Model Checkpoints

### Method 1: Using Orbax (Recommended)

Orbax is Google's official checkpointing library for JAX. It provides:
- Efficient serialization
- Asynchronous saving
- Version management
- Cloud storage support

In [None]:
def save_model_orbax(model: MLP, path: str):
    """Save model parameters using Orbax.
    
    Args:
        model: Model to save
        path: Directory path for checkpoint
    """
    # Extract all states as a pytree
    state_tree = brainstate.graph.treefy_states(model)
    
    # Create Orbax checkpointer
    checkpointer = orbax.PyTreeCheckpointer()
    
    # Save the state tree
    checkpoint_path = os.path.join(path, 'state')
    checkpointer.save(checkpoint_path, state_tree)
    
    print(f"Model saved to {checkpoint_path}")


# Example: Save a model
with TemporaryDirectory() as tmpdir:
    model = create_model(seed=42)
    save_model_orbax(model, tmpdir)
    
    # Check what was saved
    print(f"\nFiles in checkpoint directory:")
    for root, dirs, files in os.walk(tmpdir):
        for file in files:
            filepath = os.path.join(root, file)
            size = os.path.getsize(filepath)
            print(f"  {file}: {size:,} bytes")

## Loading Model Checkpoints

### Abstract Initialization

BrainState provides `abstract_init` to create models with abstract shapes (no actual arrays), which is memory-efficient:

In [None]:
def load_model_orbax(path: str) -> MLP:
    """Load model parameters using Orbax.
    
    Args:
        path: Directory path for checkpoint
        
    Returns:
        model: Loaded model with restored parameters
    """
    # Create model with abstract shapes (memory efficient)
    model = brainstate.transform.abstract_init(lambda: create_model(0))
    
    # Get state tree structure
    state_tree = brainstate.graph.treefy_states(model)
    
    # Load parameters from checkpoint
    checkpointer = orbax.PyTreeCheckpointer()
    checkpoint_path = os.path.join(path, 'state')
    restored_state = checkpointer.restore(checkpoint_path, item=state_tree)
    
    # Update model with loaded parameters
    brainstate.graph.update_states(model, restored_state)
    
    print(f"Model loaded from {checkpoint_path}")
    return model


# Example: Save and load cycle
with TemporaryDirectory() as tmpdir:
    # Save a model
    print("=== Saving Model ===")
    original_model = create_model(seed=42)
    save_model_orbax(original_model, tmpdir)
    
    # Load the model
    print("\n=== Loading Model ===")
    loaded_model = load_model_orbax(tmpdir)
    
    # Verify parameters match
    print("\n=== Verification ===")
    original_params = brainstate.graph.treefy_states(original_model, brainstate.ParamState)
    loaded_params = brainstate.graph.treefy_states(loaded_model, brainstate.ParamState)
    
    # Check if parameters are identical
    params_match = jax.tree.map(
        lambda x, y: jnp.allclose(x, y),
        original_params.to_dict_values(),
        loaded_params.to_dict_values()
    )
    
    all_match = all(jax.tree.leaves(params_match))
    print(f"Parameters match: {all_match}")

## Testing the Loaded Model

Verify that the loaded model produces identical outputs:

In [None]:
with TemporaryDirectory() as tmpdir:
    # Create and save a model
    original_model = create_model(seed=42)
    save_model_orbax(original_model, tmpdir)
    
    # Load the model
    loaded_model = load_model_orbax(tmpdir)
    
    # Test with random input
    test_input = jnp.ones((5, 10))  # Batch of 5 samples
    
    original_output = original_model(test_input)
    loaded_output = loaded_model(test_input)
    
    print("\n=== Model Inference Test ===")
    print(f"Original output shape: {original_output.shape}")
    print(f"Loaded output shape: {loaded_output.shape}")
    print(f"\nOutputs match: {jnp.allclose(original_output, loaded_output)}")
    print(f"Max difference: {jnp.max(jnp.abs(original_output - loaded_output)):.2e}")

## Advanced: Selective State Saving

Sometimes you want to save only specific types of states (e.g., only parameters, not optimizer states):

In [None]:
def save_params_only(model: MLP, path: str):
    """Save only trainable parameters.
    
    Args:
        model: Model to save
        path: Directory path for checkpoint
    """
    # Extract only ParamState instances
    param_tree = brainstate.graph.treefy_states(model, brainstate.ParamState)
    
    checkpointer = orbax.PyTreeCheckpointer()
    checkpoint_path = os.path.join(path, 'params')
    checkpointer.save(checkpoint_path, param_tree)
    
    print(f"Parameters saved to {checkpoint_path}")


def load_params_only(model: MLP, path: str):
    """Load only trainable parameters.
    
    Args:
        model: Model to update
        path: Directory path for checkpoint
    """
    # Get parameter structure
    param_tree = brainstate.graph.treefy_states(model, brainstate.ParamState)
    
    checkpointer = orbax.PyTreeCheckpointer()
    checkpoint_path = os.path.join(path, 'params')
    restored_params = checkpointer.restore(checkpoint_path, item=param_tree)
    
    # Update only parameters
    brainstate.graph.update_states(model, restored_params)
    
    print(f"Parameters loaded from {checkpoint_path}")


# Example usage
with TemporaryDirectory() as tmpdir:
    # Save only parameters
    model = create_model(42)
    save_params_only(model, tmpdir)
    
    # Create new model and load parameters
    new_model = create_model(0)  # Different initialization
    load_params_only(new_model, tmpdir)
    
    print("\nParameter transfer successful!")

## Best Practices

### 1. Version Your Checkpoints

In [None]:
import json

def save_with_metadata(model: MLP, path: str, metadata: dict):
    """Save model with metadata.
    
    Args:
        model: Model to save
        path: Directory path
        metadata: Dictionary with version info, hyperparameters, etc.
    """
    # Save model states
    save_model_orbax(model, path)
    
    # Save metadata as JSON
    metadata_path = os.path.join(path, 'metadata.json')
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"Metadata saved to {metadata_path}")


# Example with metadata
with TemporaryDirectory() as tmpdir:
    model = create_model(42)
    
    metadata = {
        'model_version': '1.0.0',
        'architecture': 'MLP',
        'input_dim': 10,
        'hidden_dim': 20,
        'output_dim': 30,
        'training_date': '2025-10-10',
        'framework': 'brainstate',
    }
    
    save_with_metadata(model, tmpdir, metadata)
    
    # Load and check metadata
    metadata_path = os.path.join(tmpdir, 'metadata.json')
    with open(metadata_path, 'r') as f:
        loaded_metadata = json.load(f)
    
    print("\nLoaded metadata:")
    print(json.dumps(loaded_metadata, indent=2))

### 2. Periodic Checkpointing During Training

Example of saving checkpoints during training:

In [None]:
def train_with_checkpoints(
    model: MLP,
    checkpoint_dir: str,
    num_epochs: int = 10,
    checkpoint_every: int = 5
):
    """Training loop with periodic checkpointing.
    
    Args:
        model: Model to train
        checkpoint_dir: Directory for checkpoints
        num_epochs: Number of training epochs
        checkpoint_every: Save checkpoint every N epochs
    """
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    for epoch in range(num_epochs):
        # Simulate training
        print(f"Epoch {epoch + 1}/{num_epochs}")
        
        # Save checkpoint periodically
        if (epoch + 1) % checkpoint_every == 0:
            epoch_dir = os.path.join(checkpoint_dir, f'epoch_{epoch + 1}')
            save_model_orbax(model, epoch_dir)
            print(f"  Checkpoint saved!\n")


# Example (simulated)
with TemporaryDirectory() as tmpdir:
    model = create_model(42)
    print("Training with periodic checkpoints:\n")
    train_with_checkpoints(model, tmpdir, num_epochs=10, checkpoint_every=3)
    
    # List saved checkpoints
    print("\nSaved checkpoints:")
    for item in sorted(os.listdir(tmpdir)):
        if os.path.isdir(os.path.join(tmpdir, item)):
            print(f"  {item}")

## Key Concepts Summary

### Checkpointing Workflow

1. **Save**:
   - Extract states with `treefy_states`
   - Serialize with Orbax checkpointer
   - Include metadata for versioning

2. **Load**:
   - Create model with `abstract_init` (memory efficient)
   - Restore states from checkpoint
   - Update model with `update_states`

### Key Functions

- `brainstate.graph.treefy_states`: Extract states as pytree
- `brainstate.graph.update_states`: Update model with new states
- `brainstate.transform.abstract_init`: Create model with abstract shapes
- `orbax.PyTreeCheckpointer`: Efficient serialization

### Best Practices

1. **Use Orbax**: Industry-standard, efficient, cloud-compatible
2. **Version Control**: Save metadata with checkpoints
3. **Periodic Saving**: Checkpoint during long training runs
4. **Test Loading**: Always verify loaded models work correctly
5. **Selective Saving**: Save only what you need (params vs. all states)

### Common Use Cases

- **Resume Training**: Load last checkpoint after interruption
- **Model Deployment**: Save trained model for production
- **Experimentation**: Save multiple versions for comparison
- **Transfer Learning**: Load pre-trained weights
- **Ensembling**: Save multiple models for ensemble predictions

## Exercises

1. **Implement Early Stopping**:
   - Save checkpoint only when validation loss improves
   - Keep only the best checkpoint

2. **Checkpoint Manager**:
   - Create a class to manage multiple checkpoints
   - Implement automatic cleanup of old checkpoints
   - Add rollback functionality

3. **Cloud Storage**:
   - Use Orbax with Google Cloud Storage
   - Implement automatic backup

4. **Model Zoo**:
   - Create a library of pre-trained models
   - Implement model registry
   - Add download functionality

## Next Steps

- **Distributed Training**: Checkpointing across multiple devices
- **Model Quantization**: Save compressed models
- **ONNX Export**: Convert to ONNX format
- **Serving**: Deploy models with TensorFlow Serving or TorchServe

## References

- [Orbax Documentation](https://orbax.readthedocs.io/)
- [JAX Checkpointing Guide](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)
- [BrainState Graph API](https://brainstate.readthedocs.io/en/latest/apis/graph.html)
- [Model Versioning Best Practices](https://neptune.ai/blog/version-control-for-ml-models)