# Working with PyTree Checkpoints

A [PyTree](https://jax.readthedocs.io/en/latest/pytrees.html) is the most common way of representing a training state in JAX. While Orbax is designed to be as generic as possible, and provides customization options for all manner of checkpointable objects, PyTrees naturally have pride of place. Furthermore, the standard object used to represent large, sharded arrays is the `jax.Array`. This, too, has extensive first-class support.

## Exclusive APIs to checkpoint PyTrees

The following APIs can be used to checkpoint PyTrees exclusively.

To save:

* `ocp.save_pytree(...)`
* `ocp.save_pytree_async(...)`
* `training.Checkpointer.save_pytree(...)`
* `training.Checkpointer.save_pytree_async(...)`

To load:
* `ocp.load_pytree(...)`
* `ocp.load_pytree_async(...)`
* `training.Checkpointer.load_pytree(...)`
* `training.Checkpointer.load_pytree_async(...)`

Of course, the `save_checkpointables(...)` and `load_checkpointables(...)`
flavor APIs can be used to save a PyTree too.

Let's setup a PyTree of jax.Array to play with these APIs.

In [None]:
from etils import epath
import jax
import numpy as np
import orbax.checkpoint.experimental.v1 as ocp

In [None]:
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(
        'model',
    ),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
pytree = {
    'a': np.arange(16, dtype=np.int32),
    'b': np.ones(16, dtype=np.int32),
}
pytree = jax.tree_util.tree_map(create_sharded_array, pytree)
pytree

## Basic Checkpointing

Let's use `ocp.save_*`/`ocp.load_*` to work with the pytree created earlier.

In [None]:
path = epath.Path('/tmp/basic/')
path.rmtree(missing_ok=True)

# Simple save using default options:
ocp.save_pytree(path, pytree)

We can easily restore using the following snippet.

**Warning: do not use for production-sensitive cases.**

In [None]:
loaded = ocp.load_pytree(path)
loaded

It is not recommended to load this way for production-sensitive cases because the user cannot make any guarantees about what they are loading. If the shapes of some arrays have changed in the model since the checkpoint was saved, errors can be seen when attempting to create the model. If the device topology has changed, we will see errors when attempting to place arrays on devices.

It is therefore recommended that users always specify an **abstract pytree** when loading.

### Understanding Abstract Trees and Leaves

An **abstract PyTree** is just a normal PyTree, but with abstract leaves. An **abstract leaf** is a cheap representation of a leaf type (such as an array) that contains only metadata, and does not represent the real values. (Contrast with a *concrete* PyTree, which contains real data in the form of large arrays, and other types.)

Let's create an abstract PyTree matching the structure of the PyTree we originally saved.

In [None]:
abstract_pytree = {
    'a': jax.ShapeDtypeStruct(shape=(16,), dtype=np.int32, sharding=sharding),
    'b': jax.ShapeDtypeStruct(shape=(16,), dtype=np.int32, sharding=sharding),
}
abstract_pytree

In [None]:
# Load using abstract_pytree.
loaded = ocp.load_pytree(path, abstract_pytree)
loaded

In [None]:
(loaded['a'].sharding, loaded['b'].sharding)

The `pytree_metadata` method returns a `CheckpointMetadata` object with a number of properties, but the core `metadata` property is just an abstract PyTree. This can also be used for loading as shown below.

In [None]:
pytree_metadata = ocp.pytree_metadata(path).metadata
pytree_metadata

In [None]:
loaded = ocp.load_pytree(path, pytree_metadata)
loaded

In [None]:
(loaded['a'].sharding, loaded['b'].sharding)

Note that it is also valid to provide a "concrete" PyTree for loading rather than an "abstract" target, since by definition, the concrete leaves contain all the same properties provided by the abstract leaves.

However, this requires that you fully initialize the target train state
before loading from the checkpoint, which is inefficient or impractical for real use cases. It is better practice to only initialize metadata (either by manually creating `jax.ShapeDtypeStruct`s or using `jax.eval_shape`).

In [None]:
ocp.load_pytree(path, pytree)

### Standard Leaf Types

The following standard leaf types are supported by Orbax by default. Each concrete leaf type has a corresponding abstract leaf type. Most abstract types are implemented as `Protocol`'s, so that any object implementing the required properties can be accepted as a valid abstract type.

| `Leaf` Type | `AbstractLeaf` Type | Properties |
:------- | :-------- | :-------- |
|`jax.Array`|`AbstractShardedArray` (`jax.ShapeDtypeStruct`) |`shape`, `dtype`, `sharding`|
|`np.ndarray`|`AbstractArray` (`np.ndarray`) |`shape`, `dtype`|
|`int`|`int`|  |
|`float`|`float`| |
|`bytes`|`bytes`| |
|`str`|`str`| |

`None` is always a valid abstract leaf; it serves as an indication that the leaf should be restored using metadata stored in the checkpoint.

`Type[AbstractLeaf]` is also always a valid abstract leaf; it again serves as an indication that the leaf should be restored using the metadata, but with the additional constraint to load as the indicated type. For example, instead of specifying `jax.ShapeDtypeStruct(shape=..., dtype=..., sharding=...)`, it is sufficient to pass `jax.ShapeDtypeStruct`. Similarly, instead of passing `0` to restore as an `int`, the type itself may be passed.

To summarize, here are the ways you can load a PyTree using abstract leaves, with the way we most recommend at the top, and the way we least recommend at the bottom.

**1. Fully-specified abstract values**

This provides the most loading validations and requires the least amount of
unnecessary metadata reads.

```
abstract_pytree = {
  'a': jax.ShapeDtypeStruct(shape=..., dtype=..., sharding=jax.sharding.NamedSharding(...))
}
```

**2. Only types specified**

This guarantees that each leaf will be loaded with the indicated type, but metadata
will be used to restore specific properties for each leaf.

```
abstract_pytree = {
  'a': jax.ShapeDtypeStruct,
  'b': int,
  'c': np.ndarray,
}
```

**3. `None` specified (per-leaf)**

This is essentially the same as (2), but metadata will also be used to decide
which type each leaf should be loaded as.

```
abstract_pytree = {
  'a': None,
  'b': None,
}
```

**4. `None` specified**

This loads the PyTree structure without any checks, and can lead to errors later
in your code if the checkpoint does not have the structure you expect.

```
abstract_pytree = None
```

### Customizing Loaded Properties for Arrays

#### Array dtype

In [None]:
def set_loading_dtype(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
  return x.update(dtype=np.int16)


cast_dtype_abstract_pytree = jax.tree_util.tree_map(
    set_loading_dtype, abstract_pytree
)

In [None]:
ocp.load_pytree(path, cast_dtype_abstract_pytree)

#### Change sharding

**NOTE: This can often be a particularly sharp edge.**

Sharding commonly needs to be changed when loading a checkpoint saved on one topology to a different topology.

**If changing topologies, you MUST specify sharding when loading.**

Unless you are loading on the exact same topology, Orbax does not make any decisions about shardings on your behalf. If you have the exact same topology,
however, it is possible to avoid specifying the sharding when loading. This is demonstrated below:

In [None]:
loaded = ocp.load_pytree(path)

In [None]:
(loaded['a'].sharding, loaded['b'].sharding)

In the example below, we alter the sharding while loading.

In [None]:
sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('x',)),
    jax.sharding.PartitionSpec(),
)


def set_sharding(x: jax.ShapeDtypeStruct) -> jax.ShapeDtypeStruct:
  return x.update(sharding=sharding)


change_sharding_abstract_pytree = jax.tree_util.tree_map(
    set_sharding, abstract_pytree
)
loaded = ocp.load_pytree(path, change_sharding_abstract_pytree)

In [None]:
(loaded['a'].sharding, loaded['b'].sharding)

We can use pytree metadata instead of the abstract pytree.

In [None]:
pytree_metadata = ocp.pytree_metadata(path).metadata
change_sharding_pytree_metadata = jax.tree_util.tree_map(
    lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=sharding), pytree_metadata
)
loaded = ocp.load_pytree(path, change_sharding_pytree_metadata)
(loaded['a'].sharding, loaded['b'].sharding)

#### Change leaf type

The abstract leaf type dictates the loaded type for each leaf. If we save a value as a `jax.Array` but provide an abstract leaf without the required `sharding` property, Orbax will load as `np.ndarray`. Similarly, we can save as an `int` and load as a `float` if we specify `float` as the abstract leaf.

In [None]:
path = epath.Path('/tmp/change-type/')
path.rmtree(missing_ok=True)

pytree_with_scalars = {
    'a': np.asarray(12),
    'b': 13.5,
    'c': create_sharded_array(np.arange(8)),
}
ocp.save_pytree(path, pytree_with_scalars)

In [None]:
abstract_pytree_with_scalars = {
    'a': float,
    'b': int,
    'c': np.empty((8,)),
}
ocp.load_pytree(path, abstract_pytree_with_scalars)

### Partial Loading

You may wish to load part of a PyTree contained within a saved checkpoint. For example, consider the following item:

In [None]:
original_item = {
    'params': {
        'layer1': {
            'kernel': np.arange(8),
            'bias': np.arange(8),
        },
        'layer2': {
            'kernel': np.arange(8),
            'bias': np.arange(8),
        },
    },
    'opt_state': [np.arange(8), np.arange(8)],
    'step': 101,
}

path = epath.Path('/tmp/partial/')
path.rmtree(missing_ok=True)

ocp.save_pytree(path / '1', original_item)

If we want to load only a subset of PyTree nodes (`params.layer2` and `step`, for example), we can use Placeholder values.

#### Placeholder

To load part of a PyTree item, we can specify which nodes to ignore during loading by using `...` (`ocp.PLACEHOLDER`).

In [None]:
reference_item = {
    'params': {
        'layer1': {
            'kernel': ...,
            'bias': ...,
        },
        'layer2': {
            'kernel': np.arange(8),
            'bias': np.arange(8),
        },
    },
    'opt_state': [..., ...],
    'step': 101,
}

ocp.load_pytree(path / '1', reference_item)

## Advanced Customizations

`ocp.Context` enables more customizations.

For customized save/load behavior, these APIs should be invoked within a `ocp.Context`
instance, which in turn can be configured with a number of options like Saving, Loading,
FileOptions etc.

The usage pattern is as follows:
```
with ocp.Context(
  pytree_options=PyTreeOptions(...),
  file_options=FileOptions(...),
):
  ocp.save_pytree(path, pytree)
```

Let's explore few examples. Please also take a look at API Reference for specific option details.

### Saving

#### Customizing Array dtype

we can customize the on-disk type used to save individual arrays. First, let's save and load as normal.

In [None]:
path = epath.Path('/tmp/advanced/')
path.rmtree(missing_ok=True)

In [None]:
ocp.save_pytree(path / '1', pytree)

In [None]:
loaded = ocp.load_pytree(path / '1')

In [None]:
(loaded['a'].dtype, loaded['b'].dtype)

Now, let's set the dtype of selective array when saving.

In [None]:
def create_array_storage_options_fn(keypath, value):
  del value
  last_key = keypath[-1]
  if isinstance(last_key, jax.tree_util.GetAttrKey) and last_key.name == 'a':
    return ocp.options.ArrayOptions.Saving.StorageOptions(
        dtype=np.dtype(np.int16)
    )
  else:
    return ocp.options.ArrayOptions.Saving.StorageOptions()


with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
        saving=ocp.options.PyTreeOptions.Saving(
            create_array_storage_options_fn=create_array_storage_options_fn,
        )
    )
):
  ocp.save_pytree(path / '2', pytree, overwrite=True)

In [None]:
loaded = ocp.load_pytree(path / '2')

In [None]:
(loaded['a'].dtype, loaded['b'].dtype)

Now, let's set the dtype of all arrays when saving.

In [None]:
create_array_storage_options_fn = (
    lambda k, v: ocp.options.ArrayOptions.Saving.StorageOptions(
        dtype=np.dtype(np.int16)
    )
)
with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
        saving=ocp.options.PyTreeOptions.Saving(
            create_array_storage_options_fn=create_array_storage_options_fn
        )
    )
):
  ocp.save_pytree(path / '3', pytree, overwrite=True)

In [None]:
loaded = ocp.load_pytree(path / '3')

In [None]:
(loaded['a'].dtype, loaded['b'].dtype)

#### High Throughput with `ocdbt` option

For high throughput and avoid creating separate subdirectories for each leaf, enable `use_ocdbt`. Please note that it is enabled by default.

In [None]:
with ocp.Context(
    array_options=ocp.options.ArrayOptions(
        saving=ocp.options.ArrayOptions.Saving(
            use_ocdbt=True,
        )
    )
):
  ocp.save_pytree(path / '4', pytree, overwrite=True)

A checkpoint created with this option enabled can be identified by presence of files `manifest.ocdbt` and subdirs like `ocdbt.process_*`.

In [None]:
!ls /tmp/advanced/4/pytree

However, for use cases like large stacked models, disabling this option may be more efficient.

In [None]:
with ocp.Context(
    array_options=ocp.options.ArrayOptions(
        saving=ocp.options.ArrayOptions.Saving(
            use_ocdbt=False,
        )
    )
):
  ocp.save_pytree(path / '5', pytree, overwrite=True)

!ls /tmp/advanced/5/pytree

Please note how each leaf is written in its own subdir when `use_ocdbt=False`.

### Loading

#### Pad / truncate shape

Ordinarily, specifying a target array with a different shape than in the
checkpoint results in an error.

In [None]:
# Original shape.
loaded = ocp.load_pytree(path / '1')

(loaded['a'].shape, loaded['b'].shape)

In [None]:
different_shape_abstract_pytree = {
    'a': jax.ShapeDtypeStruct(
        shape=(8,),
        dtype=abstract_pytree['a'].dtype,
        sharding=abstract_pytree['a'].sharding,
    ),
    'b': jax.ShapeDtypeStruct(
        shape=(32,),
        dtype=abstract_pytree['b'].dtype,
        sharding=abstract_pytree['b'].sharding,
    ),
}

try:
  ocp.load_pytree(path / '1', different_shape_abstract_pytree)
except BaseException as e:
  print(e)

We can pad or truncate arrays as they are loaded by specifying `enable_padding_and_truncation=True`.

In [None]:
with ocp.Context(
    array_options=ocp.options.ArrayOptions(
        loading=ocp.options.ArrayOptions.Loading(
            enable_padding_and_truncation=True
        )
    )
):
  loaded = ocp.load_pytree(path / '1', different_shape_abstract_pytree)

In [None]:
(loaded['a'].shape, loaded['b'].shape)