# Checkpoint Format Guide TODO(b/411456584) Add links.

## What is an Orbax checkpoint?

An Orbax checkpoint is a directory containing an empty file named
`orbax.checkpoint`. All Orbax checkpoints saved with the V1 API will include
this file, and any directories not including the file are not valid checkpoints
(note that they could still be valid older checkpoints saved with the V0 API).

## Overview

Consider the following directory tree:

```
/path/to/my/checkpoints/
  0/
    pytree/
      ...
    dataset/
      ...
  100/
    ...
  200/
    ...
```

What does each level represent?

The top-level directory is called a **root directory**.

Within the root directory is a sequence of individual **checkpoints**. In a
training context, each of these checkpoints corresponds to an integer step.

Within each checkpoint are a set of **checkpointables** corresponding to
individual elements like the PyTree train state, the dataset iterator, and so
on.

Let's take a closer look at these elements.

### Singular Checkpoints

A checkpoint is a persistent representation of an ML model present in a storage
location, typically on disk. When a model is saved using Orbax, it becomes a
checkpoint. When a checkpoint is loaded using Orbax, it becomes a model.

Concretely, in Orbax, a checkpoint is composed of a collection of
**checkpointables**. That means if we save using the following:

```
ocp.save_checkpointables(
  '/path/to/my/checkpoint/',
  dict(pytree=..., dataset=..., other_checkpointable=...),
)
```

We get a checkpoint on disk with a structure similar to the following:

```
/path/to/my/checkpoint/  # The checkpoint path.
  pytree/  # A directory containing the PyTree piece of the checkpoint.
  dataset/  # A directory containing the dataset piece of the checkpoint.
  other_checkpointable/  # Another checkpointable
```

Each checkpointable is represented by a subdirectory.

Similarly, we can use a different API:

```
ocp.save_pytree(
  '/path/to/my/checkpoint/',
  pytree_of_arrays,
)
```

This produces a checkpoint where `pytree` is the only subdirectory.

```
/path/to/my/checkpoint/  # The checkpoint path.
  pytree/  # A directory containing the PyTree piece of the checkpoint.
```

### Sequence of Checkpoints

Make sure not to confuse a "checkpoint" with a "sequence of checkpoints". For
example, when using `training.Checkpointer`, multiple checkpoints representing
steps will saved to a **root directory**.

For example, if we save a sequence of steps using the following:

```
with ocp.training.Checkpointer('/path/to/my/root_directory/') as ckptr:
  for step in range(start_step, num_steps):
    ckptr.save_checkpointables(step, ...)
```

Our root directory will look like the following, where each integer-numbered
subdirectory represents a single checkpoint, corresponding to a step.

```
/path/to/my/root_directory/
  0/
  100/
  200/
  ...
```

## Format Details

Now that we understand the checkpoint format abstractly, let's get to some
concrete details.

First, some setup:

In [None]:
import json
import pprint
from etils import epath
import jax
import numpy as np
from orbax.checkpoint import v1 as ocp

In [None]:
directory = epath.Path('/tmp/my-checkpoints')
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))

pytree = {
    'params': {
        'layer0': {
            'kernel': np.random.uniform(size=(2, 2)),
            'bias': np.ones(2),
        }
    },
    'opt_state': {'0': np.random.random(size=(2,))},
}
pytree = jax.device_put(pytree, sharding)

In [None]:
def print_directory(directory: epath.PathLike, level: int = 0):
  """Prints a directory tree for debugging purposes."""
  directory = epath.Path(directory)
  assert directory.is_dir()
  level_str = '..' * level
  if level == 0:
    print(f'{directory}/')
  else:
    print(f'{level_str}{directory.name}/')

  level_str = '..' * (level + 1)
  for p in directory.iterdir():
    if p.is_dir():
      print_directory(p, level=level + 1)
    else:
      print(f'{level_str}{p.name}')

### Generic Checkpoints

Let's create a checkpoint with two checkpointables, `pytree` and
`extra_properties`. Let's also pass some custom metadata, which allows users to
provide JSON-serializable properties. For demonstration purposes, let's save
`extra_properties` as a JSON checkpointable.

In [None]:
# Note that the example would work even without the extra step of forcing
# `extra_properties` to be handled by `JsonHandler`. We just want to ensure it
# gets JSON-encoded for demonstration purposes.
with ocp.Context(
    checkpointables_options=ocp.options.CheckpointablesOptions.create_with_handlers(
        extra_properties=ocp.handlers.JsonHandler
    )
):
  ocp.save_checkpointables(
      directory / 'ckpt-0',
      dict(pytree=pytree, extra_properties={'foo': 'bar'}),
      custom_metadata={'version': 1.0},
  )

In [None]:
!ls {directory / 'ckpt-0'}

As we expected, each checkpointable gets its own subdirectory. There is also a
`_CHECKPOINT_METADATA` file created, which contains JSON-encoded metadata.

In [None]:
pprint.pp(
    json.loads((directory / 'ckpt-0' / '_CHECKPOINT_METADATA').read_text())
)

This file contains a number of internal properties recorded by Orbax. The most
important of these is `item_handlers`, which records the handler used to save
each checkpointable, to facilitate later loading.

Notice that our `custom_metadata` is also stored in this file.

### PyTree Checkpointables

Using the same checkpoint, let's dig into the `pytree` subdirectory.

In [None]:
print_directory(directory / 'ckpt-0' / 'pytree')

The `_METADATA` file provides a complete description of the PyTree structure,
including custom and empty nodes.

The tree is represented as a flattened dictionary, where each key is represented
as a tuple, where successive elements denote successive levels of nesting. For
example, for the dict `{'a': {'b': [1, 2]}}` the metadata file would contain two
entries with keys `('a', 'b', '0')` and `('a', 'b', '1')`.

Keys at each level of nesting also encode what type they are: i.e. whether they
are a dict key or a sequential key.

Finally, metadata about the value type is stored (e.g. `jax.Array`,
`np.ndarray`, etc.) in order to allow for later reconstruction without
explicitly requiring the object type to be provided.

In [None]:
pprint.pp(
    json.loads((directory / 'ckpt-0' / 'pytree' / '_METADATA').read_text())
)

While the exact structure of the metadata is an internal implementation detail
and is subject to change, it can still be useful to manually inspect the tree
structure. In most cases, however, it is still preferable to rely on public
methods intended for obtaining metadata.

In [None]:
pprint.pp(ocp.pytree_metadata(directory / 'ckpt-0').metadata)

In [None]:
print_directory(directory / 'ckpt-0' / 'pytree')

Aside from the `_METADATA` file, most other files are not human-readable.

The `_sharding` file stores information about the shardings used when saving
`jax.Array`s in the tree. Similarly `array_metadatas` records array properties
separately on each process, so that these properties may be later compared and
validated.

Orbax uses the [TensorStore](https://google.github.io/tensorstore/) library to
save individual arrays. Actual array data is stored within the `d/` subdirectory
while directly managed by Orbax, while TensorStore metadata is recorded by the
`manifest.ocdbt` file. These files are not human-readable and require
TensorStore APIs to parse (see below).

Finally, you'll notice the presence of the directory `ocdbt.process_0/`, which
also has a `manifest.ocdbt` and its own `d/` subdirectory. One such folder
exists for every process on which the checkpoint was saved. This exists because
each process first writes its own data independently to its corresponding
subdirectory.

When all processes have finished, Orbax runs a finalization pass to cheaply
merge the metadatas from all per-process subdirectories into a global view (note
that this still references data in the original subdirectories). This allows for
scalability in checkpoint saving as the number of concurrent processes
increases.

#### Working with TensorStore

Sometimes, it is helpful to work directly with the
[TensorStore](https://google.github.io/tensorstore/) API to debug individual
parameters in a checkpoint.

In [None]:
import tensorstore as ts

pytree_path = directory / 'ckpt-0' / 'pytree'

We can verify which keys are present in the checkpoint, which matches
information we gathered earlier from the Orbax `metadata` API.

In [None]:
ts.KvStore.open(
    {"driver": "ocdbt", "base": f"file://{pytree_path.as_posix()}"}
).result().list().result()

To read using TensorStore, we need to construct a TensorStore Spec. For this, we
can use Orbax APIs. The spec points to a base path, as well as a particular
parameter name (`a` in this case). It contains further options related to the
checkpoint format.

In [None]:
tspec = {
    'driver': 'zarr3',
    'kvstore': {
        'driver': 'ocdbt',
        'base': {'driver': 'file', 'path': pytree_path.as_posix()},
        'path': 'params.layer0.kernel',
    },
}

Finally, we can directly restore the array using TensorStore.

In [None]:
t = ts.open(ts.Spec(tspec), open=True).result()
result = t.read().result()
result

### Other Checkpointables

Finally, let's return to the other checkpointable in our example, called
`extra_properties`. Since we explicitly required the use of `JsonHandler` to save this object, this piece of the checkpoint is easily human-readable.

In [None]:
print_directory(directory / 'ckpt-0' / 'extra_properties')

In [None]:
pprint.pp(
    json.loads(
        (directory / 'ckpt-0' / 'extra_properties' / 'data.json').read_text()
    )
)