# Customizing Checkpointing Behavior

Orbax allows users to specify their own logic for dealing with custom objects.
Customization can occur at two levels - the level of a "checkpointable", and the
level of a "PyTree leaf".

## Custom Checkpointables

First, ensure that you are familiar with the documentation on "checkpointables".
To recap, a "checkpointable" is a distinct unit of an entire checkpoint. For
example, the model state is a checkpointable distinct from the dataset iterator.
Embeddings, if used, may also be represented as a separate checkpointable.

Let us consider a toy example. Let's say that in addition to our PyTree state
(represented as a dictionary of arrays, containing the parameters and optimizer
state) and our dataset iteration (represented using PyGrain), we also have an
object called `Point`, which has integer properties `x` and `y`. (Obviously,
since this object is a dataclass, it would be easy to just convert this to a
PyTree, and save it in the same way as the primary model state. So this example
is a bit contrived, but demonstrates the point well enough.)

Our `Point` class is defined as follows.

In [13]:
import dataclasses
import json
from typing import Any, Awaitable
import aiofiles
import jax
import numpy as np
import orbax.checkpoint.experimental.v1 as ocp


@dataclasses.dataclass
class Point:
  x: int
  y: int


model_state = {
    'params': np.arange(16),
    'opt_state': np.ones(16),
}

If we just try to save the `Point` (along with our other checkpointables), it
will fail because the object type is not recognized.

In [14]:
try:
  ocp.save_checkpointables(
      '/tmp/ckpt1',
      dict(model_state=model_state, point=Point(1, 2)),
  )
except BaseException as e:
  print(e)

Destination /tmp/ckpt1 already exists.


There are two possible approaches for implementing support for `Point` in Orbax.
We will start with the simpler of the two.

### Implementing `Point` as a `StatefulCheckpointable`

The `Point` object must implement the methods of the `StatefulCheckpointable`
Protocol. We need to implement `save` and `load` methods so that Orbax will know
how to deal with the `Point` object.

In [15]:
from __future__ import annotations

del Point


@dataclasses.dataclass
class Point(ocp.StatefulCheckpointable):

  x: int
  y: int

  async def save(
      self, directory: ocp.path.PathAwaitingCreation
  ) -> Awaitable[None]:
    return self._background_save(
        directory,
        # If the object could be modified by the main thread while being
        # written, it is important to make a copy to prevent race conditions.
        dataclasses.asdict(self),
    )

  async def load(self, directory: ocp.path.Path) -> Awaitable[None]:
    return self._background_load(directory)

  async def _background_save(
      self,
      directory: ocp.path.PathAwaitingCreation,
      value: dict[str, int],
  ):
    # In a multiprocess setting, prevent multiple processes from writing the
    # same thing.
    if jax.process_index() == 0:
      directory = await directory.await_creation()
      async with aiofiles.open(directory / 'point.txt', 'w') as f:
        contents = json.dumps(value)
        await f.write(contents)

  async def _background_load(
      self,
      directory: ocp.path.Path,
  ):
    async with aiofiles.open(directory / 'point.txt', 'r') as f:
      contents = json.loads(await f.read())
      self.x = contents['x']
      self.y = contents['y']

Let's break this down.

Both `save` and `load` methods consist of two phases: blocking and non-blocking.
Blocking operations must execute *now*, before returning control to the caller.
Non-blocking operations may occur in a background thread, and are represented by
an `Awaitable` function returned back to the caller without being executed
(yet).

When saving, in the case of `Point`, we make a copy of the properties to prevent
them from being concurrently modified by the main thread while we are writing
them in the background thread. For a `jax.Array`, we would similarly need to
perform a transfer from device memory to host memory. When the blocking
operations complete, we can construct an awaitable function that writes the
values to a file. Note also that we must wait for the parent directory to be
created, since upper layers of Orbax have already scheduled this execution
asynchronously.

Loading is similar. Typically there are fewer operations that need to happen
synchronously, as the caller should know they cannot do anything with the object
until it is fully loaded. Again, the awaitable function that is run in the
background should return nothing, and instead set relevant properties in `self`
after loading from disk.

Now we can successfully save the `Point`.

In [16]:
ocp.save_checkpointables(
    '/tmp/ckpt1',
    dict(model_state=model_state, point=Point(1, 2)),
)

ValueError: Destination /tmp/ckpt1 already exists.

It is important to note that because `Point` is a stateful checkpointable, we
have to provide a `Point` object in order to restore it. In typical usage, we
should construct a `Point` object with "uninitialized" values. Calling
`load_checkpointables` then updates the provided object as a side effect (it
also returns it).

In [5]:
uninitialized_point = Point(0, 0)
ocp.load_checkpointables(
    '/tmp/ckpt1',
    dict(point=uninitialized_point),
)
uninitialized_point

Point(x=1, y=2)

### Supporting `Point` with `CheckpointableHandler`

While `StatefulCheckpointable` has a simple and powerful interface, it may not
be the right fit in every case. `StatefulCheckpointable` may be insufficient in
cases such as:

*   `Point` may be defined in some third-party library that we cannot easily
    control, and thus could not directly add `save` and `load` methods to the
    class itself.
*   When loading, users might need to customize loading behavior in a more
    dynamic way. For a `jax.Array`, resharding, casting, and reshaping are
    common operations. For a `Point`, users might want to cast `x` and `y`
    between `int` and `float` more dynamically.
*   We may have multiple different ways to save and load `Point` that users want
    to enable in different contexts. In such cases, placing all that different
    logic within the single `Point` class may add too much complexity.

For such cases (and others), Orbax provides an interface called
`CheckpointableHandler`.

First, let's redefine our `Point` class and also introduce an `AbstractPoint`
class. This allows us to specify the type of `x` or `y` that should be used for
loading.

In [6]:
del Point
import asyncio
from typing import Type

Scalar = int | float


@dataclasses.dataclass
class Point:
  x: Scalar
  y: Scalar


@dataclasses.dataclass
class AbstractPoint:
  x: Type[Scalar]
  y: Type[Scalar]

In [7]:
async def _write_point(
    directory: ocp.path.Path, checkpointable: dict[str, Scalar]
):
  async with aiofiles.open(directory / 'point.txt', 'w') as f:
    contents = json.dumps(checkpointable)
    await f.write(contents)


async def _write_point_metadata(
    directory: ocp.path.Path, checkpointable: dict[str, Scalar]
):
  async with aiofiles.open(directory / 'point_metadata.txt', 'w') as f:
    contents = json.dumps(
        {k: type(v).__name__ for k, v in checkpointable.items()}
    )
    await f.write(contents)


class PointHandler(ocp.CheckpointableHandler[Point, AbstractPoint]):

  async def _background_save(
      self,
      directory: ocp.path.PathAwaitingCreation,
      checkpointable: dict[str, Scalar],
  ):
    if jax.process_index() == 0:
      directory = await directory.await_creation()
      await asyncio.gather(
          _write_point(directory, checkpointable),
          _write_point_metadata(directory, checkpointable),
      )

  async def _background_load(
      self,
      directory: ocp.path.Path,
      abstract_checkpointable: AbstractPoint | None = None,
  ) -> Point:
    async with aiofiles.open(directory / 'point.txt', 'r') as f:
      contents = json.loads(await f.read())
      if abstract_checkpointable is None:
        return Point(**contents)
      else:
        return Point(
            abstract_checkpointable.x(contents['x']),
            abstract_checkpointable.y(contents['y']),
        )

  async def save(
      self,
      directory: ocp.path.PathAwaitingCreation,
      checkpointable: Point,
      partial_save: bool = False,
  ) -> Awaitable[None]:
    return self._background_save(directory, dataclasses.asdict(checkpointable))

  async def load(
      self,
      directory: ocp.path.Path,
      abstract_checkpointable: AbstractPoint | None = None,
  ) -> Awaitable[Point]:
    return self._background_load(directory, abstract_checkpointable)

  async def metadata(self, directory: ocp.path.Path) -> AbstractPoint:
    async with aiofiles.open(directory / 'point_metadata.txt', 'r') as f:
      contents = json.loads(await f.read())
      return AbstractPoint(
          **{k: getattr(__builtins__, v) for k, v in contents.items()}
      )

  def is_handleable(self, checkpointable: Any) -> bool:
    return isinstance(checkpointable, Point)

  def is_abstract_handleable(self, abstract_checkpointable: Any) -> bool:
    return isinstance(abstract_checkpointable, AbstractPoint)

This class associates itself with two types, the `Checkpointable` and the
`AbstractCheckpointable` (`Point` and `AbstractPoint` in this case). `Point` is
the input for saving, and `AbstractPoint` (or `None`) is the input for loading
(as well as the parent directory in both cases).

Saving logic in this class is essentially the same as in our
`StatefulCheckpointable` definition above.

Loading is different because loading is no longer stateful - it instead accepts
an optional `AbstractPoint` and returns a newly constructed `Point`. Providing
`None` as the input indicates that the object should simply be restored exactly
as it was saved. (Note that for some objects, this may not be possible, and it
may be necessary to raise an error if some input from the user is required to
know how to load.) Otherwise, the provided `AbstractCheckpointable` serves as
the guide describing how the concrete loaded object (`Point` in this case)
should be constructed.

We also have the capability of defining a `metadata` method in this class. In
the case of `Point`, the object is obviously quite lightweight already. For real
use cases, the checkpoint itself may be expensive to load fully, and some
metadata describing important properties that can be loaded cheaply is
essential. The `metadata` method should return an instance of
`AbstractCheckpointable`.

Finally, two additional methods, `is_handleable` and `is_abstract_handleable`
should be defined. These methods accept any object, and decide whether the given
object is an acceptable input for saving or loading, respectively. In most
cases, a simple `isinstance` check will suffice, but for more generic
constructs, like `PyTree`s, more involved logic is necessary.

We can now register `PointHandler` in order to deal with `Point` objects.

In [8]:
ocp.handlers.register_handler(PointHandler)

In [9]:
ocp.save_checkpointables(
    '/tmp/ckpt2',
    dict(model_state=model_state, point=Point(1, 2.4)),
)

Since the `AbstractPoint` is optional, we do not need to specify any arguments
to load everything successfully.

In [10]:
ocp.load_checkpointables('/tmp/ckpt2')

{'model_state': {'opt_state': array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
  'params': array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])},
 'point': Point(x=1, y=2.4)}

However, if desired, we can specify an abstract checkpointable to customize the
dtypes of the restored values.

In [11]:
ocp.load_checkpointables(
    '/tmp/ckpt2', dict(point=AbstractPoint(x=float, y=int))
)

{'point': Point(x=1.0, y=2)}

We can use `checkpointables_metadata` to load the metadata, in the form of an
`AbstractPoint`.

In [12]:
ocp.checkpointables_metadata('/tmp/ckpt2').metadata['point']

AbstractPoint(x=<class 'int'>, y=<class 'float'>)