# Model Surgery

Oftentimes, the model we saved to disk isn't exactly the model we wish to work with in memory. Some examples of this are:

 - Stacking/unstacking layers to match your training setup
 - Fine-tuning a multi-modal model from multiple uni-modal models
 - Using a frozen teacher model at each iteration of the training loop for a student model
 - Loading only the weights section of the PyTree, and ignoring things like optimizer state, when doing model evaluation

Model surgery is a toolset designed precisely for this kind of task.

Orbax Checkpointing currently exposes a Partial Loading API, which allows for a subset of PyTree leaves (or, a "strict subtree") to be loaded from the full model on disk. More arbitrary manipulation of leaves and trees is planned to be added in the future, such as loading multiple trees and merging them into one.

Let's first take a look at what it's like to restore part of a PyTree, then touch on the planned Advanced Model Surgery API.

In [15]:
import jax
import numpy as np
from orbax.checkpoint import v1 as ocp
from etils import epath

path = epath.Path('/tmp/my-checkpoints/ckpt-1')
pytree = {
  'params': {
    'layer0': {
      'kernel': np.random.uniform(size=(2, 2)),
      'bias': np.ones(2),
    },
  },
  'opt_state': {
    '0': np.random.random(size=(2,)),
    '1': [np.ones(2), np.ones(2)],
  },
  'step': np.asarray(0),
}
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))
pytree = jax.tree.map(
    lambda arr: jax.make_array_from_callback(
        arr.shape,
        sharding,
        lambda idx: arr[idx],
    ),
    pytree,
)
ocp.save_pytree(path, pytree, force=True)

## Partial Loading

Partial loading is a way to solve the most common use case of loading a different tree than is present in the checkpoint - where leaves or subtrees can be omitted. The canonical example is to skip loading the optimizer state when you're doing evaluation. There are a couple of ways to do this with the Partial Loading API. Let's take a look at both.

### Placeholder

Since we don't need the optimizer state (`opt_state`) during model evaluation, we can signal to Orbax to skip loading the leaves with that node by using the `ocp.PLACEHOLDER` (`...`) value.

In [16]:
abstract_tree = {
  'params': {
    'layer0': {
      'kernel': np.array([]),
      'bias': np.array([]),
    },
  },
  # Skip loading 'opt_state'
  'opt_state': {
    '0': ...,
    '1': [..., ...],
  },
  'step': np.array([]),
}

ocp.load_pytree(path, abstract_tree)

Note that `ocp.PLACEHOLDER` can only be used for leaves, so `opt_state: ocp.PLACEHOLDER` would not work. Keeping the structure consistent in this way is important for use cases like merging the original state with the restored state.

In [17]:
bad_abstract_tree = {
  'params': {
    'layer0': {
      'kernel': np.array([]),
      'bias': np.array([]),
    },
  },
  # Skip loading 'opt_state'
  'opt_state': ...,
  'step': np.array([]),
}

try:
  ocp.load_pytree(path, bad_abstract_tree)
except Exception as e:
  print(e)

Creating an abstract tree by hand is tedious. A more natural way to do this is by using something like JAX's `tree_map_with_path`.

In [18]:
def _create_abstract_leaf_for_partial_load(leaf_path, _):
  leaf_path = jax.tree_util.keystr(leaf_path, simple=True, separator='/')
  if (leaf_path.split('/')[0] == 'opt_state'):
    return ocp.PLACEHOLDER 
  else:
    return np.array([])

easy_abstract_tree = jax.tree.map_with_path(
  _create_abstract_leaf_for_partial_load,
  pytree
)

ocp.load_pytree(path, easy_abstract_tree)

We may not have direct access to the original PyTree when creating the abstract counterpart, and in that case, we'll need to use the on-disk `pytree_metadata`.

In [19]:
on_disk_pytree_structure = ocp.pytree_metadata(path).metadata

abstract_tree_from_metadata = jax.tree.map_with_path(
  _create_abstract_leaf_for_partial_load,
  on_disk_pytree_structure
)

ocp.load_pytree(path, abstract_tree_from_metadata)

### Omission TODO(b/411457893): add omission mode

### Model Surgery TODO(b/411457893): add future model surgery api