### [TODO - Add Links for the checkpointing sections]

# Checkpointing and Exporting JAX Models: An End-to-End Guide with Orbax

This guide demonstrates a complete, end-to-end workflow for managing JAX models using the Orbax library, from robust training-time checkpointing to final model export. We will simulate a Flax/Optax setup to show how the `Checkpointer` API enables policy-based management and restoration of training states. Following that, we use the standalone `save_pytree` function to save the final parameters for inference. At the end, we export these parameters into a TensorFlow SavedModel with `orbax-export`.

## 1. Setup

First, we set up the necessary environment by installing the required packages and importing the modules we'll use throughout this guide.

Note: The following cells install the packages required for this guide. If you are running this within an internal Google environment where these dependencies are already available, these installation steps can be safely skipped.

### Installation

Install `orbax-checkpoint` for core checkpointing, `flax` and `optax` for the JAX model and optimizer, and `orbax-export` with `tensorflow` for exporting to the SavedModel format.

In [None]:
!pip install orbax-checkpoint flax optax



In [None]:
!pip install orbax-export tensorflow



### Imports

In [None]:
from orbax.checkpoint import v1 as ocp
import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
import optax
import os
import shutil
from etils import epath
from jax import tree_util

### Helper for Directory Management

A utility function to ensure a clean state for our checkpointing directories during each run of this tutorial.

In [None]:
def cleanup_directory_if_exists(path_str):
    """Removes a directory if it exists."""
    path = epath.Path(path_str)
    if path.exists():
        shutil.rmtree(path)

tutorial_base_dir = epath.Path('/tmp/orbax_tutorial')
cleanup_directory_if_exists(str(tutorial_base_dir))
tutorial_base_dir.mkdir(parents=True, exist_ok=True)
print(f"Tutorial artifacts will be saved under: {tutorial_base_dir}")

Tutorial artifacts will be saved under: /tmp/orbax_tutorial


## 2. Define a Simulated JAX State

We'll construct a PyTree representing our model's training state. This typically includes model parameters, optimizer state, and the current training step.

### Define a Model and Training State

We will define a basic Flax model, initialize its parameters, and create an Optax optimizer. The complete training state (model parameters, optimizer state, and step count) is stored in a Python dictionary. Sharding is applied to array elements using `jax.device_put`.

In [None]:
# Model Hyperparameters
input_dim = 64
hidden_dim = 32
output_dim = 10
batch_size_for_init = 4

class SimpleFlaxModel(nn.Module):
    hidden_dim: int
    output_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim, name="d1")(x)
        x = nn.relu(x)
        return nn.Dense(features=self.output_dim, name="d2")(x)

key = jax.random.PRNGKey(0)
model_instance = SimpleFlaxModel(hidden_dim, output_dim)

# Initialize model parameters with dummy data.
dummy_input_for_flax_init = jnp.ones((batch_size_for_init, input_dim))
initial_model_params_template = model_instance.init(key, dummy_input_for_flax_init)['params']
np_params = jax.tree_util.tree_map(np.array, initial_model_params_template)

# Initialize the optimizer state.
optimizer_instance = optax.adam(1e-3)
np_opt_state_template = optimizer_instance.init(initial_model_params_template)
# Convert all array-like elements to NumPy arrays, leaving others (like `count`) as-is.
np_opt_state = jax.tree_util.tree_map(lambda x: np.array(x) if hasattr(x, 'shape') else x, np_opt_state_template)

# Define sharding for the model (replicated across all devices).
mesh = jax.sharding.Mesh(jax.devices(), ('data',))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())

# Group components and apply sharding to all NumPy arrays in the PyTree.
pytree_components_np = {
    'params': np_params,
    'opt_state': np_opt_state,
}
pytree_components_jax = jax.tree_util.tree_map(lambda x: jax.device_put(x, replicated_sharding) if isinstance(x, np.ndarray) else x, pytree_components_np)

# Combine everything into the final training state PyTree.
simulated_train_state = {**pytree_components_jax, 'step': 0}
print("Initialized JAX training state PyTree with explicit sharding.")

Initialized JAX training state PyTree with explicit sharding.


## 3. Orbax Checkpointing Workflow

This section covers managing checkpoints during a simulated training loop using `Checkpointer`. This API is designed for common training scenarios and allows for powerful configuration through save policies.

### Create a Checkpointing Directory

We'll create a dedicated directory to store our training checkpoints and define a constant for our save interval.

In [None]:
training_ckpt_dir = tutorial_base_dir / 'simulated_training_ckpts'
cleanup_directory_if_exists(str(training_ckpt_dir))
training_ckpt_dir.mkdir(parents=True, exist_ok=True)

SAVE_INTERVAL_STEPS = 2

### Checkpointing During a Simulated Training Loop

We use `Checkpointer` as a context manager and configure it with a `FixedIntervalPolicy`. Inside the loop, `save_pytree(...)` is called on every step, but the policy ensures that a checkpoint is only written to disk when the condition (e.g., `step % 2 == 0`) is met.

In [None]:
# A simplified function to simulate a single training step.
def train_step_for_loop(state):
  new_state = state.copy() # Work with a mutable copy of the state dict.
  new_state['step'] += 1
  # For this demo, we simulate param changes by adding small random noise.
  key_for_noise = jax.random.PRNGKey(state['step'])
  new_state['params'] = jax.tree_util.tree_map(
        lambda p: p + 0.001 * jax.random.normal(key_for_noise, p.shape, p.dtype),
        state['params']
    )
  return new_state

current_loop_state = tree_util.tree_map(lambda x: x, simulated_train_state) # Start with a fresh copy.
num_training_steps = 7

print(f"Simulating {num_training_steps} training steps...")

with ocp.training.Checkpointer(
    directory=str(training_ckpt_dir),
    save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(SAVE_INTERVAL_STEPS)
) as ckptr:
    for _ in range(num_training_steps):
        step_to_save_at = current_loop_state['step']

        # `save_pytree` takes the current step, the state to save, and optional metrics.
        saved = ckptr.save_pytree(step_to_save_at, current_loop_state, metrics={'accuracy': 0.85})

        if saved: # Will be True if the save_decision_policy decided to save.
            print(f"  Saved checkpoint for step {step_to_save_at}...")

        current_loop_state = train_step_for_loop(current_loop_state)

Simulating 7 training steps...
  Saved checkpoint for step 0...
  Saved checkpoint for step 2...
  Saved checkpoint for step 4...
  Saved checkpoint for step 6...


### Resuming from a Checkpoint

To resume training, we use `training.Checkpointer.load_pytree`. Orbax-checkpoint can automatically find the latest completed checkpoint. We provide an `abstract_pytree` (an empty or example version of our state) to guide the restoration process and ensure the data is loaded with the correct structure and sharding.

In [None]:
with ocp.training.Checkpointer(directory=str(training_ckpt_dir)) as ckptr:
    print(f"Restore from the latest checkpoint in {training_ckpt_dir}...")

    # It returns None if no checkpoint is found.
    resumed_train_state = ckptr.load_pytree(
        abstract_pytree=simulated_train_state # Provide an abstract state for structure and sharding.
    )

# If a checkpoint was successfully loaded, resumed_train_state will not be None.
if resumed_train_state is not None:
    print(f"Restored state successfully. Resuming from step: {resumed_train_state['step']}")
    with ocp.training.Checkpointer(directory=str(training_ckpt_dir)) as ckptr:
        assert resumed_train_state['step'] == ckptr.latest.step
else:
    # If no checkpoint was found, fall back to the initial state.
    print("No checkpoint found to restore; using initial state.")
    resumed_train_state = simulated_train_state

Restoring from latest checkpoint: step 6...
Restored state. Internal step of loaded state: 6


## 4. Saving Final JAX Parameters for Export

After training, you often need to save just the final model parameters for inference or export. For this, Orbax provides the simple `save_pytree` function, which is ideal for one-off saves without the overhead of training policies.

### Extract Final Parameters for Saving

We extract the learned parameters from our final training state, as this is the only part we need for inference.

In [None]:
final_params_save_dir = tutorial_base_dir / 'exported_model_params_orbax'
final_model_params_to_save = current_loop_state['params']
print("Final model parameters extracted for saving.")

Final model parameters extracted for saving.


### Using `save_pytree` for the Final Save

`save_pytree` directly saves the given PyTree to the specified directory. It's a straightforward way to persist the final artifacts of a training process.

In [None]:
# Ensure a clean state by removing the directory if it exists from a previous run.
cleanup_directory_if_exists(str(final_params_save_dir))

print(f"Saving final parameters to: {final_params_save_dir}...")
ocp.save_pytree(
    path=final_params_save_dir,
    pytree=final_model_params_to_save,
    overwrite=True #  overwrites an existing checkpoint in directory
)
print("Final model parameters saved via `save_pytree`.")

Saving final parameters to: /tmp/orbax_tutorial/exported_model_params_orbax...
Final model parameters saved via `save_pytree`.


### Loading Exported Parameters (Verification)

We can use `load_pytree` to load the parameters back and verify that the save operation was successful. Again, we can pass an `abstract_pytree` to help guide the restoration.

In [None]:
if final_params_save_dir.exists() and len(os.listdir(str(final_params_save_dir))) > 0:
    print(f"Loading parameters from {final_params_save_dir} for verification...")
    loaded_final_params = ocp.load_pytree(
        final_params_save_dir,
        abstract_pytree=final_model_params_to_save # Use instance as a template for structure and sharding.
    )
    # Check that the loaded parameters match the original ones.
    params_match = jax.tree_util.tree_all(
        jax.tree_util.tree_map(jnp.array_equal, final_model_params_to_save, loaded_final_params)
    )
    print(f"Verification: {'PASSED' if params_match else 'FAILED'}")
else:
    print("Saved parameters directory not found or empty. Skipping verification.")

Loading parameters from /tmp/orbax_tutorial/exported_model_params_orbax for verification...
Verification: PASSED


## 5. Exporting to TensorFlow SavedModel

This section demonstrates converting the saved JAX model parameters into a TensorFlow SavedModel format using the [`orbax-export`](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html) library. This is a common step for for exporting JAX models to TensorFlow [SavedModel](https://www.tensorflow.org/guide/saved_model) format.

In [None]:
from orbax.export import ExportManager, JaxModule, ServingConfig
from orbax.export.validate.validation_manager import ValidationManager
import tensorflow as tf
import traceback
import sys

### Define JAX Model Apply Function and Pre/Post-processing for Export

For [`orbax-export`](https://orbax.readthedocs.io/en/latest/guides/export/orbax_export_101.html), we need to provide a JAX function that takes `(params, inputs)`. We can also define TensorFlow-based pre-processing and post-processing functions, which will be included in the SavedModel's computation graph.

In [None]:
# `model_instance` was defined in Section 2 (the SimpleFlaxModel instance).
# `final_model_params_to_save` contains the parameters we want to export from Section 4.

# JAX Apply Function: The core JAX logic for the model's forward pass.
@jax.jit
def jax_model_apply_fn_for_export(params, inputs):
  """A JAX function with the signature (params, inputs) for orbax-export."""
  return model_instance.apply({'params': params}, inputs)


# Optional: TF Pre-processing Function.
def tf_preprocess_fn_for_export(input_tensor: tf.Tensor) -> tf.Tensor:
  """Normalizes the raw input tensor. Orbax-export will trace this into a graph."""
  return tf.cast(input_tensor, tf.float32) / 255.0


# Optional: TF Post-processing Function.
def tf_postprocess_fn_for_export(output_tensor: tf.Tensor) -> dict[str, tf.Tensor]:
  """Packages the model output into a dictionary. Orbax-export will trace this."""
  return {'predictions': output_tensor}

print("JAX apply function and plain TF pre/post-processing functions defined for export.")

JAX apply function and plain TF pre/post-processing functions defined for export.


### Create `JaxModule` and `ServingConfig`

[`JaxModule`](https://orbax.readthedocs.io/en/latest/api_reference/export.jax_module.html#id1) wraps the JAX function and its parameters. [`ServingConfig`](https://orbax.readthedocs.io/en/latest/api_reference/export.serving_config.html#id1) defines the input signature for the SavedModel and specifies which pre/post-processing functions to use for a given serving signature key (e.g., `serving_default`).

In [None]:
# Create the JaxModule, which encapsulates the JAX function and its parameters.
jax_module_for_export = JaxModule(
    params=final_model_params_to_save,
    apply_fn=jax_model_apply_fn_for_export,
    input_polymorphic_shape=f'(b, {input_dim})',
    jax2tf_kwargs={'with_gradient': False, 'native_serialization': False}
)

# This tells orbax-export how to trace the Python preprocessor function.
tf_input_signature = [
    tf.TensorSpec(shape=[None, input_dim], dtype=tf.float32)
]

# Create a serving configuration that bundles the signature key, input specs,
# and our Python processing functions.
serving_config = ServingConfig(
    signature_key='serving_default',
    input_signature=tf_input_signature,
    tf_preprocessor=tf_preprocess_fn_for_export,
    tf_postprocessor=tf_postprocess_fn_for_export
)
print("JaxModule and ServingConfig created successfully.")

JaxModule and ServingConfig created successfully.


### Export to TensorFlow SavedModel

The [`ExportManager`](https://orbax.readthedocs.io/en/latest/api_reference/export.export_manager.html#id1) takes the `JaxModule` and a list of `ServingConfig` to build and save the final TensorFlow SavedModel.

In [None]:
# Define the directory to save the final exported model.
saved_model_dir = tutorial_base_dir / 'tf_saved_model_orbax_export'
cleanup_directory_if_exists(str(saved_model_dir))
saved_model_dir.mkdir(parents=True, exist_ok=True)

# The ExportManager orchestrates the JAX-to-TF conversion and saving process.
export_manager = ExportManager(jax_module_for_export, [serving_config])
print(f"Exporting SavedModel to: {saved_model_dir}")
try:
    export_manager.save(str(saved_model_dir))
    print("Model exported successfully to SavedModel format.")
    print(f"Contents of {saved_model_dir}: {os.listdir(str(saved_model_dir))}")
except Exception as e:
    print(f"ERROR during SavedModel export: {e}")
    import traceback
    traceback.print_exc()

Exporting SavedModel to: /tmp/orbax_tutorial/tf_saved_model_orbax_export
Model exported successfully to SavedModel format.
Contents of /tmp/orbax_tutorial/tf_saved_model_orbax_export: ['fingerprint.pb', 'assets', 'saved_model.pb', 'variables']


### Validate the Exported Model

A critical final step is to verify that the exported TensorFlow model produces the same results as the original JAX model. We use the [`ValidationManager`](https://orbax.readthedocs.io/en/latest/api_reference/export.validate.validation_manager.html#orbax.export.validate.validation_manager.ValidationManager), which compares the outputs of the JAX model and the loaded TF SavedModel for a given batch of inputs and generates a detailed report.

In [None]:
# Prepare a batch of test inputs. These should be "raw" (pre-preprocessing).
validation_batch_size = 4
raw_validation_inputs = np.random.rand(validation_batch_size, input_dim).astype(np.float32) * 255.0

# To match the positional signature pass inputs as a list of lists.
validation_mgr = ValidationManager(
    module=jax_module_for_export,
    serving_configs=[serving_config],
    model_inputs=[[raw_validation_inputs]]
)

# Load the candidate model we want to validate.
loaded_tf_model = tf.saved_model.load(str(saved_model_dir))

# Run the validation, which compares the JAX and TF outputs.
print("\nRunning validation...")
validation_reports = validation_mgr.validate(loaded_tf_model)

# Check the report. The report is a dict keyed by the signature_key.
report = validation_reports['serving_default']

# The report status is an enum. We check its string name for a simple pass/fail result.
if report.status.name == 'Pass':
    print(f"VERIFICATION PASSED! Status: {report.status.name}")
else:
    print(f"VERIFICATION FAILED! Status: {report.status.name}")

# The report can be printed as a JSON string for detailed inspection of differences and latencies.
print("\nValidation Report:")
print(report.to_json(indent=2))


Running validation...
VERIFICATION PASSED! Status: Pass

Validation Report:
{
  "outputs": {
    "FloatingPointDiffReport": {
      "total": 40,
      "max_diff": 1.4901161193847656e-07,
      "max_rel_diff": 5.774800229119137e-06,
      "all_close": true,
      "all_close_absolute_tolerance": 1e-07,
      "all_close_relative_tolerance": 1e-07
    },
    "NonFloatingPointDiffReport": {
      "total_flattened_tensors": 0,
      "mismatches": 0,
      "mismatch_ratio": 0.0,
      "max_non_floating_mismatch_ratio": 0.01
    }
  },
  "latency": {
    "baseline": {
      "num_batches": 1,
      "avg_in_ms": 2.8274059295654297,
      "p90_in_ms": 2.8274059295654297,
      "p99_in_ms": 2.8274059295654297
    },
    "candidate": {
      "num_batches": 1,
      "avg_in_ms": 3.104686737060547,
      "p90_in_ms": 3.104686737060547,
      "p99_in_ms": 3.104686737060547
    }
  },
  "xprof_url": {
    "baseline": "N/A",
    "candidate": "N/A"
  },
  "metadata": {
    "baseline": {},
    "candidate