# Save and Load Checkpoints

In this tutorial, we will explore how to save and load checkpoints in `brainstate` by using the `orbax` library and `braintools` library which provide a more lightweight approach. This is particularly useful for saving the state of your model during training so that you can resume training from where you left off or use the trained model for inference later. The following example demonstrates how to use `orbax` and `braintools`'s checkpointing functionality with a simple MLP model.

First you can install the `orbax` library by running the following command:

`pip install orbax-checkpoint`

You may also install directly from GitHub, using the following command. This can be used to obtain the most recent version of Orbax.

`pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'`

You can install the `braintools` library by running the following command:

`pip install braintools`

First, let's import the necessary libraries.

In [7]:
import tempfile
import os

import jax
import jax.numpy as jnp
import orbax.checkpoint as orbax
import braintools
import brainstate 

## Define the Model
We define a simple Multi-Layer Perceptron (MLP) model using `brainstate`.

In [8]:
class MLP(brainstate.nn.Module):
    def __init__(self, din: int, dmid: int, dout: int):
        super().__init__()
        self.dense1 = brainstate.nn.Linear(din, dmid)
        self.dense2 = brainstate.nn.Linear(dmid, dout)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = self.dense1(x)
        x = jax.nn.relu(x)
        x = self.dense2(x)
        return x

## Create the Model
We create an instance of the model with a given seed for reproducibility.

In [9]:
SEED = 42
brainstate.random.seed(SEED)   # set seed in brainstate
model1 = MLP(10, 20, 30)    # create model
model1

MLP(
  dense1=Linear(
    in_size=(10,),
    out_size=(20,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[20]),
        'weight': ShapedArray(float32[10,20])
      }
    )
  ),
  dense2=Linear(
    in_size=(20,),
    out_size=(30,),
    w_mask=None,
    weight=ParamState(
      value={
        'bias': ShapedArray(float32[30]),
        'weight': ShapedArray(float32[20,30])
      }
    )
  )
)

## Save the Model State

### Save the Model State Using `orbax`
We save the model's parameters to a checkpoint file.

In [10]:
tmpdir = tempfile.mkdtemp()    # create temporary directory

# Helper function to convert State objects to plain dictionaries for orbax
def to_plain_dict(obj):
    # Check if it's a dict-like object first
    if isinstance(obj, dict):
        return {k: to_plain_dict(v) for k, v in obj.items()}
    # Try to access 'value' attribute safely
    try:
        if 'value' in dir(obj):
            return to_plain_dict(obj.value)
    except (TypeError, AttributeError):
        pass
    # Return as-is if it's a leaf value (array, number, etc.)
    return obj

# Save using orbax - convert to plain dict for compatibility
state_nest = brainstate.graph.states(model1).to_nest()
state_plain = to_plain_dict(state_nest)
checkpointer = orbax.PyTreeCheckpointer()   # create checkpointer
checkpointer.save(os.path.join(tmpdir, 'state'), state_plain)    # save state

Now, we've saved the model's parameters to the checkpoint files in `tmpdir/state` by using the `orbax` library.

### Save the Model State Using `braintools`

In [11]:
checkpoint = brainstate.graph.states(model1).to_nest()   # convert model to nest
braintools.file.msgpack_save(os.path.join(tmpdir, 'state.msgpack'), checkpoint)    # save checkpoint

Saving checkpoint into C:\Users\Administrator\AppData\Local\Temp\tmpnjecqtgi\state.msgpack


Now, we've saved the model's parameters to the checkpoint files in `tmpdir/state.msgpack` by using the `braintools` library.

## Load the Model State

### Load the Model State Using `orbax`
Let's load the model's parameters from the checkpoint files.

In [12]:
# create a new model with the same structure
brainstate.random.seed(0)
model2 = MLP(10, 20, 30)

# Load the parameters from checkpoint files using orbax
checkpointer = orbax.PyTreeCheckpointer()
restored_state = checkpointer.restore(os.path.join(tmpdir, 'state'))

# Helper function to update model states from loaded dictionary
def update_from_dict(model_dict, loaded_dict):
    for key in model_dict:
        if isinstance(model_dict[key], dict) and isinstance(loaded_dict.get(key), dict):
            update_from_dict(model_dict[key], loaded_dict[key])
        elif hasattr(model_dict[key], 'value'):
            model_dict[key].value = loaded_dict[key]

# Update the model with the loaded state
model2_states = brainstate.graph.states(model2).to_nest()
update_from_dict(model2_states, restored_state)

### Load the Model State Using `braintools`
Let's load the model's parameters from the checkpoint files.

In [None]:
# Create a model with the same structure.
brainstate.random.seed(0)
model3 = MLP(10, 20, 30)
checkpoint = brainstate.graph.states(model3).to_nest()

# Read the model parameters from the msgpack file
braintools.file.msgpack_load(os.path.join(tmpdir, 'state.msgpack'), checkpoint)

Loading checkpoint from C:\Users\Administrator\AppData\Local\Temp\tmpnjecqtgi\state.msgpack


{'dense1': {'weight': ParamState(
    value={
      'bias': ShapedArray(float32[20]),
      'weight': ShapedArray(float32[10,20])
    }
  )},
 'dense2': {'weight': ParamState(
    value={
      'bias': ShapedArray(float32[30]),
      'weight': ShapedArray(float32[20,30])
    }
  )}}

## Demonstrate the Loaded Model
Let's run the loaded model and check if it produces the same output as the original model.

In [14]:
y1 = model1(jnp.ones((1, 10)))
y2 = model2(jnp.ones((1, 10)))
y3 = model3(jnp.ones((1, 10)))
print(jnp.allclose(y1, y2))    # True
print(jnp.allclose(y1, y3))    # True

True
True
