# API Overview

In [None]:
import orbax.checkpoint as ocp
from orbax.checkpoint.checkpoint_managers import preservation_policy as preservation_policy_lib
from orbax.checkpoint.checkpoint_managers import save_decision_policy as save_decision_policy_lib
import jax
import numpy as np
from jax import numpy as jnp

path = ocp.test_utils.erase_and_create_empty('/tmp/my-checkpoints/')

state = {'layer0': {'bias': np.ones((4,)), 'weight': jnp.arange(16)}}
abstract_state = jax.tree.map(ocp.tree.to_shape_dtype_struct, state)
metadata = {'version': 1.0}
extra_metadata = {'version': 1.0, 'step': 0}
dataset = {'my_data': 2}

## CheckpointManager Layer

The  most high-level API layer provided by Orbax is the [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_manager.html). This is the API of choice for users dealing with a series of checkpoints denoted as steps in the context of a training run.

`CheckpointManagerOptions` allows customizing the behavior of the `CheckpointManager` along various dimensions. A partial list of important customization options is given below. See the API reference for a complete list.

*   `save_decision_policy`: A policy that determines when to save checkpoints.
*   `preservation_policy`: A policy that determines which checkpoints to keep.
*   `step_format_fixed_length`: Formats with leading `n` digits. This can make visually examining the checkpoints in sorted order easier.
*   `cleanup_tmp_directories`: Automatically cleans up existing temporary/incomplete directories when the `CheckpointManager` is created.
*   `read_only`: If True, then checkpoints save and delete are skipped. Restore works as usual.
*   `enable_async_checkpointing`: True by default. Be wary of turning off, as save performance may be significantly impacted.

If dealing with a single checkpointable object, like a train state, `CheckpointManager` can be created as follows:



Note that `CheckpointManager` always saves asynchronously, unless you set  `enable_async_checkpointing=False` in `CheckpointManagerOptions`. Make sure to use `wait_until_finished()` if you need to block until a save is complete.

### Basic Usage

In [2]:
import jax

directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-single/')

In [5]:
options = ocp.CheckpointManagerOptions(
    save_decision_policy=save_decision_policy_lib.FixedIntervalPolicy(2),
    preservation_policy=preservation_policy_lib.LatestN(2),
    # other options
)
mngr = ocp.CheckpointManager(
    directory,
    options=options,
)

In [None]:
num_steps = 5

def train_step(state):
  return jax.tree_util.tree_map(lambda x: x + 1, state)

for step in range(num_steps):
  state = train_step(state)
  mngr.save(step, args=ocp.args.StandardSave(state))
mngr.wait_until_finished()

In [None]:
mngr.latest_step()

In [None]:
mngr.all_steps()

In [None]:
mngr.restore(mngr.latest_step())

In [None]:
# Restore with additional arguments, like dtype or sharding.
def set_dtype(abstract_arr):
  return abstract_arr.update(dtype=np.float32)

mngr.restore(mngr.latest_step(), args=ocp.args.StandardRestore(
    jax.tree.map(set_dtype, abstract_state)))

### Managing Multiple Items

Often, we need to deal with multiple items, representing the training state, dataset, and some custom metadata, for instance.

In [None]:
directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-multiple/')

In [None]:
options = ocp.CheckpointManagerOptions(
    save_decision_policy=save_decision_policy_lib.FixedIntervalPolicy(2),
    preservation_policy=preservation_policy_lib.LatestN(2),
    # other options
)
mngr = ocp.CheckpointManager(
    directory,
    options=options,
)

In [None]:
num_steps = 5

def train_step(step, _state, _extra_metadata):
  return jax.tree_util.tree_map(lambda x: x + 1, _state), {**_extra_metadata, **{'step': step}}

for step in range(num_steps):
  state, extra_metadata = train_step(step, state, extra_metadata)
  mngr.save(
      step,
      args=ocp.args.Composite(
        state=ocp.args.StandardSave(state),
        extra_metadata=ocp.args.JsonSave(extra_metadata),
      )
  )
mngr.wait_until_finished()

In [None]:
# Restore exactly as saved
result = mngr.restore(mngr.latest_step())

In [None]:
result

In [None]:
result.state

In [None]:
result.extra_metadata

In [None]:
# Skip `state` when restoring.
# Note that it is possible to provide `extra_metadata=None` because we already
# saved using `JsonSave`. This is internally cached, so we know it uses JSON
# logic to save and restore. If you had called `restore` without first calling
# `save`, however, it would have been necessary to provide
# `ocp.args.JsonRestore`.
mngr.restore(mngr.latest_step(), args=ocp.args.Composite(extra_metadata=None))

In [None]:
# Restoration of the state can be customized by specifying an abstract state.
# For example, we can change the dtypes to automatically cast the restored
# arrays.
def set_dtype(abstract_arr):
  return abstract_arr.update(dtype=np.float32)

mngr.restore(
    mngr.latest_step(),
    args=ocp.args.Composite(
      state=ocp.args.StandardRestore(jax.tree.map(set_dtype, abstract_state)),
      extra_metadata=None
    )
)

There are some when the mapping between items and respective `CheckpointHandler`s need to be provided at the time of creating a `CheckpointManager` instance.

CheckpointManager constructor argument, `item_handlers`, enables to resolve those scenarios. Please see [Using the Refactored CheckpointManager API](https://orbax.readthedocs.io/en/latest/guides/checkpoint/api_refactor.html) for the details.



## Checkpointer Layer

Conceptually, the [`Checkpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html) exists to work with a single checkpoint that exists at a single path. It is no frills (relative to `CheckpointManager`) but guarantees atomicity and allows for asynchronous saving via `AsyncCheckpointer`.

### Saving and Restoring a PyTree

Typically, you may wish to save and restore a PyTree of arrays to a given path.
This is easily accomplished with `StandardCheckpointer`.

In [None]:
with ocp.StandardCheckpointer() as ckptr:
  ckptr.save(path / 'standard-ckpt-1', state)
  result = ckptr.restore(path / 'standard-ckpt-1', abstract_state)
  print(result)

Note that `StandardCheckpointer` always saves asynchronously! In order to block until a save completes, use `ckptr.wait_until_finished()`.

Equivalently, this can be expressed as follows (see the following section):

In [None]:
with ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler()) as ckptr:
  ckptr.save(path / 'standard-ckpt-2', args=ocp.args.StandardSave(state))

### Understanding Checkpointers

When greater customization of save and restore behavior is desired, Orbax must be instructed which logic to use to save and restore a given object. This is achieved by combining a `Checkpointer` with a `CheckpointHandler`. You can think of the `CheckpointHandler` as providing a configuration that tells the `Checkpointer` what serialization logic to use to deal with a particular object, while the `Checkpointer` provides shared logic used by all `CheckpointHandler`s, like thread management and atomicity.

In [None]:
with ocp.Checkpointer(ocp.JsonCheckpointHandler()) as ckptr:
  ckptr.save(path / 'json-ckpt-1', args=ocp.args.JsonSave({'a': 'b'}))

Async checkpointing provided via `AsyncCheckpointer` can often help to realize significant resource savings and training speedups because write to disk happens in a background thread. See [here](https://orbax.readthedocs.io/en/latest/guides/checkpoint/async_checkpointing.html) for more details.

In [None]:
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())

While most `Checkpointer`/`CheckpointHandler` pairs deal with a single object that is saved and restored, pairing a `Checkpointer` with `CompositeCheckpointHandler` allows dealing with multiple distinct objects at once.

In [None]:
with ocp.Checkpointer(ocp.CompositeCheckpointHandler()) as ckptr:
  ckptr.save(
      path / 'composite-ckpt-1',
      args=ocp.args.Composite(
          state=ocp.args.StandardSave(state),
          metadata=ocp.args.JsonSave(metadata),
      )
  )

## Understanding Items and Registration

Let's return to the subject of "items". This is the term Orbax uses to refer to logically distinct checkpointable units. These units may be bundled together as part of the same state, but it is frequently convenient to maintain some separation between them, as they are often used for very different purposes.

Some common examples may include the training state, dataset, embeddings, custom metadata, etc.

Each of these items may require different logic in order to save, and it is neither possible nor desirable for Orbax to "just figure it out" automatically. It is important to have confidence that the item you're saving is being saved as you expect it to be.

You can see a list of available handlers available for checkpointing different objects in the [API reference](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.args.html). In the case where none of these meet your needs, you can [create your own](https://orbax.readthedocs.io/en/latest/guides/checkpoint/custom_handlers.html).

Let's return to our standard example. In this section we will always use `CheckpointManager`, but all the following principles apply in the same way when using `Checkpointer(CompositeCheckpointHandler())`.

In [None]:
directory = ocp.test_utils.erase_and_create_empty('/tmp/checkpoint-manager-items-1/')

mngr = ocp.CheckpointManager(directory)
mngr.save(
    0,
    args=ocp.args.Composite(
      state=ocp.args.StandardSave(state),
      extra_metadata=ocp.args.JsonSave(extra_metadata),
    )
)
restored = mngr.restore(0)
print(restored.state)
print(restored.extra_metadata)

For any given item, be it `state`, `extra_metadata`, the first `arg` used to save or restore a given item is then "locked in" and used for all subsequent saves and restores. This is what allows us to restore without specifying any arguments.

In [None]:
mngr.save(1, args=ocp.args.Composite(
    state=ocp.args.StandardSave(state), extra_metadata=None))
restored = mngr.restore(1)
print(restored.state)
print(restored.extra_metadata)

We can also obtain metadata about our saved state, again without needing to specify any arguments.

In [None]:
meta = mngr.item_metadata(1)
print(meta.state)
mngr.close()

However, if we create a new `CheckpointManager` and try to get metadata or restore, we will get an error because the `CheckpointHandler for `state` is not configured. `item_metadata`, in contrast, does not raise an error, but returns None, so we have some indication that the item exists, but could not be reconstructed.

In [None]:
with ocp.CheckpointManager(directory) as mngr:
  try:
    print(mngr.restore(0))
  except BaseException as e:
    print(e)
  print('')
  print(mngr.item_metadata(0))

To fix this, we can pre-configure with a handler registry in order to specify the behavior that should be taken when restoring a particular item.

In [None]:
registry = ocp.handlers.DefaultCheckpointHandlerRegistry()
registry.add('state', ocp.args.StandardSave)
registry.add('state', ocp.args.StandardRestore)
with ocp.CheckpointManager(
    directory,
    handler_registry=registry,
) as mngr:
  print(mngr.restore(0, args=ocp.args.Composite(state=None)))
  print('')
  print(mngr.item_metadata(0))

As previously mentioned, once we have "locked in" the type for an item, either through eager configuration with the registry, or lazy configuration by providing `args`, we cannot change the item type without reinitializing the `CheckpointManager`.

In [None]:
with ocp.CheckpointManager(
    directory,
    handler_registry=registry,
) as mngr:
  mngr.save(2, args=ocp.args.PyTreeSave({'a': 'b'}))
  try:
    print(mngr.save(3, args=ocp.args.JsonSave({'a': 'b'})))
  except BaseException as e:
    print(e)

## CheckpointHandler Layer

The lowest-level API that users typically interact with in Orbax is the [`CheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html). Every `CheckpointHandler` is also paired with one or two [`CheckpointArgs`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.args.html) objects which encapsulate all necessary and optional arguments that a user can provide when saving or restoring.  At a high level `CheckpointHandler` exists to provide the logic required to save or restore a particular object in a checkpoint.

`CheckpointHandler` allows for synchronous saving. Subclasses of [`AsyncCheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#asynccheckpointhandler) allow for asynchronous saving. (Restoration is always synchronous.)

Crucially a `CheckpointHandler` instance **should not be used in isolation**, but should always be used **in conjunction with a `Checkpointer`**. Otherwise, save operations will not be atomic and async operations cannot be waited upon. This means that in most cases, you will be working with `Checkpointer` APIs rather than `CheckpointHandler` APIs.

However, it is still essential to understand `CheckpointHandler` because you need to know how you want your object to be saved and restored, and what arguments are necessary to make that happen.

Let's consider the example of [`StandardCheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#standardcheckpointhandler). This class is paired with [`StandardSave`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.args.html#standardsave) and [`StandardRestore`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.args.html#standardsave).

`StandardSave` allows specifying the `item` argument, which is the PyTree to be saved using Tensorstore. It also includes `save_args`, which is an optional `PyTree` with a structure matching `item`. Each leaf is a `ocp.type_handlers.SaveArgs` object, which can be used to customize things like the `dtype` of the saved array.

`StandardRestore` only has one possible argument, the `item`, which is a PyTree of concrete or abstract arrays matching the structure of the checkpoint. This is optional, and the checkpoint will be restored exactly as saved if no argument is provided.

In general, other `CheckpointHandler`s may have other arguments, and the contract can be discerned by looking at the corresponding `CheckpointArgs`. Additionally, `CheckpointHandler`s can be [customized](https://orbax.readthedocs.io/en/latest/guides/checkpoint/custom_handlers.html) for specific needs by providing your own implementation.

[`CompositeCheckpointHandler`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_handlers.html#compositecheckpointhandler) is a special case that allows composing multiple `CheckpointHandlers` at once. More details are provided throughout this page.