# Checkpointing with Orbax

## A Simple Recipe

The following example shows how you can synchronously save and restore a PyTree.



In [58]:
from etils import epath
import orbax.checkpoint as ocp
import numpy as np

Ensure that the top-level directory already exists before saving.

In [59]:
path = epath.Path('/tmp/my-checkpoints/')
if path.exists():
  path.rmtree()
path.mkdir()

Create a basic [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html). This is simply a nested tree-like structure, which can include dicts, lists, or more complicated objects. For the leaves of the tree, Orbax is capable of handling many different types. For our purposes, we will simply use a nested dict of some simple arrays.


In [60]:
my_tree = {
    'a': np.arange(8),
    'b': {
        'c': 42,
        'd': np.arange(16),
    }
}

To save and restore, we create a `Checkpointer` object. The `Checkpointer` must be constructed with a `CheckpointHandler` - essentially as a configuration providing the `Checkpointer` with the logic needed to save and restore your object.

For PyTrees, the most common checkpointable object, we can use the convenient shorthand of `PyTreeCheckpointer`, which is the same as `Checkpointer(PyTreeCheckpointHandler())`

In [61]:
checkpointer = ocp.PyTreeCheckpointer()
# 'checkpoint_name' must not already exist.
checkpointer.save(path / 'checkpoint_name', my_tree)
checkpointer.restore(path / 'checkpoint_name/')

{'a': array([0, 1, 2, 3, 4, 5, 6, 7]),
 'b': {'c': array(42),
  'd': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])}}

We can manually inspect the checkpoint files output by the save operation. Each parameter is stored as a separate directory. A final `checkpoint` file is produced which stores the structure of the PyTree. This file can also store actual parameter values, if so configured.

In [62]:
[str(f) for f in (path / 'checkpoint_name').iterdir()]

['/tmp/my-checkpoints/checkpoint_name/a',
 '/tmp/my-checkpoints/checkpoint_name/b.c',
 '/tmp/my-checkpoints/checkpoint_name/b.d',
 '/tmp/my-checkpoints/checkpoint_name/checkpoint']

## Managing Checkpoints

Sometimes, you may have multiple different objects that you want to checkpoint. You may also wish to benefit from more high-level management logic to keep track of your checkpoints while training progresses.

In [63]:
path = epath.Path('/tmp/checkpoint_manager')
state = {
  'a': np.arange(8),
  'b': np.arange(16),
}
extra_params = [42, 43]

In [64]:
# Keeps a maximum of 3 checkpoints, and only saves every other step.
options = ocp.CheckpointManagerOptions(
    max_to_keep=3,
    save_interval_steps=2
)
mngr = ocp.CheckpointManager(
    path,
    {
        'state': ocp.PyTreeCheckpointer(),
        'extra_params': ocp.PyTreeCheckpointer()
    },
    options=options)

for step in range(11):  # [0, 1, ..., 10]
  mngr.save(step, {'state': state, 'extra_params': extra_params})
restored = mngr.restore(10)
restored_state, restored_extra_params = restored['state'], restored['extra_params']

In [65]:
mngr.all_steps()

[6, 8, 10]

In [66]:
mngr.latest_step()

10

In [67]:
mngr.should_save(11)

False

## A Standard Recipe

In most cases, users will wish to save and restore a PyTree representing a model state over the course of many training steps. Many users will also wish to do this is a multi-host, multi-device environment.

First, we will create a PyTree state with sharded `jax.Array` as leaves.

In [68]:
import jax

path = epath.Path('/tmp/checkpoint_manager_sharded')

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec('model',)
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
train_state = {
    'a': np.arange(16),
    'b': np.ones(16),
}
train_state = jax.tree_map(create_sharded_array, train_state)
jax.tree_util.tree_map(lambda x: x.sharding, train_state)

{'a': NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec('model',)),
 'b': NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec('model',))}

In [69]:
num_steps = 10
options = ocp.CheckpointManagerOptions(
    max_to_keep=3,
    save_interval_steps=2
)
mngr = ocp.CheckpointManager(
    path,
    ocp.PyTreeCheckpointer(),
    options=options
)

@jax.jit
def train_fn(state):
  return jax.tree_util.tree_map(lambda x: x + 1, state)

for step in range(num_steps):
  train_state = train_fn(train_state)
  mngr.save(step, train_state)

In [70]:
mngr.restore(mngr.latest_step())

{'a': array([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
       dtype=int32),
 'b': array([10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.,
        10., 10., 10.], dtype=float32)}

Let's imagine now that we are starting a new training run, and would like to restore the checkpoint previously saved. In this case, we only know the tree structure of the checkpoint, and not the actual array values. We would also like to load the arrays with different sharding constraints than how they were originally saved.

In [71]:
train_state = jax.tree_util.tree_map(np.zeros_like, train_state)
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(None,)
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
train_state = jax.tree_util.tree_map(create_sharded_array, train_state)

Construct arguments needed for restoration.

In [72]:
shardings = jax.tree_map(lambda x: x.sharding, train_state)
restore_args = ocp.checkpoint_utils.construct_restore_args(
    train_state, shardings)
restore_args

{'a': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('int32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec(None,)), global_shape=(16,)),
 'b': ArrayRestoreArgs(restore_type=<class 'jax.Array'>, dtype=dtype('float32'), mesh=None, mesh_axes=None, sharding=NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec(None,)), global_shape=(16,))}

Alternatively, the arguments can be constructed manually for fine-grained control.

In [73]:
directly_constructed_restore_args = jax.tree_util.tree_map(
  lambda x: ocp.ArrayRestoreArgs(
      # Restore as object. Could also be np.ndarray, int, or others.
      restore_type=jax.Array,
      # Cast the restored array to a specific dtype.
      dtype=np.float32,
      sharding=x.sharding,
      # Padding or truncation may occur. Ensure that the shape matches the
      # saved shape!
      global_shape=x.shape,
  ),
  train_state)

In [74]:
restored = mngr.restore(
    mngr.latest_step(),
    items=train_state,
    restore_kwargs={'restore_args': restore_args},
)

In [75]:
restored

{'a': Array([ 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],      dtype=int32),
 'b': Array([10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10., 10.,
        10., 10., 10.], dtype=float32)}

In [76]:
jax.tree_util.tree_map(lambda x: x.sharding, restored)

{'a': NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec(None,)),
 'b': NamedSharding(mesh=Mesh('model': 1), spec=PartitionSpec(None,))}