# Checkpointing in a Training Loop TODO(b/409382939) add links

This guide covers the usage of the `training` module, designed around the basic
concept of a training loop. <br><br> Note: We use the `--xla_force_host_platform_device_count=8` flag to emulate multiple devices in our single-CPU environment.

In [None]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

from orbax.checkpoint import v1 as ocp
from etils import epath

training = ocp.training

## Getting Started

Let's dive in with a simple training loop example.

We will use the `Checkpointer` API provided by the `training` module. The
`Checkpointer` must be configured with a **root directory**, which represents a
working directory where all checkpoints will be saved throughout the course of
an experiment.

The root directory is not itself a checkpoint; rather, it is a *container* of
checkpoints.

In [None]:
root_directory = epath.Path('/tmp/my-checkpoints-1')
root_directory.rmtree(missing_ok=True)

We will assume the existence of a training state containing the keys `params`
and `opt_state`, which are trees of `jax.Array`. The state also contains a key
`step`, which is represented as an integer.

Note that the arrays in the state will be sharded using a fully-replicated
sharding, but the example would work equally well with any other sharding.

In [None]:
import jax
import numpy as np

pytree = {
    'params': {
        'layer0': np.arange(16).reshape((8, 2)),
    },
    'opt_state': [np.arange(16)],
}
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('x',)), jax.sharding.PartitionSpec()
)
pytree = jax.tree.map(lambda x: jax.device_put(x, sharding), pytree)
pytree['step'] = 0

Let's set up our fake training loop. We will add a "training step function" that
just increments the step. In reality, this would also compute gradients and
update model parameters.

In [None]:
def train_step(state):
  state['step'] += 1
  return state

Now, we can create a `Checkpointer` to begin saving a sequence of checkpoints.

In [None]:
with training.Checkpointer(root_directory) as ckptr:
  num_steps = 10
  for step in range(num_steps):
    saved = ckptr.save_pytree(step, pytree)
    assert saved
    pytree = train_step(pytree)

Calling `load` with no arguments will automatically restore the latest saved
checkpoint.

In [None]:
with training.Checkpointer(root_directory) as ckptr:
  print(ckptr.load_pytree())

## Checkpointer APIs

Now, let's get into a bit more detail about how to interact with the
`Checkpointer`.

In general, we recommend using `Checkpointer` as a context manager, as shown in
the examples below.

```
with Checkpointer(...) as ckptr:
  ...
```

You can use it without the context manager, but make sure to call `close()`
before the program exits to ensure the completion of any outstanding operations
and to ensure resource cleanup.

```
ckptr = Checkpointer(...)
...
ckptr.close()
```

### Saving

Calling `save` in the training loop automatically calls `should_save`, which
determines whether or not a checkpoint should be saved at the given step, based
on the configured saving frequency. If a save is performed `save` returns
`True`; otherwise it returns `False`.

Whether or not a save should be performed can be controlled via
`SaveDecisionPolicy`.

By default, `ContinuousCheckpointingPolicy` is configured, which always saves
*unless* a save is already ongoing.

Other pre-configured policies include: - `FixedIntervalPolicy`: Saves every `n`
steps. - `InitialSavePolicy`: Saves on the first step. -
`PreemptionCheckpointingPolicy`: Saves on a step where a preemption signal is
received by the JAX distributed system. This is useful for saving whenever a job
is automatically restarted by the system. - `SpecificStepsPolicy`: Saves on the
specific set of configured steps.

The policies can be used in conjunction via `AnySavePolicy`, which performs a
save if any of the sub-policies would perform a save at the given step.

You may always implement your own policy. See `SaveDecisionPolicy` for details.

In [None]:
root_directory = epath.Path('/tmp/my-checkpoints-2')
root_directory.rmtree(missing_ok=True)
with training.Checkpointer(
    root_directory,
    save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(3),
) as ckptr:
  for step in range(10):
    ckptr.save_pytree(step, pytree)

In [None]:
!ls {root_directory}

Now let's exercise some additional save features. These include:

-   `custom_metadata`: A JSON-formatted object intended for storing any
    user-specified properties. Custom metadata can be specified at both the root
    directory level and the individual checkpoint level. At the root level, the
    metadata should pertain to all checkpoints. For example, the experiment name
    is shared by all checkpoints within the root directory, while a property
    like `is_final` has different values for different checkpoints.
-   `override`: Deletes and overwrites any existing checkpoint at the provided
    step.
-   `force`: Performs a save at the current step regardless of what would
    ordinarily be dictated by the `SaveDecisionPolicy`.
-   `metrics`: A JSON-formatted object storing evaluation metrics for the
    current step. This can be useful for ordering and garbage collecting
    checkpoints; more on that below.

In [None]:
root_directory = epath.Path('/tmp/my-checkpoints-3')
root_directory.rmtree(missing_ok=True)
with training.Checkpointer(
    root_directory,
    save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(3),
    custom_metadata={'experiment_name': 'my-experiment'},
) as ckptr:
  num_steps = 10
  for step in range(num_steps):
    is_final = step == num_steps - 1
    ckptr.save_pytree(
        step,
        pytree,
        metrics={'accuracy': 0.85},
        custom_metadata={'is_final': is_final},
        force=is_final,
    )

In [None]:
!ls {root_directory}

We will learn more about how to access some of the attributes that we saved in
the sections below.

### Querying Available Checkpoints

We can learn about which checkpoints are available by using `latest` and
`checkpoints`.

In [None]:
ckptr = training.Checkpointer(root_directory)

Each of these APIs returns `CheckpointMetadata` objects, which store a number of
properties describing each checkpoint. Some metadata properties are more
expensive to retrieve than others though. The `latest` and `checkpoints` APIs
just store a limited set of cheaply-retrievable properties, like the `step`.
These APIs also make use of caching as much as possible, to avoid repeated disk
reads.

In [None]:
# Returns CheckpointMetadata or None, if no checkpoints are found.
latest = ckptr.latest
assert latest is not None
print(latest.step)
print(latest)

### Inspecting Checkpoint Metadata

In many cases, we wish to cheaply gain information about checkpoint properties
without loading the entire model. Using the `pytree_metadata` API, we can learn
about the tree structure of our PyTree, as well as information about each array
in the tree.

Like loading methods, metadata methods accept either no argument, or an argument
representing the step to retrieve metadata for.

For example:

In [None]:
# Loads metadata from the latest checkpoint.
ckptr.pytree_metadata()
# Loads metadata corresponding to the first step.
ckptr.pytree_metadata(ckptr.checkpoints[0])
# Loads metadata from a specific integer step.
ckptr.pytree_metadata(3)

print()

Let's examine the output.

In [None]:
ckptr.pytree_metadata()

Let's dig into a few specific fields. In particular, we can access
`custom_metadata` and `metrics` that were saved previously.

In [None]:
print(ckptr.pytree_metadata().metrics)
print(ckptr.pytree_metadata().custom_metadata)

Within the metadata object, there is another field called `metadata`. This
stores information specific to the structure of the object we saved. In this
case, it describes the structure of the PyTree and array properties.

In [None]:
import pprint

pprint.pprint(ckptr.pytree_metadata().metadata)

Finally, we can also retrieve the root-level metadata. Recall that this metadata
is intended to describe the entire sequence of checkpoints, rather than just a
single checkpoint.

In [None]:
ckptr.root_metadata()

### Garbage Collection

Garbage collection is important to avoid accumulating too many old checkpoints
and running out of disk space.

To control this behavior, we have an object (fairly similar to
`SaveDecisionPolicy`) above, called `PreservationPolicy`. This class tells the
`Checkpointer` which checkpoints should be protected from garbage collection.

By default, the `PreservationPolicy` defaults to `PreserveAll` (no garbage
collection), because we do not want users to lose any valuable data. However,
for anything other than toy use cases,
you should make sure to configure a more restrictive `PreservationPolicy`.

Our `Checkpointer` below is implicitly configured with `PreserveAll`, so all 10
steps should be present at first.

In [None]:
root_directory = epath.Path('/tmp/my-checkpoints-gc')
root_directory.rmtree(missing_ok=True)
with training.Checkpointer(root_directory) as ckptr:
  for step in range(10):
    ckptr.save_pytree(step, pytree)
  print([c.step for c in ckptr.checkpoints])

If we create a new `Checkpointer` with a new `PreservationPolicy` configured,
the same 10 checkpoints should still be present. Once we
save a new step, any steps indicated for cleanup by the policy will be removed.

In [None]:
with training.Checkpointer(
    root_directory,
    preservation_policy=training.preservation_policies.AnyPreservationPolicy([
        training.preservation_policies.LatestN(2),
        training.preservation_policies.EveryNSteps(4),
    ]),
) as ckptr:
  print([c.step for c in ckptr.checkpoints])
  assert ckptr.latest.step == 9
  ckptr.save_pytree(10, pytree)
  print([c.step for c in ckptr.checkpoints])

Typically, the latest `n` checkpoints are preserved (`LatestN`) along with
checkpoints at some regular, but longer interval (`EveryNSteps` or
`EveryNSeconds`). The latter can be useful for performing evals and maintaining
a record of the experiment's progress.

### Loading

As we saw above with the `metadata` methods, we can load in a variety of ways.

In [None]:
# Loads from the latest checkpoint.
ckptr.load_pytree()
# Loads the first available checkpoint in the root directory.
ckptr.load_pytree(ckptr.checkpoints[0])
# Loads from a specific integer step.
ckptr.load_pytree(4)

print()

When dealing with PyTrees, particularly PyTrees with sharded `jax.Array` leaves,
it is important for any non-toy use cases to specify an "abstract PyTree" that
is used to guide restoration. Checkpoints are complicated objects. The abstract
PyTree acts as an assertion to verify that the checkpoint has structure you
expect and that arrays have the correct shapes.

The abstract PyTree can also be used to instruct Orbax how to load the PyTree.
The `dtype` property may be used to cast arrays, while the `sharding` property
is used to correctly place array shards on devices.

We should define an abstract tree with the same structure as the tree we
originally saved. For the leaves, we specify different shardings than we
originally saved with, and different dtypes as well, causing the loaded arrays
to be cast and resharded when loading.

In [None]:
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('x',)), jax.sharding.PartitionSpec('x')
)
abstract_pytree = {
    'params': {
        'layer0': jax.ShapeDtypeStruct((8, 2), np.float32, sharding=sharding),
    },
    'opt_state': [jax.ShapeDtypeStruct((16,), np.float32, sharding=sharding)],
    'step': 0,
}

In [None]:
ckptr.load_pytree(None, abstract_pytree)

More details on working with PyTrees in such a manner can be found at
TODO(link).

### Checkpointables and Dataset Checkpointing TODO(link)

`Checkpointer` supports the concept of `checkpointables`. See the documentation
on "Working with Checkpointables" for more information.

In simplified terms, a "checkpointable" refers to a distinct piece of the
overall checkpoint, which can be thought of as a bundle. The `PyTree` training
state is one such checkpointable. The dataset iterator is another. Checkpointing
the position of the dataset iterator can be useful to ensure training resumes
where we were interrupted not just for the model parameters, but for the data as
well.

We can see this concept in concrete terms using a Grain dataset iterator. See
[Grain documentation](https://google-grain.readthedocs.io/en/latest/index.html)
for more information. For our purposes, we can construct a toy dataset iterator.

In [None]:
import grain

dataset = iter(grain.MapDataset.range(30).batch(3).map(lambda x: x.tolist()))

pytree = {
    'params': {
        'layer0': np.arange(16).reshape((8, 2)),
    },
    'opt_state': [np.arange(16)],
}
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('x',)), jax.sharding.PartitionSpec()
)
pytree = jax.tree.map(lambda x: jax.device_put(x, sharding), pytree)
pytree['step'] = 0

In [None]:
def train_step(state, ds):
  next(ds)  # Advances the dataset iterator
  state['step'] += 1
  return state

We can save ten checkpoints in sequence, including the dataset iterator,
advancing the iterator once per step. At each step, the dataset iterator points
to `[step*3, step*3+1, step*3+2]`.

In [None]:
root_directory = epath.Path('/tmp/my-checkpoints-4')
root_directory.rmtree(missing_ok=True)
num_steps = 10

with training.Checkpointer(root_directory) as ckptr:
  for step in range(num_steps):
    ckptr.save_checkpointables(step, dict(pytree=pytree, dataset=dataset))
    pytree = train_step(pytree, dataset)

After loading at step `5`, `new_dataset` points to position `5` of the iterator.

In [None]:
new_dataset = iter(
    grain.MapDataset.range(30).batch(3).map(lambda x: x.tolist())
)
print(f'Initial position: {next(new_dataset)}')

with training.Checkpointer(root_directory) as ckptr:
  ckptr.load_checkpointables(5, dict(pytree=None, dataset=new_dataset))
print(f'Loaded from checkpoint: {next(new_dataset)}')

It's important to note that dataset loading is stateful. You need to instantiate
an iterator object, pass it to `load_checkpointables`, and the checkpoint state
will be restored into the iterator state of the dataset object.

## Training with MNIST Data

In the past examples, we've used incomplete data to demonstrate Orbax functionality. Here, it's useful to demonstrate checkpointing during training loops involving real-world training data.<br> We'll define our own loss function and simulate a MNIST model training loop. These functions are pulled from the [Flax docs](https://flax.readthedocs.io/en/latest/mnist_tutorial.html).

In [None]:
import flax
from flax import nnx
import jax
import optax
from orbax.checkpoint.experimental.v1._src.training.model_helpers import DotReluDot  # Don't rely on this module!

flax.config.update('flax_always_shard_variable', False)


def loss_fn(model: DotReluDot, batch: dict) -> tuple[jax.Array, jax.Array]:
  logits = model(batch['image'])
  loss = optax.softmax_cross_entropy_with_integer_labels(
      logits=logits, labels=batch['label']
  ).mean()
  return loss, logits


@nnx.jit
def train_step(
    model: DotReluDot,
    optimizer: nnx.ModelAndOptimizer,
    batch: dict,
):
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  optimizer.update(grads)

It's worth mentioning that `DotReluDot` is an example layer that subclasses `nnx.Module` and is capable of incorporating sharding information to accelerate the training process. Let's recreate a sharded environment for sake of demonstration.

In [None]:
from jax.sharding import Mesh
import numpy as np

print(f'You have 8 “fake” JAX devices now: {jax.devices()}')

mesh = Mesh(
    devices=np.array(jax.devices()).reshape(4, 2), # Can be customized to run across multiple devices
    axis_names=('data', 'model'),
)
print(mesh)

Let's go ahead and import the MNIST dataset using the `Grain` library:

In [None]:
%%capture
from orbax.checkpoint.experimental.v1._src.training.model_helpers import create_dataset

batch_size = 32

train_ds = create_dataset('train', batch_size)
test_ds = create_dataset('test', batch_size)

Before we start the training loop, let's initialize an abstract version of our checkpoint state, without initializing any real array values. We do this for both the model and optimizer.

In [None]:
# Create an abstract model
abs_model = nnx.eval_shape(lambda: DotReluDot(1024, rngs=nnx.Rngs(0)))
abs_model_state = nnx.state(abs_model)
abs_model_state = jax.tree.map(
    lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
    abs_model_state,
    nnx.get_named_sharding(abs_model_state, mesh),
)

# Create an abstract optimizer
abs_model_tmp = DotReluDot(1024, rngs=nnx.Rngs(0))
abs_optimizer = nnx.eval_shape(
    lambda: nnx.ModelAndOptimizer(abs_model_tmp, optax.adamw(0.005, 0.9))
)

# Store the abstract model and optimizer as one object
abs_state = {
    'params': abs_model_state,
    'optimizer': nnx.state(abs_optimizer, nnx.optimizer.OptState),
}

Now, we can define our main `train()` function, throughout which we will demonstrate checkpointing. A couple notes: <br>

*   We use `FixedIntervalPolicy` so that our checkpoint is saved every 60 training steps.
*   We use `nnx.state()` to convert the model object (DotReluDot) and optimizer to a checkpointable PyTree, which can then be checkpointed with `ckptr.save_pytree()`

When actually loading a checkpoint, we do the following:
* If a checkpoint exists in our current checkpoints directory, we restore the latest one.
* If no checkpoints exist in our directory and the user provides a path to another directory, we load the checkpoint saved at that path.
* If no checkpoints have been saved, this indicates we're entering the training loop for the first time, so we don't restore a model or optimizer.

In [None]:
from typing import Callable
from orbax.checkpoint import v1 as ocp

training = ocp.training
root_directory = epath.Path('/tmp/my-checkpoints-5')
root_directory.rmtree(missing_ok=True)
train_steps = 1200
learning_rate = 0.005
momentum = 0.9
model_depth = 1024


def init_or_restore(
    ckptr: training.Checkpointer, abs_state: dict, ckpt_path: str | None
) -> tuple[DotReluDot, nnx.ModelAndOptimizer, int]:
  model = DotReluDot(model_depth, rngs=nnx.Rngs(0))
  optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate, momentum))

  if ckpt_path or ckptr.latest:
    # If a checkpoint already exists, we restore it.
    if ckptr.latest:
      loaded_state = ckptr.load_pytree(abstract_pytree=abs_state)
    else:
      loaded_state = ocp.load_pytree(path=ckpt_path, abstract_pytree=abs_state)
    # Update model and optimizer separately
    nnx.update(model, loaded_state['params'])
    nnx.update(optimizer, loaded_state['optimizer'])
    last_step = loaded_state['optimizer']['step'].value
  else:
    last_step = 0

  return model, optimizer, last_step


def train(fail_fn: Callable[[int], bool] = None, ckpt_path: str = None):
  # If step_to_restore is provided, use that to load this specific checkpoint.
  # Otherwise, load the latest checkpoint; if none exists, start from scratch.
  with training.Checkpointer(
      root_directory,
      save_decision_policy=training.save_decision_policies.FixedIntervalPolicy(
          60
      ),
  ) as ckptr:
    model, optimizer, last_step = init_or_restore(ckptr, abs_state, ckpt_path)

    # Main training loop
    with mesh:
      for step, batch in enumerate(train_ds, start=last_step + 1):
        if step >= train_steps or (fail_fn and fail_fn(step)):
          break

        train_step(model, optimizer, batch)
        # Save the combined state
        state = {
            'params': nnx.state(model),
            'optimizer': nnx.state(optimizer, nnx.optimizer.OptState),
        }
        ckptr.save_pytree(step, state)

We invoke the `train()` function. In doing so, we demonstrate a "failure" causing our main training loop to stop at step 600.

In [None]:
# Call train()
def simulate_failure(step):
    return step >= 600

train(fail_fn=simulate_failure)

We continue training until step 900. Our last saved checkpoint is restored automatically (step 540, because this is the last save). We now continue until step 900.

In [None]:
%%capture
train(fail_fn = lambda x: x >= 900)

Now, we are at step 900, let's assume we want to fine-tune the model. We restore from our latest checkpoint, which is 840, and continue training until we reach step 1200.

In [None]:
train(ckpt_path=root_directory / '840')
# Here, we can inspect specific model parameters and show the checkpoints saved in directory
# !ls {root_directory}