# Model Surgery

Oftentimes, the model we saved to disk isn't exactly the model we wish to work with in memory. Some examples of this are:

 - Stacking/unstacking layers to match your training setup
 - Fine-tuning a multi-modal model from multiple uni-modal models
 - Using a frozen teacher model at each iteration of the training loop for a student model
 - Loading only the weights section of the PyTree, and ignoring things like optimizer state, when doing model evaluation

Model surgery is a toolset designed precisely for this kind of task.

Orbax Checkpointing currently exposes a Partial Loading API, which allows for a subset of PyTree leaves (or, a "strict subtree") to be loaded from the full model on disk. More arbitrary manipulation of leaves and trees is planned to be added in the future, such as loading multiple trees and merging them into one.

Let's first take a look at what it's like to restore part of a PyTree, then touch on the planned Advanced Model Surgery API.

In [11]:
import jax
import numpy as np
from orbax.checkpoint import v1 as ocp
from etils import epath

path = epath.Path('/tmp/my-checkpoints/ckpt-1')
pytree = {
  'params': {
    'layer0': {
      'kernel': np.random.uniform(size=(2, 2)),
      'bias': np.ones(2),
    },
  },
  'opt_state': {
    '0': np.random.random(size=(2,)),
    '1': [np.ones(2), np.ones(2)],
  },
  'step': np.asarray(0),
}
mesh = jax.sharding.Mesh(jax.devices(), ('x',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None))
pytree = jax.tree.map(
    lambda arr: jax.make_array_from_callback(
        arr.shape,
        sharding,
        lambda idx: arr[idx],
    ),
    pytree,
)
ocp.save_pytree(path, pytree, overwrite=True)

## Partial Loading

Partial loading is a way to solve the most common use case of loading a different tree than is present in the checkpoint - where leaves or subtrees can be omitted. The canonical example is to skip loading the optimizer state when you're doing evaluation. There are a couple of ways to do this with the Partial Loading API. Let's take a look at both.

### Placeholder

Since we don't need the optimizer state (`opt_state`) during model evaluation, we can signal to Orbax to skip loading the leaves with that node by using the `ocp.PLACEHOLDER` (`...`) value.

In [12]:
abstract_tree = {
  'params': {
    'layer0': {
      'kernel': np.array([]),
      'bias': np.array([]),
    },
  },
  # Skip loading 'opt_state'
  'opt_state': {
    '0': ...,
    '1': [..., ...],
  },
  'step': np.array([]),
}

ocp.load_pytree(path, abstract_tree)

Note that `ocp.PLACEHOLDER` can only be used for leaves, so `opt_state: ocp.PLACEHOLDER` would not work. Keeping the structure consistent in this way is important for use cases like merging the original state with the restored state.

In [13]:
bad_abstract_tree = {
  'params': {
    'layer0': {
      'kernel': np.array([]),
      'bias': np.array([]),
    },
  },
  # Skip loading 'opt_state'
  'opt_state': ...,
  'step': np.array([]),
}

try:
  ocp.load_pytree(path, bad_abstract_tree)
except Exception as e:
  print(e)

Creating an abstract tree by hand is tedious. A more natural way to do this is by using something like JAX's `tree_map_with_path`.

In [14]:
def _create_abstract_leaf_for_partial_load(leaf_path, _):
  leaf_path = jax.tree_util.keystr(leaf_path, simple=True, separator='/')
  if (leaf_path.split('/')[0] == 'opt_state'):
    return ocp.PLACEHOLDER 
  else:
    return np.array([])

easy_abstract_tree = jax.tree.map_with_path(
  _create_abstract_leaf_for_partial_load,
  pytree
)

ocp.load_pytree(path, easy_abstract_tree)

We may not have direct access to the original PyTree when creating the abstract counterpart, and in that case, we'll need to use the on-disk `pytree_metadata`.

In [15]:
on_disk_pytree_structure = ocp.pytree_metadata(path).metadata

abstract_tree_from_metadata = jax.tree.map_with_path(
  _create_abstract_leaf_for_partial_load,
  on_disk_pytree_structure
)

ocp.load_pytree(path, abstract_tree_from_metadata)

### Omission

Alternatively, we can enable the `partial_load` option to avoid having to explicitly specify nodes to be skipped. Instead, we simply ignore those nodes during construction of the abstract PyTree.

In [16]:
abstract_tree = {
  'params': {
      'layer0': {
        'kernel': np.array([]),
        'bias': np.array([]),
      },
  },
  # Note: omit 'opt_state' to avoid loading it
  'step': 0,
}

# Loading PyTrees with certain leaves missing is unsafe
try:
  ocp.load_pytree(path, abstract_tree)
except ValueError as e:
  print(e)

# So partial_load must be opted-into
with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
        loading=ocp.options.PyTreeOptions.Loading(
            partial_load=True,
        ),
    ),
):
  ocp.load_pytree(path, abstract_tree)

### Model Surgery

While partial loading is useful for omitting parts of a PyTree, it does not allow for more complex manipulations. In contrast, the planned Model Surgery API is a powerful toolset where the user can manipulate trees and leaves in arbitrary ways. This includes restructuring trees, modifying values, and even loading and merging multiple distinct checkpoints into a single model in memory.

The core of this API will be user-defined transformation functions that are applied to checkpoints during the loading process.

#### Single-Model Transformations

A common use case for model surgery is transforming a single checkpoint into a different structure. For example, you might want to stack model layers that were saved individually. This can be accomplished with a transform_fn that takes the PyTree from the source checkpoint and returns a new, modified PyTree.

A potential API for this could look like:

In [19]:
ocp.load_and_transform = lambda *args: None

def stack_layers_transform(source_tree):
  params = source_tree['params']
  # Assumes layers are named 'layer0', 'layer1', etc.
  layer_keys = sorted([k for k in params if 'layer' in k])
  
  stacked_layers = jax.numpy.stack([params[k]['kernel'] for k in layer_keys])
  
  new_params = {'stacked_layers': stacked_layers}
  # Bring over any other parameters that are not part of the stacking.
  for k in params:
    if 'layer' not in k:
      new_params[k] = params[k]
      
  source_tree['params'] = new_params
  return source_tree

abstract_tree = ...

# The API would apply the transformation during loading.
restored_tree = ocp.load_and_transform(path, stack_layers_transform, abstract_tree)

#### Multi-Model Transformations

A more advanced use case is merging multiple checkpoints. A key example is creating a multi-modal model by combining two separately trained uni-modal models (e.g., an image model and a text model).

A transformation function for this scenario would accept multiple source trees and define how they should be combined.

In [18]:
def merge_models_transform(image_model_tree, text_model_tree):
  return {
      'params': {
          'image_encoder': image_model_tree['params'],
          'text_encoder': text_model_tree['params'],
          # A new fusion layer. The user can initialize it later.
          'fusion_layer': {
              'kernel': np.empty((512, 256)),
              'bias': np.empty((256,)),
          }
      },
      # Can also merge other things, like step counts etc.
      'step': image_model_tree['step'],
  }

image_model_path = ...
text_model_path = ...

# The API would take multiple paths and apply the transform.
final_model = ocp.load_and_transform(
    merge_models_transform,
    image_model_path,
    text_model_path,
)

This example also highlights an important feature: any parameters in the target structure that are not explicitly populated from a source checkpoint (like 'fusion_layer') would be initialized from scratch. This makes it easy to combine pre-trained components with new, untrained ones.

This planned API aims to provide maximum flexibility, making complex restoration and fine-tuning workflows more straightforward to implement.