# How to Save and Load Models

This guide shows you how to save and load BrainPy models for checkpointing, resuming training, and deployment.

## Quick Start

**Save a trained model:**

In [None]:
import brainpy
import brainstate
import pickle

# After training...
state_dict = {
    'params': net.states(brainstate.ParamState),
    'epoch': current_epoch,
}

with open('model.pkl', 'wb') as f:
    pickle.dump(state_dict, f)

**Load a model:**

In [None]:
# Create model with same architecture
net = MyNetwork()
brainstate.nn.init_all_states(net)

# Load saved state
with open('model.pkl', 'rb') as f:
    state_dict = pickle.load(f)

# Restore parameters
for name, state in state_dict['params'].items():
    net.states(brainstate.ParamState)[name].value = state.value

## Understanding What to Save

### State Types

BrainPy has three state types with different persistence requirements:

**ParamState (Always save)**
   - Learnable weights and biases
   - Required to restore trained model
   - Examples: synaptic weights, neural biases

**LongTermState (Usually save)**
   - Persistent statistics and counters
   - Not updated by gradients
   - Examples: running averages, spike counts

**ShortTermState (Never save)**
   - Temporary dynamics that reset each trial
   - Will be re-initialized anyway
   - Examples: membrane potentials, synaptic conductances

### Recommended Approach

In [None]:
def save_checkpoint(net, optimizer, epoch, filepath):
    """Save model checkpoint."""
    state_dict = {
        # Required: model parameters
        'params': net.states(brainstate.ParamState),

        # Optional but recommended: long-term states
        'long_term': net.states(brainstate.LongTermState),

        # Training metadata
        'epoch': epoch,
        'optimizer_state': optimizer.state_dict(),  # If continuing training

        # Model configuration (helpful for loading)
        'config': {
            'n_input': net.n_input,
            'n_hidden': net.n_hidden,
            'n_output': net.n_output,
            # ... other hyperparameters
        }
    }

    with open(filepath, 'wb') as f:
        pickle.dump(state_dict, f)

    print(f"‚úÖ Saved checkpoint to {filepath}")

## Basic Save/Load

### Using Pickle (Simple)

**Advantages:**
- Simple and straightforward
- Works with any Python object
- Good for quick prototyping

**Disadvantages:**
- Python-specific format
- Version compatibility issues
- Not human-readable

In [None]:
import pickle
import brainpy
import brainstate

# Define your model
class SimpleNet(brainstate.nn.Module):
    def __init__(self, n_neurons=100):
        super().__init__()
        self.lif = brainpy.state.LIF(n_neurons, V_rest=-65*u.mV, V_th=-50*u.mV, tau=10*u.ms)
        self.fc = brainstate.nn.Linear(n_neurons, 10)

    def update(self, x):
        self.lif(x)
        return self.fc(self.lif.get_spike())

# Train model
net = SimpleNet()
brainstate.nn.init_all_states(net)
# ... training code ...

# Save
params = net.states(brainstate.ParamState)
with open('simple_net.pkl', 'wb') as f:
    pickle.dump(params, f)

# Load
net_new = SimpleNet()
brainstate.nn.init_all_states(net_new)

with open('simple_net.pkl', 'rb') as f:
    loaded_params = pickle.load(f)

# Restore parameters
for name, state in loaded_params.items():
    net_new.states(brainstate.ParamState)[name].value = state.value

### Using NumPy (Arrays Only)

**Advantages:**
- Language-agnostic
- Efficient storage
- Widely supported

**Disadvantages:**
- Only saves arrays (not structure)
- Need to manually track parameter names

In [None]:
import numpy as np

# Save parameters as .npz
params = net.states(brainstate.ParamState)
param_dict = {name: np.array(state.value) for name, state in params.items()}
np.savez('model_params.npz', **param_dict)

# Load parameters
loaded = np.load('model_params.npz')
for name, array in loaded.items():
    net.states(brainstate.ParamState)[name].value = jnp.array(array)

## Checkpointing During Training

### Periodic Checkpoints

Save at regular intervals during training.

In [None]:
import braintools

# Training setup
net = MyNetwork()
optimizer = braintools.optim.Adam(lr=1e-3)
optimizer.register_trainable_weights(net.states(brainstate.ParamState))

save_interval = 5  # Save every 5 epochs
checkpoint_dir = './checkpoints'
import os
os.makedirs(checkpoint_dir, exist_ok=True)

# Training loop
for epoch in range(num_epochs):
    # Training step
    for batch in train_loader:
        loss = train_step(net, optimizer, batch)

    # Periodic save
    if (epoch + 1) % save_interval == 0:
        checkpoint_path = f'{checkpoint_dir}/epoch_{epoch+1}.pkl'
        save_checkpoint(net, optimizer, epoch, checkpoint_path)

        print(f"Epoch {epoch+1}: Loss={loss:.4f}, Checkpoint saved")

### Best Model Checkpoint

Save only when validation performance improves.

In [None]:
best_val_loss = float('inf')
best_model_path = 'best_model.pkl'

for epoch in range(num_epochs):
    # Training
    train_loss = train_epoch(net, optimizer, train_loader)

    # Validation
    val_loss = validate(net, val_loader)

    # Save if best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(net, optimizer, epoch, best_model_path)
        print(f"‚úÖ New best model! Val loss: {val_loss:.4f}")

    print(f"Epoch {epoch+1}: Train={train_loss:.4f}, Val={val_loss:.4f}")

### Resuming Training

Continue training from a checkpoint.

In [None]:
def load_checkpoint(filepath, net, optimizer=None):
    """Load checkpoint and restore state."""
    with open(filepath, 'rb') as f:
        state_dict = pickle.load(f)

    # Restore model parameters
    params = net.states(brainstate.ParamState)
    for name, state in state_dict['params'].items():
        if name in params:
            params[name].value = state.value

    # Restore long-term states
    if 'long_term' in state_dict:
        long_term = net.states(brainstate.LongTermState)
        for name, state in state_dict['long_term'].items():
            if name in long_term:
                long_term[name].value = state.value

    # Restore optimizer state
    if optimizer is not None and 'optimizer_state' in state_dict:
        optimizer.load_state_dict(state_dict['optimizer_state'])

    start_epoch = state_dict.get('epoch', 0) + 1
    return start_epoch

# Resume training
net = MyNetwork()
brainstate.nn.init_all_states(net)
optimizer = braintools.optim.Adam(lr=1e-3)
optimizer.register_trainable_weights(net.states(brainstate.ParamState))

# Load checkpoint
start_epoch = load_checkpoint('checkpoint_epoch_50.pkl', net, optimizer)

# Continue training from where we left off
for epoch in range(start_epoch, num_epochs):
    train_step(net, optimizer, train_loader)

## Advanced Saving Strategies

### Versioned Checkpoints

Keep multiple checkpoints without overwriting.

In [None]:
from datetime import datetime

def save_versioned_checkpoint(net, epoch, base_dir='checkpoints'):
    """Save checkpoint with timestamp."""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    filename = f'model_epoch{epoch}_{timestamp}.pkl'
    filepath = os.path.join(base_dir, filename)

    state_dict = {
        'params': net.states(brainstate.ParamState),
        'epoch': epoch,
        'timestamp': timestamp,
    }

    with open(filepath, 'wb') as f:
        pickle.dump(state_dict, f)

    return filepath

### Keep Last N Checkpoints

Automatically delete old checkpoints to save disk space.

In [None]:
import glob

def save_with_cleanup(net, epoch, checkpoint_dir='checkpoints', keep_last=5):
    """Save checkpoint and keep only last N."""

    # Save new checkpoint
    filepath = f'{checkpoint_dir}/epoch_{epoch:04d}.pkl'
    save_checkpoint(net, None, epoch, filepath)

    # Get all checkpoints
    checkpoints = sorted(glob.glob(f'{checkpoint_dir}/epoch_*.pkl'))

    # Delete old ones
    if len(checkpoints) > keep_last:
        for old_checkpoint in checkpoints[:-keep_last]:
            os.remove(old_checkpoint)
            print(f"Removed old checkpoint: {old_checkpoint}")

## Model Export for Deployment

### Minimal Model File

Save only what's needed for inference.

In [None]:
def export_for_inference(net, filepath, metadata=None):
    """Export minimal model for inference."""

    export_dict = {
        'params': net.states(brainstate.ParamState),
        'config': {
            # Only architecture info, no training state
            'model_type': net.__class__.__name__,
            # ... architecture hyperparameters
        }
    }

    if metadata:
        export_dict['metadata'] = metadata

    with open(filepath, 'wb') as f:
        pickle.dump(export_dict, f)

    # Report size
    size_mb = os.path.getsize(filepath) / (1024 * 1024)
    print(f"üì¶ Exported model: {size_mb:.2f} MB")

# Export trained model
export_for_inference(
    net,
    'deployed_model.pkl',
    metadata={
        'description': 'LIF network for digit classification',
        'accuracy': 0.95,
        'date': datetime.now().isoformat()
    }
)

### Loading for Inference

In [None]:
def load_for_inference(filepath, model_class):
    """Load model for inference only."""

    with open(filepath, 'rb') as f:
        export_dict = pickle.load(f)

    # Create model from config
    config = export_dict['config']
    net = model_class(**config)  # Must match saved config
    brainstate.nn.init_all_states(net)

    # Load parameters
    params = net.states(brainstate.ParamState)
    for name, state in export_dict['params'].items():
        params[name].value = state.value

    return net, export_dict.get('metadata')

# Load and use
net, metadata = load_for_inference('deployed_model.pkl', MyNetwork)
print(f"Loaded model: {metadata['description']}")

# Run inference
brainstate.nn.init_all_states(net)
output = net(input_data)

## Best Practices

‚úÖ **Always save configuration** - Include hyperparameters for reproducibility

‚úÖ **Version your checkpoints** - Track model version for compatibility

‚úÖ **Save metadata** - Include training metrics, date, description

‚úÖ **Regular backups** - Save periodically during long training

‚úÖ **Keep best model** - Separate best and latest checkpoints

‚úÖ **Test loading** - Verify checkpoint can be loaded before continuing

‚úÖ **Use relative paths** - Make checkpoints portable

‚úÖ **Document format** - Comment what's in your checkpoint files

‚ùå **Don't save ShortTermState** - It resets anyway

‚ùå **Don't save everything** - Minimize checkpoint size

‚ùå **Don't overwrite** - Keep multiple checkpoints for safety

## Summary

**Quick reference:**

In [None]:
# Save
checkpoint = {
    'params': net.states(brainstate.ParamState),
    'epoch': epoch,
    'config': net.get_config()
}
with open('checkpoint.pkl', 'wb') as f:
    pickle.dump(checkpoint, f)

# Load
with open('checkpoint.pkl', 'rb') as f:
    checkpoint = pickle.load(f)

net = MyNetwork.from_config(checkpoint['config'])
brainstate.nn.init_all_states(net)

for name, state in checkpoint['params'].items():
    net.states(brainstate.ParamState)[name].value = state.value

## See Also

- Core Concepts: State Management
- Tutorials: SNN Training
- GPU/TPU Usage