# Optimized Checkpointing with Tensorstore

Orbax relies on [Tensorstore](https://google.github.io/tensorstore/) to store
individual arrays in a checkpoint. Tensorstore provides efficient, scalable library for reading and writing arrays.

Until recently, however, our use of Tensorstore came with a few drawbacks. Chief among them was the fact that every parameter in a training state would be saved as a separate directory. This approach can be quite performant, even for models with hundreds of billions of parameters, *provided that model layers are stacked*. Otherwise, hundreds or thousands of directories may be created in the checkpoint.

This fact can lead to very slow restore times, which is undesirable in and of itself, but is particularly painful for jobs that may be preempted frequently and need to restart, for example.

While it is slightly less of a concern at save time, since writes to disk can happen asynchronously, the synchronous portion of the save can still be slow as many directories are created.

Additionally, if individual parameters are small, storage may be wasted on filesystems with minimum file sizes.

## Introducing OCDBT

The new, optimized checkpoint format provided by Orbax is backed by Tensorstore's [OCDBT](https://google.github.io/tensorstore/kvstore/ocdbt/index.html) driver (optionally-cooperative distributed B-tree).

For practical purposes, this means that we will no longer store one parameter per directory, but will aggregate many parameters into a smaller set of large files.

Empirically, we have observed substantial speed-ups in both save and restore when using OCDBT.

### Save Performance (sec)

<img src=https://orbax.readthedocs.io/en/latest/img/checkpoint/benchmarks/save_ocdbt.png>

### Restore Performance (sec)

<img src=https://orbax.readthedocs.io/en/latest/img/checkpoint/benchmarks/restore_ocdbt.png>

## Checkpoint Format

Concretely, what does the new checkpoint format look like in comparison to the old?

### Old Format

In [None]:
f = """
path/to/my/checkpoint/dir/
  0/
    state/
      layer0.param0/
        .zarray
        0.0
        0.1
        1.0
        1.1
      layer1.param0/
        .zarray
        0.0
      ...
    <another_item>/
      ...
  1/
    ...
  2/
    ...

Note: in this case, `0.0`, `0.1`, etc. provides an indication of how the array
was sharded when originally saved.
"""

### New Format

In [None]:
f = """
path/to/my/checkpoint/dir/
  0/
    state/
      checkpoint  # legacy msgpack file, stores tree structure
      tree_metadata  # (maybe) new proto file, stores tree structure
      d/  # array data stored here
        012b2c6e5c9d2a16c240a59d5f0f35c0
        056e0816bdc5496a86251e58a0ec202b
        ...
      manifest.0000000000000001
      ...
      manifest.ocdbt
    <another_item>/
      ...
  1/
    ...
  2/
    ...
"""

## Enabling OCDBT

In [None]:
import orbax.checkpoint as ocp

# Ensure that the coordinator_server is kept alive for the duration of the
# program (if not None).
# The server will only be non-None on a single process.
ocdbt_context, coordinator_server = (
    ocp.type_handlers.create_coordinator_server_and_context()
)
ocp.type_handlers.register_standard_handlers_with_options(
    use_ocdbt=True, ts_context=ocdbt_context
)

In [None]:
# Later, make sure PyTreeCheckpointHandler is initialized with `use_ocdbt=True`.
# Depending on when you read this, the option may already default to True.
ckptr = ocp.Checkpointer(ocp.PyTreeCheckpointHandler(use_ocdbt=True))

## Additional Notes

All checkpoints previously produced by Orbax in the old format will still be
readable when OCDBT is enabled. However, if a checkpoint is produced in the OCDBT format, it cannot be read if the OCDBT feature is disabled.