<a href="https://colab.research.google.com/github/google/orbax/blob/main/orbax.checkpoint.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/google/CommonLoopUtils/blob/master/clu_synopsis.ipynb" ><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install orbax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting orbax
  Downloading orbax-0.0.3-py3-none-any.whl (43 kB)
[K     |████████████████████████████████| 43 kB 1.2 MB/s 
Collecting dataclasses
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Collecting flax
  Downloading flax-0.5.2-py3-none-any.whl (197 kB)
[K     |████████████████████████████████| 197 kB 17.9 MB/s 
Collecting tensorstore>=0.1.20
  Downloading tensorstore-0.1.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.1 MB)
[K     |████████████████████████████████| 9.1 MB 45.1 MB/s 
Collecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[K     |████████████████████████████████| 217 kB 67.4 MB/s 
Collecting pyyaml
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 57.3 MB/s 
[?25hCollecting optax
  Downlo

In [2]:
import orbax.checkpoint
import time
from collections import namedtuple
import orbax.checkpoint as orbax
import jax
from jax.experimental.maps import Mesh
from jax.experimental.pjit import pjit, PartitionSpec
import numpy as np
import tensorflow as tf
from jax.experimental.global_device_array import GlobalDeviceArray

In [5]:
# Enable GDA
jax.config.update('jax_parallel_functions_output_gda', True)
devices = np.asarray(jax.devices())

mesh = Mesh(devices, ('data',))
axes = PartitionSpec('data',)

In [6]:
#@title checkpoints_directory

directory = '/cns/gg-d/home/cpgaffney/colab' #@param {type:"string"}

print(directory)

/cns/gg-d/home/cpgaffney/colab


# A Very Basic Example

In [7]:
options = orbax.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=3)
mngr = orbax.CheckpointManager(
    directory, orbax.Checkpointer(orbax.PyTreeCheckpointHandler()), options)

AttributeError: ignored

The CheckpointManager is constructed with `Checkpointer` and `CheckpointHandler` objects. We will discuss these further below, but at a high level, the `Checkpointer` controls the *manner in which* the object is saved while the `CheckpointHandler` deals with type-specific logic and provides extra options for customization.

First, we'll need to perform some setup to create a train state that mimics in a very basic form how a real model might look. We use GDA for this is example, but it is also possible to use scalars or numpy arrays, assuming they are *replicated* or *not sharded*.

In [None]:
def create_initial_state():
  state = {
      'layer_0': {
          'bias': np.zeros(16),
          'kernel': np.arange(16),
      },
      'layer_1': {
          'bias': np.zeros(8),
          'kernel': np.arange(8),
      },
  }

  create_gda = pjit(lambda x: x, in_axis_resources=axes, out_axis_resources=axes)
  with Mesh(mesh.devices, mesh.axis_names):
    state = jax.tree_map(create_gda, state)

  state['step'] = 0
  return state

Here's our mock training step. At every step, we save a checkpoint. Since we specified `max_to_keep=3` in our options, we expect to only have the latest 3 checkpoints at the end of training.

In [None]:
state = create_initial_state()

def train(step, state):
  # do some training, modify state
  state['step'] = step

  mngr.save(step, state)

  return state

for step in range(5):
  state = train(step, state)

print(f'Steps: {mngr.all_steps()}')



NameError: ignored

# A More Complicated Example

In the following example, we will checkpoint multiple objects, and also use metrics to track our best checkpoint so far. 

We will also allow the state to be saved asynchronously using `AsyncCheckpointer`. This means checkpointing will happen in a background thread, leaving us free to continue training or other tasks in the main thread.

In [None]:
options = orbax.CheckpointManagerOptions(
    save_interval_steps=1,
    max_to_keep=3,
    best_fn=lambda metrics: metrics['loss'],
    best_mode='min')
mngr = orbax.CheckpointManager(
    directory, {
        'state': orbax.AsyncCheckpointer(orbax.PyTreeCheckpointHandler()),
        'metadata': orbax.Checkpointer(orbax.JsonCheckpointHandler())
    }, options)

In [None]:
state = create_initial_state()
metadata = {
    'version': 1.1,
    'exp_name': 'my_test_exp',
    'timestamp': 0,
}


def get_metrics(step):
  return {'accuracy': 1.0, 'loss': step * 1.5}


def train(step, state):
  # do some training, modify state
  metrics = get_metrics(step)
  state['step'] = step
  metadata['timestamp'] = time.time()

  # save with default arguments for all params except 'step', which uses flax
  state_save_args = jax.tree_map(lambda _: orbax.SaveArgs(), state)
  state['step'] = SaveArgs(use_flax=True)

  mngr.save(
      step,
      items={
          'state': state,
          'metadata': metadata
      },
      # save_kwargs must be a dict with the same keys as items.
      # not all keys in items have to be provided, in which case default kwargs
      # are used
      # each value must be a dict with keyword args passed to the underlying
      # CheckpointHandler for that item (see CheckpointManager object construction)
      save_kwargs={'state': {
          'save_args': state_save_args
      }},
      metrics=metrics)

  return state


for step in range(5):
  state = train(step, state)

mngr.wait_until_finished()
print(f'Steps: {mngr.all_steps()}')

Let's unpack what's happening. 

For starters, we now track metrics, which can be used to keep only the best checkpoints saved, while deleting the rest. Since our loss is getting progressively worse, and `best_mode='min'`, we will keep the first checkpoints, rather than the most recent ones. The metrics may be an arbitrary PyTree; it is up to you to define how it is interpreted. 

Our metadata will be saved synchronously, but our state will be saved in a background thread. After calling `save`, all files for the state may not have been written yet. In the meantime, we may continue training. However, we need to call `wait_until_finished` before ending our training program to block for any outstanding save operations. Calling `save` again will do this automatically - you cannot have multiple saves for multiple steps running concurrently.

We also have extra arguments to customize saving for the `step` parameter within the train state. Because this is only an integer, using the default storage mechanism, [Tensorstore](https://google.github.io/tensorstore/) might be somewhat overkill. It would be more efficient to store it, along with any other similarly small parameters, into a single file using [flax.serialization](https://flax.readthedocs.io/en/latest/flax.serialization.html). The parameters passed here should match the optional arguments for the provided `CheckpointHandler`. See below for further details.

# Checkpointer

`Checkpointer` allows you to save an object to a specified directory without providing any of the structure or extra features that `CheckpointManager` does.

In [None]:
state = create_initial_state()
ckptr = orbax.Checkpointer(PyTreeCheckpointHandler())

existing_checkpoint_dir = tf.io.gfile.join(directory, '0', 'state')
restore_args = jax.tree_map(lambda _: orbax.RestoreArgs(mesh=mesh, mesh_axes=axes))
restore_args['step'] = orbax.RestoreArgs(as_gda=False)
restored = ckptr.restore(existing_checkpoint_dir, restore_args=restore_args)
print(restored)

As shown in this example, this object may be useful for restoring a pre-exisiting checkpoint without requiring a `CheckpointManager`.

# CheckpointHandler

**`CheckpointHandler` should not be used independently of `Checkpointer` or `CheckpointManager`.** 

Additional documentation coming soon. See go/orbax/checkpoint.

TODO(cpgaffney) Add sections on usage of the transformations library and (later) how to customize CheckpoitnHandler.