# Working with PyTree Checkpoints

A [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html) is the most common way of representing a training state in JAX. While Orbax is designed to be as generic as possible, and provides customization options for all manner of checkpointable objects, PyTrees naturally have pride of place. Furthermore, the standard object used to represent large, sharded arrays is the `jax.Array`. This, too, has extensive first-class support.

## Exclusive APIs to checkpoint PyTrees

The following APIs can be used to checkpoint PyTrees exclusively.

To save:

* `ocp.save_pytree(...)`
* `ocp.save_pytree_async(...)`
* `training.Checkpointer.save_pytree(...)`
* `training.Checkpointer.save_pytree_async(...)`

To load:
* `ocp.load_pytree(...)`
* `ocp.load_pytree_async(...)`
* `training.Checkpointer.load_pytree(...)`
* `training.Checkpointer.load_pytree_async(...)`

Of course, the `save_checkpointables(...)` and `load_checkpointables(...)`
flavor APIs can be used to save a PyTree too.

Let's setup a PyTree of jax.Array to play with these APIs.

In [None]:
from etils import epath
import jax
import numpy as np
from orbax import checkpoint as ocp_v0
import orbax.checkpoint.experimental.v1 as ocp

In [None]:
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(
        'model',
    ),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
pytree = {
    'a': np.arange(16),
    'b': np.ones(16),
}
pytree = jax.tree_util.tree_map(create_sharded_array, pytree)
pytree

In [None]:
abstract_pytree = jax.tree_util.tree_map(
    ocp_v0.utils.to_shape_dtype_struct, pytree
)
abstract_pytree

## Basic Checkpointing

Let's use `ocp.save_*`/`ocp.load_*` to work with the pytree created earlier.

In [None]:
path = epath.Path('/tmp/basic/')
path.rmtree(missing_ok=True)

# Simple save using default options:
ocp.save_pytree(path, pytree)

loaded = ocp.load_pytree(path)
loaded

We use the pytree metadata to load the pytree.

In [None]:
pytree_metadata = ocp.pytree_metadata(path).metadata
pytree_metadata

In [None]:
loaded = ocp.load_pytree(path, pytree_metadata)
loaded

In [None]:
(loaded['a'].sharding, loaded['b'].sharding)

We specify the `abstract_pytree` in order to load with the given dtypes, shapes, and shardings for each leaf.

In [None]:
# Load using abstract_pytree.
loaded = ocp.load_pytree(path, abstract_pytree)
loaded

In [None]:
(loaded['a'].sharding, loaded['b'].sharding)

You can do the exact same with a "concrete" target rather than an "abstract" target. However, this requires that you fully initialize the target train state
before loading from the checkpoint, which is inefficient. It is better practice to only initialize metadata (either by manually creating `jax.ShapeDtypeStruct`s or using `jax.eval_shape`).

In [None]:
ocp.load_pytree(path, pytree)

### Customizing Loaded Properties

#### Array dtype

In [None]:
def set_loading_dtype(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
  return x.update(dtype=np.int16)


cast_dtype_abstract_pytree = jax.tree_util.tree_map(
    set_loading_dtype, abstract_pytree
)

In [None]:
ocp.load_pytree(path, cast_dtype_abstract_pytree)

#### Change sharding

**NOTE: This can often be a particularly sharp edge.**

Sharding commonly needs to be changed when loading a checkpoint saved on one topology to a different topology.

**If changing topologies, you MUST specify sharding when loading.**

Unless you are loading on the exact same topology, Orbax does not make any decisions about shardings on your behalf. If you have the exact same topology,
however, it is possible to avoid specifying the sharding when loading. This is demonstrated below:

In [None]:
loaded = ocp.load_pytree(path)

In [None]:
(loaded['a'].sharding, loaded['b'].sharding)

In the example below, we alter the sharding while loading.

In [None]:
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('x',)),
    jax.sharding.PartitionSpec(),
)


def set_sharding(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
  return x.update(sharding=sharding)


change_sharding_abstract_pytree = jax.tree_util.tree_map(
    set_sharding, abstract_pytree
)
loaded = ocp.load_pytree(path, change_sharding_abstract_pytree)

In [None]:
(loaded['a'].sharding, loaded['b'].sharding)

We can use pytree metadata instead of the abstract pytree.

In [None]:
pytree_metadata = ocp.pytree_metadata(path).metadata
change_sharding_pytree_metadata = jax.tree_util.tree_map(
    lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=sharding), pytree_metadata
)
loaded = ocp.load_pytree(path, change_sharding_pytree_metadata)
(loaded['a'].sharding, loaded['b'].sharding)

### Partial Loading

You may wish to load part of a PyTree contained within a saved checkpoint. For example, consider the following item:

In [None]:
original_item = {
    'params': {
        'layer1': {
            'kernel': np.arange(8),
            'bias': np.arange(8),
        },
        'layer2': {
            'kernel': np.arange(8),
            'bias': np.arange(8),
        },
    },
    'opt_state': [np.arange(8), np.arange(8)],
    'step': 101,
}

path = epath.Path('/tmp/partial/')
path.rmtree(missing_ok=True)

ocp.save_pytree(path / '1', original_item)

If we want to load only a subset of PyTree nodes (`params.layer2` and `step`, for example), we can use Placeholder values.

#### Placeholder

To load part of a PyTree item, we can specify which nodes to ignore during loading by using `...` (`ocp.PLACEHOLDER`).

In [None]:
reference_item = {
    'params': {
        'layer1': {
            'kernel': ...,
            'bias': ...,
        },
        'layer2': {
            'kernel': np.arange(8),
            'bias': np.arange(8),
        },
    },
    'opt_state': [..., ...],
    'step': 101,
}

ocp.load_pytree(path / '1', reference_item)

## Advanced Customizations

`ocp.Context` enables more customizations.

For customized save/load behavior, these APIs should be invoked within a `ocp.Context`
instance, which in turn can be configured with a number of options like Saving, Loading,
FileOptions etc.

The usage pattern is as follows:
```
with ocp.Context(
  pytree_options=PyTreeOptions(...),
  file_options=FileOptions(...),
):
  ocp.save_pytree(path, pytree)
```

Let's explore few examples. Please also take a look at API Reference for specific option details.

### Saving

#### Customizing Array dtype

we can customize the on-disk type used to save individual arrays. First, let's save and load as normal.

In [None]:
path = epath.Path('/tmp/advanced/')
path.rmtree(missing_ok=True)

In [None]:
ocp.save_pytree(path / '1', pytree)

In [None]:
loaded = ocp.load_pytree(path / '1')

In [None]:
(loaded['a'].dtype, loaded['b'].dtype)

Now, let's set the dtype of selective array when saving.

In [None]:
def create_array_storage_options_fn(keypath, value):
  del value
  last_key = keypath[-1]
  if isinstance(last_key, jax.tree_util.GetAttrKey) and last_key.name == 'a':
    return ocp.options.ArrayOptions.Saving.StorageOptions(
        dtype=np.dtype(np.int16)
    )
  else:
    return ocp.options.ArrayOptions.Saving.StorageOptions()


with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
        saving=ocp.options.PyTreeOptions.Saving(
            create_array_storage_options_fn=create_array_storage_options_fn,
        )
    )
):
  ocp.save_pytree(path / '2', pytree, overwrite=True)

In [None]:
loaded = ocp.load_pytree(path / '2')

In [None]:
(loaded['a'].dtype, loaded['b'].dtype)

Now, let's set the dtype of all arrays when saving.

In [None]:
create_array_storage_options_fn = (
    lambda k, v: ocp.options.ArrayOptions.Saving.StorageOptions(
        dtype=np.dtype(np.int16)
    )
)
with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
        saving=ocp.options.PyTreeOptions.Saving(
            create_array_storage_options_fn=create_array_storage_options_fn
        )
    )
):
  ocp.save_pytree(path / '3', pytree, overwrite=True)

In [None]:
loaded = ocp.load_pytree(path / '3')

In [None]:
(loaded['a'].dtype, loaded['b'].dtype)

#### High Throughput with `ocdbt` option

For high throughput and avoid creating separate subdirectories for each leaf, enable `use_ocdbt`. Please note that it is enabled by default.

In [None]:
with ocp.Context(
    array_options=ocp.options.ArrayOptions(
        saving=ocp.options.ArrayOptions.Saving(
            use_ocdbt=True,
        )
    )
):
  ocp.save_pytree(path / '4', pytree, overwrite=True)

A checkpoint created with this option enabled can be identified by presence of files `manifest.ocdbt` and subdirs like `ocdbt.process_*`.

In [None]:
!ls /tmp/advanced/4/pytree

However, for use cases like large stacked models, disabling this option may be more efficient.

In [None]:
with ocp.Context(
    array_options=ocp.options.ArrayOptions(
        saving=ocp.options.ArrayOptions.Saving(
            use_ocdbt=False,
        )
    )
):
  ocp.save_pytree(path / '5', pytree, overwrite=True)

!ls /tmp/advanced/5/pytree

Please note how each leaf is written in its own subdir when `use_ocdbt=False`.

### Loading

#### Pad / truncate shape

Ordinarily, specifying a target array with a different shape than in the
checkpoint results in an error.

In [None]:
# Original shape.
loaded = ocp.load_pytree(path / '1')

(loaded['a'].shape, loaded['b'].shape)

In [None]:
different_shape_abstract_pytree = {
    'a': jax.ShapeDtypeStruct(
        shape=(8,),
        dtype=abstract_pytree['a'].dtype,
        sharding=abstract_pytree['a'].sharding,
    ),
    'b': jax.ShapeDtypeStruct(
        shape=(32,),
        dtype=abstract_pytree['b'].dtype,
        sharding=abstract_pytree['b'].sharding,
    ),
}

try:
  ocp.load_pytree(path / '1', different_shape_abstract_pytree)
except BaseException as e:
  print(e)

We can pad or truncate arrays as they are loaded by specifying `enable_padding_and_truncation=True`.

In [None]:
with ocp.Context(
    array_options=ocp.options.ArrayOptions(
        loading=ocp.options.ArrayOptions.Loading(
            enable_padding_and_truncation=True
        )
    )
):
  loaded = ocp.load_pytree(path / '1', different_shape_abstract_pytree)

In [None]:
(loaded['a'].shape, loaded['b'].shape)