# Partial Saving

As deep learning models grow, often to hundreds of billions of parameters, managing their checkpoints becomes a significant challenge. Modifying a large checkpoint, even for a small change like adding metrics or a single layer, traditionally requires an inefficient "load-modify-save" cycle. This process uses lots memory and I/O bandwidth, as the entire multi-terabyte checkpoint must be loaded from storage, changed in memory, and written back out.

Partial saving is designed to solve this problem by allowing you to modify a checkpoint without loading the entire object into host memory. It dramatically reduces peak memory usage, minimizes redundant I/O, and simplifies common model update workflows.

### The Core Concept: The Partial Save Session

Partial saving operates on a "session" or "transaction" model. Instead of overwriting your checkpoint directly, Orbax stages all changes in a temporary, in-progress location. The workflow consists of two stages:

1. **Incremental Updates**: Calls to functions like `ocp.partial.save_pytree` contribute data to an in-progress checkpointing session. These changes are staged in a temporary location and are not yet visible at the final checkpoint path. From the user's perspective, the first save call simply begins this incremental process, and subsequent calls add to it.
2. **Finalization**: A concluding call to `ocp.partial.finalize` completes the session. This action commits all the staged changes, making the checkpoint available at its final destination and ready for consumption.

This approach ensures that the modification process is safe and atomic. If the process is interrupted before finalization, your original checkpoint remains untouched.

> **Note**: Partial saving currently does *NOT* support replacing data written out in previous save calls. If you have a need for Partial Saving Replacement (as opposed to the currently supported Partial Saving Addition), please reach out to the Orbax Checkpointing team so that development of this feature can be prioritized.
>
> The canonical way to do Replacement without partial saving is by loading the model, updating values in memory, then saving back out.

### API and Basic Usage: Adding to a Checkpoint

The partial saving API is available in the `orbax.checkpoint.v1.partial` module, but you'll likely access it via `ocp.partial`.

The most common (and only supported) use case is adding new data (leaves or subtrees) to an existing PyTree checkpoint. The provided PyTree in a save call represents a set of updates. If a key does not exist in the on-disk checkpoint, it is treated as an addition. If a key already exists, it is viewed as a replacement (currently not allowed), and results in a `NotImplementedError`.

### Code Example: A Simple Addition Workflow

Let's start with an initial training state, then update that state with new data in a separate step.

In [1]:
from orbax.checkpoint import v1 as ocp
import numpy as np
import jax
from etils import epath

#### Initial Save

Let's say we have an initial training state. The first call creates a temporary directory (e.g., `/tmp/partial_save/ckpt.partial_save`) and saves the initial state there.

In [2]:
path = epath.Path('/tmp/partial_save/ckpt')
path.parent.rmtree(missing_ok=True)

initial_state = {
    'params': {
        'layer0': np.arange(8),
    },
    'step': 10000,
}

ocp.partial.save_pytree(path, initial_state)
assert not path.exists()
assert (path.parent / (path.name + '.partial_save')).exists()

#### Add More Data

After training some more, we have a new layer ready to be added. A subsequent call adds the new layer to the same temporary directory. Orbax merges the new PyTree with the existing one.

In [3]:
new_state = {
    'params': {
        'layer1': np.ones(4),
    },
}

ocp.partial.save_pytree(path, new_state)
assert not path.exists()
assert (path.parent / (path.name + '.partial_save')).exists()

#### Aside: Loading Before Finalizing

Before finalizing the checkpoint, let's see what happens if we try to load the partial checkpoint.

In [4]:
try:
  ocp.load_pytree(path)
except Exception as e:
  print("LOAD ERROR")
  print(e)

#### Finalize the Checkpoint

This atomically renames the temporary directory to the final path, making it a complete, readable checkpoint.

In [5]:
ocp.partial.finalize(path)
assert not (path.parent / (path.name + '.partial_save')).exists()
assert path.exists()

#### Verify the Result

Now, we can load the checkpoint and see the merged result.

In [6]:
restored_state = ocp.load_pytree(path)

expected_state = {
  'params': {
    'layer0': np.array([0, 1, 2, 3, 4, 5, 6, 7]),
    'layer1': np.array([1., 1., 1., 1.])
  },
  'step': 10000,
}

def is_equal(x, y):
  if isinstance(x, np.ndarray):
    assert np.allclose(x, y)
  else:
    assert x == y

jax.tree.map(is_equal, restored_state, expected_state)
restored_state

### API Reference

 - `ocp.partial.save_pytree()` / `ocp.partial.save_pytree_async()`: Saves a PyTree to the temporary partial save location. These functions can be called multiple times.
 - `ocp.partial.finalize()`: Commits the transaction, making the checkpoint permanent at the specified path. This must be called to complete the process.

### Advanced Workflow: Combining Partial Saving and Partial Restore

When combined with Partial Restore, this feature can enable highly efficient, targeted updates to massive checkpoints with a minimal memory footprint. You can use Partial Restore for a memory-efficient *read*, perform modifications, and then use Partial Save for a flexible and efficient *write*.

#### Use Case: Creating an Inference-Ready Checkpoint

Imagine you have a 2TB training checkpoint containing model params and a bulky optimizer_state. You want to create a smaller, inference-ready checkpoint that:
 - Contains only the `params`.
 - Has an updated `encoder_stack` within the params from a recent fine-tuning run.

This entire process can be done without ever loading the massive optimizer_state into memory.

In [7]:
from orbax.checkpoint import v1 as ocp
import numpy as np
import jax
from etils import epath

#### Setup

Create a large, multi-part "base" checkpoint to simulate a real scenario. This represents a very large model, but we only write it to disk. We never load it all at once (other than to view the metadata).

In [9]:
base_path = epath.Path('/tmp/base_model/ckpt')
base_path.rmtree(missing_ok=True)

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(
        'model',
    ),
)
create_sharded_array = lambda x: jax.device_put(x, sharding)
base_model_state = {
    'params': {
        'large_embedding_table': np.ones((1024, 1024)), # A large array
        'encoder_stack': {f'layer_{i}': np.random.rand(2) for i in range(4)}, # The part we will replace
        'classification_head': np.random.rand(8),
    },
    'optimizer_state': [np.random.rand(128) for _ in range(16)],
}
base_model_state = jax.tree.map(create_sharded_array, base_model_state)
ocp.save_pytree(base_path, base_model_state)

abstract_base_model_state = jax.tree.map(
    ocp.arrays.to_shape_dtype_struct,
    base_model_state
)
init_ckpt = ocp.load_pytree(base_path, abstract_base_model_state)
print("\n--- Setup ---")
print(f"Optimizer state exists in initial checkpoint: {'optimizer_state' in init_ckpt}")
print(f"Model version exists in initial checkpoint: {'model_version' in init_ckpt}")
for layer, weights in init_ckpt['params']['encoder_stack'].items():
    print(f"Original {layer}: {weights}")

#### The Efficient Update Workflow

Use Partial Restore (Omission mode) to load ONLY the `params`. Create a reference PyTree that only has the `params` structure. This tells Orbax to ignore everything else (like `optimizer_state`). Enable partial loading via `Context` to allow omitting nodes.

In [10]:
inference_path = epath.Path('/tmp/inference_model/ckpt')
inference_path.parent.rmtree(missing_ok=True)

abstract_params = jax.tree.map(
    ocp.arrays.to_shape_dtype_struct, {'params': base_model_state['params']}
)

with ocp.Context(
    pytree_options=ocp.options.PyTreeOptions(
        loading=ocp.options.PyTreeOptions.Loading(partial_load=True)
    )
):
    loaded_params = ocp.load_pytree(base_path, abstract_params)

At this point, `params` is in memory, but `optimizer_state` was never loaded.

#### Update and Partial Save

Modify the loaded parameters in memory. Add new metadata that might be useful for inference. Use Partial Save to write the modified `params` and new metadata to the new inference checkpoint location. Finalize the new, smaller, inference-ready checkpoint.

In [13]:
save_params = {}  # Used to create abstract params for inference

metadata = {'model_version': 'v1.2-finetuned'}
save_params = ocp.tree.merge(save_params, metadata)
ocp.partial.save_pytree(inference_path, metadata)  # Initial partial save for metadata

for layer, weights in loaded_params['params']['encoder_stack'].items():
  new_weights = weights + np.random.rand(2)
  stack_layer = {
      'params': {
          'encoder_stack': {
              layer: jax.tree.map(
                  create_sharded_array, new_weights
              ),
          }
      },
  }
  save_params = ocp.tree.merge(save_params, stack_layer)
  ocp.partial.save_pytree(inference_path, stack_layer)  # One partial save per layer

ocp.partial.finalize(inference_path)

#### Verification

In [14]:
abstract_params = jax.tree.map(
    lambda x: (
        str()
        if isinstance(x, str)
        else ocp.arrays.to_shape_dtype_struct(x)
    ),
    save_params
)
final_ckpt = ocp.load_pytree(inference_path, abstract_params)

print("\n--- Verification ---")
print(f"Optimizer state exists in final checkpoint: {'optimizer_state' in final_ckpt}")
print(f"Model version: {final_ckpt['model_version']}")
for layer, weights in final_ckpt['params']['encoder_stack'].items():
    print(f"New {layer}: {weights}")

In this workflow, we created a new, pruned, and modified checkpoint. The key efficiency gain came from using Partial Restore to load only the params, completely avoiding the memory cost of the massive `optimizer_state`.

### Atomicity Guarantees

The use of a temporary directory and an atomic rename operation during finalization guarantees safety. If your program crashes mid-save, the original checkpoint (if any) is unharmed, and the temporary directory can be safely deleted.