# Asynchronous Checkpointing

## What is this?

Orbax supports async checkpointing. This means that checkpoints can be saved in a background thread while training proceeds concurrently, leaving a minimum amount of time required for performing the blocking portion of the save.

## Why should I care?

Training jobs that would ordinarily spend time blocking for arrays to be written to disk, often via slow network connections, can proceed without waiting. This typically results in faster training progress. Furthermore, expensive devices like TPUs or GPUs which would have previously been left idle for the entire duration of the save are put to productive use for a higher proportion of the time during the training run.

Because the we only need to worry about the blocking portion of the save, checkpointing becomes significantly faster. Consider some concrete numbers:

*   On a **300M** parameter model, saving time decreased by **~40%**
*   On an **8B** parameter model, saving time decreased by **~85%**
*   On a **340B** parameter model, saving time decreased by **~97%**

In short, **async checkpointing adoption is highly encouraged**. It can result in improved training throughput and substantial resource savings.

## Usage

Some setup first:

In [1]:
import numpy as np
import orbax.checkpoint.experimental.v1 as ocp
from etils import epath

train_state = {
    'layer0': {
        'kernel': np.ones((8, 8), dtype=np.float32),
        'bias': np.ones((8,), dtype=np.float32),
    }
}

Using async checkpointing is quite simple in Orbax. For blocking save, do something like this:

In [6]:
### PREFER NOT TO USE THIS. ###
### PREFER TO USE ASYNC CHECKPOINTING INSTEAD (SEE BELOW). ###

path = epath.Path('/tmp/sync_checkpoint')
path.rmtree(missing_ok=True)

ocp.save_pytree(path, train_state)

In [9]:
!ls /tmp/sync_checkpoint

For async save, simply use `save_pytree_async(...)` instead of `save_pytree(...)`. Calling it will kick off the checkpoint save in a background thread, and return a `response` object without waiting for completion. At this point, other work can be performed in the main thread, and `response.result()` can be called to block until completion.

In [10]:
path = epath.Path('/tmp/async_checkpoint')
path.rmtree(missing_ok=True)

response = ocp.save_pytree_async(path, train_state)
### Do some other work...
response.result()

In [11]:
!ls /tmp/async_checkpoint

To save multiple checkpointables together, Orbax provides free functions in both blocking and async flavors: `save_checkpointables(...)` and `save_checkpointables_async(...)`.

And the same goes with `training.Checkpointer` class:
* `training.Checkpointer.save_pytree(...)`
* `training.Checkpointer.save_pytree_async(...)`
* `training.Checkpointer.save_checkpointables(...)`
* `training.Checkpointer.save_checkpointables_async(...)`


## Additional Details

From start to finish, async checkpointing for a train state of arrays works by first performing a blocking copy of the arrays from device to host. (If the array is already in memory, a copy will also be created in this case.) This step is necessary because the values cannot be written directly from device to storage. It also needs to be blocking because if training proceeds on the main thread, updates to the train state will result in the checkpoint being corrupted.


The examples shown above works well for PyTrees of `jax.Array`s present on TPUs or GPUs. However, Orbax provides a more generalizable API allowing you to save any object asynchronously. In practice, custom async checkpointing logic can be implemented with `CheckpointableHandler`.