# Checkpointing with Orbax

<a href="http://colab.research.google.com/github/google/orbax/blob/main/orbax//checkpoint/orbax_checkpoint.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Take a look at https://github.com/google/orbax/blob/main/docs/checkpoint.md for additional documentation on Orbax checkpointing APIs.



In [None]:
!pip install orbax
!pip install nest_asyncio
# Needed to enable asyncio in colab environment.
import nest_asyncio
nest_asyncio.apply()

In [None]:
import time
from collections import namedtuple
import jax
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit
import numpy as np
import os
import portpicker
from etils import epath
from flax import traverse_util
import orbax.checkpoint as orbax

In [None]:
jax.config.update('jax_array', True)
port = portpicker.pick_unused_port()
jax.distributed.initialize(f'localhost:{port}', num_processes=1, process_id=0)

In [None]:
devices = np.asarray(jax.devices())
mesh = Mesh(devices, ('data',))
axes = PartitionSpec('data',)

In [None]:
directory = epath.Path('checkpoint_data')
directory.mkdir(parents=True, exist_ok=True)

In [None]:
print(directory)

# A Basic Example

In [None]:
basic_dir = directory / 'basic'
basic_dir.mkdir()

In [None]:
options = orbax.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=3)
mngr = orbax.CheckpointManager(
    basic_dir, orbax.Checkpointer(orbax.PyTreeCheckpointHandler()), options)

The CheckpointManager is constructed with `Checkpointer` and `CheckpointHandler` objects. We will discuss these further below, but at a high level, the `Checkpointer` controls the *manner in which* the object is saved while the `CheckpointHandler` deals with type-specific logic and provides extra options for customization.

First, we'll need to perform some setup to create a train state that mimics in a very basic form how a real model might look. We use jax.Array for this is example, but it is also possible to use scalars or numpy arrays, assuming they are *replicated* or *not sharded*.

In [None]:
def create_initial_state():
  state = {
      'layer_0': {
          'bias': np.zeros(16),
          'kernel': np.arange(16),
      },
      'layer_1': {
          'bias': np.zeros(8),
          'kernel': np.arange(8),
      },
  }

  create_sharded_array = pjit(lambda x: x, in_axis_resources=None, out_axis_resources=axes)
  with Mesh(mesh.devices, mesh.axis_names):
    state = jax.tree_map(create_sharded_array, state)

  state['step'] = 0
  return state

def state_shape(state):
  return jax.eval_shape(lambda: state)

Here's our mock training step. At every step, we save a checkpoint. Since we specified `max_to_keep=3` in our options, we expect to only have the latest 3 checkpoints at the end of training.

In [None]:
state = create_initial_state()

def train(step, state):
  # do some training, modify state
  state['step'] = step

  mngr.save(step, state)

  return state

for step in range(5):
  state = train(step, state)

print(f'Steps: {mngr.all_steps()}')

After saving, we can restore the latest step. Since this `CheckpointManager` is
only managing a single item, the arguments provided to `restore` should just
match the `state` (see the below cell). See 'Multi-Object Checkpointing' for
further details on how these arguments are used when there are multiple items.



In [None]:
dummy_state = state_shape(state)
restore_args = jax.tree_util.tree_map(
    lambda _:
        orbax.ArrayRestoreArgs(
            restore_type=jax.Array,
            mesh=mesh, 
            mesh_axes=axes), 
    state)
restore_args['step'] = orbax.RestoreArgs(restore_type=int)
mngr.restore(mngr.latest_step(), items=dummy_state, 
             restore_kwargs={'restore_args': restore_args})

We can achieve the same result by just using `Checkpointer`.

In [None]:
ckptr = orbax.PyTreeCheckpointer()
# restore_args can also be constructed "automatically" from a target PyTree.
restore_args = orbax.checkpoint_utils.restore_args_from_target(mesh, 
                    state, jax.tree_util.tree_map(lambda _: axes, state))
print(restore_args)
print()
# CheckpointManager saved the checkpoint under /<directory>/<step>/default.
# 'default' is used as the subdirectory name when the CheckpointManager has a
# single item. See below for information on how to use multiple items or how
# to customize this name.
ckpt_path = mngr.directory / str(mngr.latest_step()) / 'default'
ckptr.restore(ckpt_path, item=dummy_state, restore_args=restore_args)

# Tracking Metrics

When saving checkpoints across many steps, we are often interested in keeping only the best *n* checkpoints based on some metric.



In [None]:
ckpt_metrics_dir = directory / 'ckpt_with_metrics'
ckpt_metrics_dir.mkdir()

In [None]:
options = orbax.CheckpointManagerOptions(
    max_to_keep=3, best_fn=lambda metrics: metrics['loss'], best_mode='min')
mngr = orbax.CheckpointManager(
    ckpt_metrics_dir, orbax.AsyncCheckpointer(orbax.PyTreeCheckpointHandler()),
    options)

In [None]:
def get_metrics(step):
  return {'accuracy': 1.0, 'loss': step * 1.5}


def train(step, state):
  # do some training, modify state
  metrics = get_metrics(step)
  state['step'] = step
  state_save_args = jax.tree_map(lambda _: orbax.SaveArgs(), state)
  mngr.save(step, state, metrics=metrics)
  return state


state = create_initial_state()
for step in range(5):
  state = train(step, state)

mngr.wait_until_finished()
print(f'Steps: {mngr.all_steps()}')

Now that we track metrics, we will only keep the best checkpoints saved, while deleting the rest. Since our loss is getting progressively worse, and best_mode='min', we will keep the first checkpoints, rather than the most recent ones. The metrics may be an arbitrary PyTree; it is up to you to define how it is interpreted.

# Multi-Object Checkpointing

In the following example, we will show how to checkpoint multiple different objects at once using CheckpointManager.

In [None]:
multi_dir = directory / 'multi'
multi_dir.mkdir()

In [None]:
# Save every 3 steps.
options = orbax.CheckpointManagerOptions(save_interval_steps=3)
mngr = orbax.CheckpointManager(
    multi_dir, {
        'state': orbax.Checkpointer(orbax.PyTreeCheckpointHandler()),
        'metadata': orbax.Checkpointer(orbax.JsonCheckpointHandler())
    }, options)

We can save multiple objects simply by specifying a Checkpointer/CheckpointHandler combination for each of them. While Checkpointers can typically be used with any CheckpointHandler, you'll need to ensure that your object can be saved and restored by the given CheckpointHandler.

In [None]:
NUM_STEPS = 5

def train(step, state):
  # do some training, modify state
  state['step'] = step
  metadata['timestamp'] = time.time()

  # save with default arguments for all params except 'step', which uses flax
  state_save_args = jax.tree_map(lambda _: orbax.SaveArgs(), state)
  state_save_args['step'] = orbax.SaveArgs(aggregate=True)

  # with `force` a save will be performed even if it would not ordinarily do so,
  # based on the step number.
  force = False
  if step == NUM_STEPS - 1:
    force = True
  save_performed = mngr.save(
      step,
      items={
          'state': state,
          'metadata': metadata
      },
      # save_kwargs must be a dict with the same keys as items.
      # not all keys in items have to be provided, in which case default kwargs
      # are used each value must be a dict with keyword args passed to the
      # underlying CheckpointHandler for that item (see CheckpointManager
      # object construction)
      save_kwargs={'state': {
          'save_args': state_save_args
      }}, 
      force=force)
  print(f'Save performed: {save_performed}')

  return state


state = create_initial_state()
metadata = {
    'version': 1.1,
    'exp_name': 'my_test_exp',
    'timestamp': 0,
}
for step in range(NUM_STEPS):
  state = train(step, state)

print(f'Steps: {mngr.all_steps()}')

We also have extra arguments to customize saving for the `step` parameter within the train state. Because this is only an integer, using the default storage mechanism, [Tensorstore](https://google.github.io/tensorstore/) might be somewhat overkill. It would be more efficient to store it, along with any other similarly small parameters, into a single file using [flax.serialization](https://flax.readthedocs.io/en/latest/flax.serialization.html). The parameters passed here should match the optional arguments for the provided `CheckpointHandler`.

Let's try restoring the latest step. In the `items` argument, make sure to provide a key for each of the items you want to restore. The values of this `items` dictionary are what gets provided to `CheckpointHandler.restore` as the argument of `item`. As a result, it may be `None` if the `CheckpointHandler`
does not depend on the value of `item`.

The `restore_kwargs` argument should be a dictionary with matching top-level keys, but keys can be omitted if no arguments are needed. The values of the
`restore_kwargs` dictionary are provided as keyword args to the matching `CheckpointHandler.restore`. For example, if a `CheckpointHandler` subclass called `FooBarCheckpointHandler` takes kwargs `foo` and `bar` (in addition to standard args like `path` and `item`, the restore_kwargs for `CheckpointManager` would be: 

```
restore_kwargs = {'foobar_item': {'foo': ???, 'bar': ???}}
```

In [None]:
mngr.restore(mngr.latest_step(), 
             # Safe to provide None for `items` values because they are not
             # really needed in this case. `restore_args` is needed though.
             items={'state': None, 'metadata': None}, 
             restore_kwargs={'state': {'restore_args': restore_args}})

If we skip providing 'metadata' in `items`, it will not be returned in the result.

In [None]:
mngr.restore(mngr.latest_step(), 
             items={'state': None}, 
             restore_kwargs={'state': {'restore_args': restore_args}})

# Asynchronous Checkpointing

You may be wondering what the point of Checkpointer is, and why it is separate from CheckpointHandler. The reason for this is that Checkpointer may have different subclasses, each of which handles certain common logic that we would not want to reimplement for every different CheckpointHandler.

This logic may include ensuring save operation atomicity and managing a background thread for asynchronous saving.

In [None]:
async_dir = directory / 'async'
async_dir.mkdir()

In [None]:
mngr = orbax.CheckpointManager(
    async_dir, {
        'state': orbax.AsyncCheckpointer(orbax.PyTreeCheckpointHandler()),
        'metadata': orbax.Checkpointer(orbax.JsonCheckpointHandler())
    })

With the above configuration, our metadata will be saved synchronously, but our state will be saved in a background thread. After calling save, all files for the state may not have been written yet. In the meantime, we may continue training. 

However, we need to call wait_until_finished before ending our training program to block for any outstanding save operations. Calling save again will do this automatically - you cannot have multiple saves for multiple steps running concurrently.

In [None]:
def train(step, state):
  # do some training, modify state
  state['step'] = step
  metadata['timestamp'] = time.time()

  mngr.save(
      step,
      items={
          'state': state,
          'metadata': metadata
      })

  return state


state = create_initial_state()
metadata = {
    'version': 1.1,
    'exp_name': 'my_test_exp',
    'timestamp': 0,
}
for step in range(NUM_STEPS):
  state = train(step, state)

mngr.wait_until_finished()
print(f'Steps: {mngr.all_steps()}')

# Checkpointer

`Checkpointer` allows you to save an object to a specified directory without providing any of the structure or extra features that `CheckpointManager` does.

In [None]:
state = create_initial_state()
ckptr = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())

existing_checkpoint_dir = multi_dir / '0' / 'state'
restore_args = jax.tree_map(lambda _: orbax.ArrayRestoreArgs(mesh=mesh, mesh_axes=axes), state)
restore_args['step'] = orbax.RestoreArgs()
restored = ckptr.restore(existing_checkpoint_dir, restore_args=restore_args)

d = orbax.utils.to_state_dict(restored)
for k, v in traverse_util.flatten_dict(d, keep_empty_nodes=True).items():
  k_str = '/'.join(k)
  print(f'{k_str}: {v}')

As shown in this example, this object may be useful for restoring a pre-exisiting checkpoint without requiring a `CheckpointManager`.

# CheckpointHandler

**Important: `CheckpointHandler` should not be used independently of `Checkpointer` or `CheckpointManager`.**


As shown above, `PyTreeCheckpointHandler` provides support for most standard use-cases, where a `PyTree` consisting of jax.Array, scalars, or numpy arrays can be saved using a combination of Tensorstore and msgpack.

However, we may also wish to provide custom support for a novel type or storage medium. Below is an example of how `CheckpointHandler` can be overridden to support a sharded file concept.

In [None]:
from typing import Any, List, Mapping, Optional
import asyncio
from concurrent import futures
import datetime

In [None]:
executor = futures.ThreadPoolExecutor(max_workers=2)
def ctime():
  return datetime.datetime.now(
      tz=datetime.timezone.utc
  ).isoformat(sep=' ', timespec='milliseconds')

In [None]:
class ShardedFile:

  def __init__(self):
    self.devices = [d for d in jax.devices() if d.host_id == jax.process_index()]

  def get_device_shard(self, device):
    return f'Data associated with device: {device}... '

  def get(self):
    data = f'Process {jax.process_index()}: '
    for d in self.devices:
      data += self.get_device_shard(d)
    return data

class ShardedFileWriter:

  async def copy(self):
    await asyncio.sleep(1)
    self.local_file = self.sharded_file.get()
    print(f'{ctime()}: done copy')

  def commit(self):
    path = epath.Path(self.path) / str(jax.process_index())

    def _write():
      time.sleep(5)
      path.write_text(self.local_file)
      print(f'{ctime()}: done commit')

    future = executor.submit(_write)
    print(f'{ctime()}: started commit')
    return future
    
  def __init__(self, path: str, sharded_file: ShardedFile):
    self.path = path
    self.sharded_file = sharded_file

In [None]:
class ShardedFileCheckpointHandler(orbax.async_checkpoint_handler.AsyncCheckpointHandler):

  async def async_save(self, directory: epath.Path, item: ShardedFile) -> List[futures.Future]:
    """Saves the given item.

    Args:
      directory: save location directory.
      item: nested dictionary.
    """
    fw = ShardedFileWriter(os.fspath(directory), item)
    await fw.copy()
    return [fw.commit()]

  def save(self, directory: epath.Path, item: Any, *args, **kwargs):
    async def async_save(*args, **kwargs):
      commit_futures = await self.async_save(*args, **kwargs)
      # Futures are already running, so sequential waiting is equivalent to
      # concurrent waiting.
      for future in commit_futures:
        future.result()  # Block on result.
    asyncio.run(async_save(directory, item, *args, **kwargs))
    orbax.utils.sync_global_devices('ShardedFileCheckpointHandler:save')

  def restore(self,
              directory: epath.Path,
              item: Optional[bytes] = None) -> str:
    del item
    path = directory / str(jax.process_index())
    return path.read_text()

  def structure(self, directory: epath.Path) -> int:
    return len(list(directory.iterdir()))

In [None]:
handler_dir = directory / 'sync_handler'
async_handler_dir = directory / 'async_handler'

file = ShardedFile()

In [None]:
checkpointer = orbax.Checkpointer(ShardedFileCheckpointHandler())
checkpointer.save(handler_dir, file)
print(f'{ctime()}: done save')

In [None]:
checkpointer = orbax.AsyncCheckpointer(ShardedFileCheckpointHandler())
checkpointer.save(async_handler_dir, file)
print(f'{ctime()}: processing save')
# do something else
checkpointer.wait_until_finished()
print(f'{ctime()}: done save')

In [None]:
executor.shutdown()

# TypeHandler

`TypeHandler` as a concept exists in conjunction with `PyTreeCheckpointHandler` to provide additional customization options for advanced users with custom types they wish to save as part of a PyTree.

By default, Orbax includes handler implementations for `jax.Array`, `np.ndarray`, scalars, strings, and others. These are simply the leaf types supported by `PyTreeCheckpointHandler`.

By implementing a subclass of `TypeHandler` and registering a type using `register_type_handler`, we can add support for a novel type.

In [None]:
class Foo():
  # Realistically we would use a dataclass for this, but this is just for
  # illustration purposes.
  def __init__(self, a, b, c):
    self.a = a
    self.b = b
    self.c = c

  def __str__(self):
    return f'{self.a};{self.b};{self.c}'

In [None]:
class FooHandler(orbax.type_handlers.TypeHandler):

  async def serialize(
      self,
      value: Foo,
      info: orbax.type_handlers.ParamInfo,
      args: Optional[orbax.SaveArgs] = None) -> List[orbax.future.Future]:
    # A more sophisticated implementation would make this write asynchronous.
    (info.path / 'data.txt').write_text(str(value))
    return []

  async def deserialize(
      self,
      info: orbax.type_handlers.ParamInfo,
      args: Optional[orbax.RestoreArgs] = None) -> Foo:
    entries = (info.path / 'data.txt').read_text().split(';')
    assert len(entries) == 3
    return Foo(*entries)

In [None]:
type_handler_dir = directory / 'type_handler'
orbax.type_handlers.register_type_handler(Foo, FooHandler(), override=True)

In [None]:
foo_tree = {
    'one_foo': Foo(2, 4, 6),
    'two_foo': Foo(1, 2, 3),
}
ckptr = orbax.Checkpointer(orbax.PyTreeCheckpointHandler())
ckptr.save(type_handler_dir, foo_tree)

In [None]:
restore_args = jax.tree_util.tree_map(lambda _: orbax.RestoreArgs(restore_type=Foo), foo_tree)
restored = ckptr.restore(type_handler_dir, restore_args=restore_args)
jax.tree_util.tree_map(str, restored)

# Transformations

A key component of the Orbax checkpointing library is PyTree [transformations](https://github.com/google/orbax/tree/main/orbax/checkpoint/transform_utils.py). While this functionality is designed to be as flexible as possible, it can be used to support:


*   Partial restoration of checkpoints where some keys can be dropped and replaced with randomly initialized values.
*   Checkpoint version compatiblity where newer checkpoints may have different structures than old ones.
*   Mappings over keys, including one-to-one, many-to-one, one-to-many, and many-to-many transformations.

The transformations library is discussed in detail [here](https://github.com/google/orbax/blob/main/docs/checkpoint.md#transformations), so we will avoid discussing all possible features and will instead focus on concrete examples.



Let's start with a simple example first.

In [None]:
from orbax.checkpoint.transform_utils import Transform

In [None]:
original_tree = {
  'a': 1,
  'b': {
    'c': 5,
    'd': [0, 1, 2, 3]
  },
  'f': 2,
}
transformations = {
  'a1': Transform(original_key='a'),  # rename
  'b': {
    # doubled original
    'c': Transform(value_fn=lambda v: v * 2)
    # drop b/d
  },
  # one to many mapping
  'x': Transform(multi_value_fn=lambda kv: kv['b']['d'][0]),
  'y': Transform(multi_value_fn=lambda kv: kv['b']['d'][1:]),
  # many to one mapping
  'z': Transform(multi_value_fn=lambda kv: kv['a'] * 2 + sum(kv['b']['d'])),
}
new_tree = {  # defines the structure of the result
  'a1': ...,
  'b': {
    'c': ...,
  },
  'x': ...,
  'y': ...,
  'z': ...,
  # 'f' defined in original_tree and new_tree, but not in transforms. Value
  # carried over from original_tree.
  'f': ...,
  # This value matters since it is not present in original_tree or
  # transformations, so the value here will simply be preserved in the result.
  'g': 5,
}

orbax.apply_transformations(original_tree, transformations, new_tree)

An important rule of thumb to remember is that the output of `apply_transformations` will always match the structure of `new_tree`. This provides an easy way to know exactly what your result will look like after applying transformations.

Often, users have to deal with very large PyTrees and it would become very burdensome to specify transformations for large numbers of keys. Our library provides two solutions to this: regexes and implicit transformations.

Implicit transformations have been alluded to in our first example, but let's focus on them specifically.

In the following example, transformations is an empty dictionary, so we rely exclusively on implicit transformations. Key/value pairs present in `new_tree` but not in `original_tree` simply remain in place in the result, while a key present in `original_tree` but not in `new_tree` will be dropped from the result.

In [None]:
original_tree = {
  'a': 1,
  'b': {
    'c': 5,
    'd': 6,
    'e': 7,
  },
  'f': 2,
}
transformations = {}
new_tree = {
  'a': ...,
  'b': {
    'c': ...,
  },
  'f': ...,
  'g': {
      'h': 3,
      'i': 4,
  },
}

orbax.apply_transformations(original_tree, transformations, new_tree)

We can also change the `default_to_original` argument to customize the behavior when keys are unspecified in the `transformations` tree. Setting `default_to_original=False` means that unspecified keys will be taken from `new_tree`, **not** `original_tree`.

This can be useful if we just want to take a few values from our original checkpoint, while using the rest from our new state.

In [None]:
original_tree = {
  'a': 1,
  'b': {
    'c': 5,
    'd': 6,
    'e': 7,
  },
  'f': 2,
}
transformations = {'a': Transform(value_fn=lambda x: x*10)}
new_tree = {
  'a': 11,
  'b': {
    'c': 12,
    'd': 13,
    'e': 14,
  },
  'f': 15,
}
orbax.apply_transformations(original_tree, transformations, new_tree, default_to_original=False)

Returning to our other feature: support for regexes. Real model states often represent a conceptual parameter (a single layer, perhaps) with multiple actual key/value pairs. In this case, it can be useful to use a regex to refer to the parameter in question.

In the following example, we have one model (conceputally a pretrained checkpoint) with two layers, and another model (conceputally our in-memory model state) with four layers. We would like to insert the two layers of the checkpoint as the middle two layers of the new state, while leaving the bottom and top layers randomly initialized.

In [None]:
import flax.linen as nn
from orbax.checkpoint import test_utils

In [None]:
from flax.training.train_state import TrainState
import optax
from jax import numpy as jnp

def init_flax_model(model):
  params = model.init(jax.random.PRNGKey(0), jnp.ones([8, 8]))
  tx = optax.adamw(learning_rate=0.001)
  state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
  return jax.tree_util.tree_map(np.asarray, state)

class SmallModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(features=8)(x)
    x = nn.sigmoid(x)
    x = nn.Dense(features=8)(x)
    return x

old_state = init_flax_model(SmallModel())
# multiply by 100 to represent "training"
old_state = jax.tree_util.tree_map(lambda x: x * 100, old_state)

class LargeModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=8)(x)
    x = nn.sigmoid(x)
    x = nn.Dense(features=4)(x)
    x = nn.sigmoid(x)
    x = nn.Dense(features=4)(x)
    x = nn.sigmoid(x)
    x = nn.Dense(features=2)(x)
    return x

new_state = init_flax_model(LargeModel())

In [None]:
transformations = {
    # NewModel layer_0 is a newly inserted layer, thus use_fallback=True.
    # The "fallback" tree in this case is new_tree.
    # Since the layer_0 has the same name in old and new, we need to provide
    # an indication that the value of layer_0 should come from new_tree rather
    # than original_tree.
    r'(.*)Dense_0(.*)': Transform(use_fallback=True),
    # OriginalModel layer 0 maps to NewModel layer 1
    r'(.*)Dense_1(.*)': Transform(original_key=r'\1Dense_0\2'),
    # OriginalModel layer 1 maps to NewModel layer 2
    r'(.*)Dense_2(.*)': Transform(original_key=r'\1Dense_1\2')
}  # Note: NewModel layer 3 is newly added.
restored_state = orbax.apply_transformations(old_state, transformations, new_state)
print(restored_state.params['params'])

We can see in the result that layers 0 and 3 are still small, while layers 1 and 2 have large values, coming from the original checkpoint.