# Orbax Checkpointing for PyTorch Users


This tutorial serves as an onboarding guide for developers familiar with PyTorch, aiming to smooth their transition to JAX and Orbax for checkpointing. It complements existing Orbax documentation by specifically demonstrating how to map common PyTorch practices for saving and loading models to their JAX/Orbax equivalents.


### Core Differences
The fundamental difference lies in how state is managed. PyTorch uses an object-oriented approach where state is typically captured in a `state_dict` (a Python dictionary).  In contrast, JAX adopts a functional paradigm where the entire training state, such as model parameters, optimizer state, and step number, is explicitly managed in a single, nested data structure called a [**PyTree**](https://docs.jax.dev/en/latest/pytrees.html). Orbax is designed to efficiently save and load these PyTrees. See [Checkpointing PyTrees](https://orbax.readthedocs.io/en/latest/guides/checkpoint/checkpointing_pytrees.html) for more detail.

The following table provides a high-level, side-by-side comparison of the two approaches:

| Feature | **Orbax Checkpointing** | **PyTorch** |
| :--- | :--- | :--- |
| **Core API** | `orbax.checkpoint.PyTreeCheckpointer` and the high-level `CheckpointManager`. | `torch.save()` and `torch.load()`.|
| **Data Structure** | Saves a JAX **PyTree** - a nested structure containing parameters, optimizer state, and any other metadata in a single object.| Saves a standard Python `dictionary`, typically containing the model and optimizer `state_dict()`.
| **Basic Save** | `checkpointer.save(path, args=ocp.args.StandardSave(state))` | `torch.save({'model_state_dict': model.state_dict()}, path)` |
| **Basic Load** | `restored_state = checkpointer.restore(path)` | `model.load_state_dict(torch.load(path))`


## 1. Setup

First, we set up the necessary environment by installing the required packages and importing the modules used throughout this guide.

### Installation

Install the latest `orbax-checkpoint` for core checkpointing, along with `jax`, `flax`, and `optax` for the JAX model and optimizer.

In [1]:
!pip install -q --upgrade orbax-checkpoint jax[cuda12] flax optax

## 2. Checkpointing: Saving & Loading Training Progress

### 2.1 PyTorch Recap: Checkpointing

Let's begin with the familiar PyTorch pattern. You use `torch.save()` to store a model's `state_dict()` and other training information in a dictionary, and `torch.load()` to retrieve it. You then apply the loaded parameters to a model instance using `model.load_state_dict()`.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import os
import tempfile
import shutil

# Define a simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.linear1(x))
        return self.linear2(x)

# Create model and optimizer
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Simulate some training
dummy_input = torch.randn(32, 10)
dummy_target = torch.randn(32, 1)
loss_fn = nn.MSELoss()

for step in range(5):
    optimizer.zero_grad()
    output = model(dummy_input)
    loss = loss_fn(output, dummy_target)
    loss.backward()
    optimizer.step()
    print(f"Step {step}, Loss: {loss.item():.4f}")

# Save checkpoint
tmpdir = tempfile.mkdtemp()
checkpoint_path = os.path.join(tmpdir, 'pytorch_checkpoint.pth')
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'step': 5,
    'loss': loss.item()
}
torch.save(checkpoint, checkpoint_path)
print(f"\nSaved checkpoint to {checkpoint_path}")

# Load checkpoint
loaded_checkpoint = torch.load(checkpoint_path)
model.load_state_dict(loaded_checkpoint['model_state_dict'])
optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
step = loaded_checkpoint['step']
loss = loaded_checkpoint['loss']

print(f"Loaded checkpoint from step {step} with loss {loss:.4f}")
shutil.rmtree(tmpdir)

Step 0, Loss: 0.7590
Step 1, Loss: 0.7560
Step 2, Loss: 0.7531
Step 3, Loss: 0.7501
Step 4, Loss: 0.7471

Saved checkpoint to pytorch_checkpoint.pth
Loaded checkpoint from step 5 with loss 0.7471


### 2.2 JAX/Orbax Equivalent: Functional Checkpointing with `PyTreeCheckpointer`

Now, let's look at the equivalent workflow in JAX with Orbax. We manage the entire training state (parameters, optimizer state, etc.) as a single PyTree. Orbax is designed to save and load these PyTrees efficiently.

We use a [`PyTreeCheckpointer`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpointers.html#orbax.checkpoint.PyTreeCheckpointer) to save our complete training state in one go. This clear, explicit management of state is a core part of the JAX philosophy.

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as fnn
import optax
import orbax.checkpoint as ocp
from flax.training import train_state
import tempfile
import shutil
import os

# Define a simple model using Flax
class SimpleNet(fnn.Module):
    @fnn.compact
    def __call__(self, x):
        x = fnn.Dense(5)(x)
        x = fnn.relu(x)
        x = fnn.Dense(1)(x)
        return x

# Initialize model, optimizer, and the Flax TrainState object.
# `state` is the PyTree that holds everything we need to save.
model = SimpleNet()
key = jax.random.PRNGKey(42)
dummy_input = jax.random.normal(key, (32, 10))
params = model.init(key, dummy_input)['params']
tx = optax.adam(0.001)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx
)

# Define the Training Step
@jax.jit
def train_step(state, batch_input, batch_target):
    def loss_fn(params):
        predictions = state.apply_fn({'params': params}, batch_input)
        return jnp.mean((predictions - batch_target) ** 2)
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Simulate a Few Training Steps
dummy_target = jax.random.normal(key, (32, 1))
for _ in range(5):
    state, loss = train_step(state, dummy_input, dummy_target)
    print(f"Step {state.step}, Loss: {loss:.4f}")

# Initialize a PyTreeCheckpointer, designed for saving single PyTrees.
checkpointer = ocp.PyTreeCheckpointer()
checkpoint_dir = tempfile.mkdtemp()
checkpoint_path = os.path.join(checkpoint_dir, 'my_checkpoint')

# Use `ocp.args.PyTreeSave` to wrap the state.
save_args = ocp.args.PyTreeSave(item=state)
checkpointer.save(checkpoint_path, args=save_args, force=True)
print(f"\nSaved checkpoint to {checkpoint_dir}")

# Restore Checkpoint by providing the original `state` as a template.
restored_state = checkpointer.restore(checkpoint_path, item=state)

print(f"Loaded checkpoint at step {restored_state.step}")

# Verify that the parameters of the original state and the restored state are identical.
are_params_equal = jax.tree_util.tree_all(
    jax.tree_util.tree_map(lambda x, y: jnp.allclose(x, y), state.params, restored_state.params)
)
print("Parameters match:", are_params_equal)

# Clean up the temporary directory.
shutil.rmtree(checkpoint_dir)



Step 1, Loss: 1.7981
Step 2, Loss: 1.7822
Step 3, Loss: 1.7664
Step 4, Loss: 1.7507
Step 5, Loss: 1.7352

Saved checkpoint to /tmp/tmpphbau1yv




Loaded checkpoint at step 5
Parameters match: True


### 2.3 Advanced Checkpoint Management with `CheckpointManager`
For robust training loops, Orbax provides the [`CheckpointManager`](https://orbax.readthedocs.io/en/latest/api_reference/checkpoint.checkpoint_manager.html#id1). This high-level utility automates the checkpointing process, handling saving and cleaning up old checkpoints based on rules you define, such as keeping only the N most recent checkpoints. This ensures a clean directory and simplifies training state management.

In [None]:
from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions

# Define state initialization
def init_state():
    key = jax.random.PRNGKey(42)
    dummy_input = jax.random.normal(key, (32, 10))
    model = SimpleNet()
    params = model.init(key, dummy_input)['params']
    tx = optax.adam(0.001)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx
    )

# Define the training step function once, outside the loop.
@jax.jit
def train_step(state, batch_input, batch_target):
    def loss_fn(params):
        predictions = state.apply_fn({'params': params}, batch_input)
        return jnp.mean((predictions - batch_target) ** 2)
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# --- Main Logic ---
state = init_state()
key = jax.random.PRNGKey(42)
dummy_input = jax.random.normal(key, (32, 10))
dummy_target = jax.random.normal(key, (32, 1))
managed_checkpoint_dir = tempfile.mkdtemp()
options = CheckpointManagerOptions(max_to_keep=3, create=True)
checkpoint_manager = CheckpointManager(managed_checkpoint_dir, options=options)

print("Training with automatic checkpoint management:")
for step in range(1, 21):
    state, loss = train_step(state, dummy_input, dummy_target)
    print(f"Step {state.step}: Loss: {loss:.4f}")

    if step % 6 == 0:
        checkpoint_manager.save(step, args=ocp.args.StandardSave(state))
        print(f"  -> Saved checkpoint for step {step}")


print(f"\nAvailable checkpoints: {checkpoint_manager.all_steps()}")
latest_step = checkpoint_manager.latest_step()
print(f"Latest checkpoint step: {latest_step}")

# Restore from step 18
abstract_state = jax.eval_shape(init_state)
restored_state = checkpoint_manager.restore(
    latest_step,
    args=ocp.args.StandardRestore(abstract_state)
)

print(f"\nRestored checkpoint from step {restored_state.step}")

shutil.rmtree(managed_checkpoint_dir)

Training with automatic checkpoint management:
Step 1: Loss: 1.7981
Step 2: Loss: 1.7822
Step 3: Loss: 1.7664
Step 4: Loss: 1.7507
Step 5: Loss: 1.7352
Step 6: Loss: 1.7198
  -> Saved checkpoint for step 6
Step 7: Loss: 1.7046
Step 8: Loss: 1.6895
Step 9: Loss: 1.6746
Step 10: Loss: 1.6598
Step 11: Loss: 1.6452
Step 12: Loss: 1.6308
  -> Saved checkpoint for step 12
Step 13: Loss: 1.6165
Step 14: Loss: 1.6024
Step 15: Loss: 1.5885
Step 16: Loss: 1.5749
Step 17: Loss: 1.5614
Step 18: Loss: 1.5480
  -> Saved checkpoint for step 18
Step 19: Loss: 1.5348
Step 20: Loss: 1.5218

Available checkpoints: [6, 12, 18]
Latest checkpoint step: 18

Restored checkpoint from step 18
