## Orbax Export for PyTorch Users
This tutorial is a guide for developers familiar with PyTorch, aiming to smooth their transition to JAX and Orbax for model exporting. It complements the existing Orbax documentation by demonstrating how to map the common PyTorch practice of exporting models for inference to its JAX/Orbax equivalent.

## Core Differences
PyTorch's `torch.export` tool creates a self-contained ExportedProgram that includes the model's graph and weights. This program can be saved for deployment. JAX, on the other hand, is a functional framework that needs both the model's parameters and its forward pass function. The parameters are stored in a [PyTree](https://www.google.com/url?q=https%3A%2F%2Fdocs.jax.dev%2Fen%2Flatest%2Fpytrees.html). Orbax's export utilities package these two components into the widely-used TensorFlow SavedModel format, which allows JAX models to be served in production environments like TensorFlow Serving.

The following table provides a high-level, side-by-side comparison of the two approaches:

| Feature | **Orbax Export** | **PyTorch** |
| :--- | :--- | :--- |
| **Core API** | `orbax.export.JaxModule and orbax.export.ExportManager`. | `torch.export.export(), torch.export.save(), and torch.export.load()`.|
| **Data Structure** | Packages a JAX PyTree of parameters along with a Python function (e.g., `model.apply`).| Creates an `ExportedProgram` object containing the model's graph, state dictionary, and buffers.
| **Output Format** | TensorFlow SavedModel | A self-contained `.pt2` file format |
| **Basic Workflow** | Wrap parameters and apply function in a `JaxModule`, then use `ExportManager` to save as a `SavedModel.` | Pass the model and inputs to `torch.export.export()` to generate an `ExportedProgram`, then save it.

## 1. Setup

### Installation

Start by installing the required packages. This includes `orbax.export` for the main export features, `jax` and `flax` for model building, and tensorflow for the target export format. `torch` is also included for the comparative example.

In [1]:
!pip install -q orbax-export jax[cuda12] flax tensorflow torch

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/180.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.5/180.5 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/55.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.4/55.4 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.5/40.5 MB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[?25h

## 2. Model Export: Preparing Models for Inference
This section covers how to prepare a trained model for deployment, starting with the standard PyTorch approach and then showing the equivalent method using JAX and Orbax.

### 2.1 PyTorch Recap: Exporting with `torch.export`
The modern approach to exporting PyTorch models is to use the torch.export library. This tool traces the model's execution using sample inputs to produce an ExportedProgram. This object is a portable and standardized representation of the model, encapsulating both the computation graph and the learned weights (`state_dict`). The program can be saved as a `.pt2` file, which can then be loaded in other environments for efficient inference.

In [2]:
import torch
import torch.nn as nn
import tempfile
import os

# Define a simple PyTorch model
class PyTorchSimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 5)
        self.linear2 = nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.linear1(x))
        x = self.linear2(x)
        return x

# Create model instance and dummy input for tracing
pytorch_model = PyTorchSimpleNet()
pytorch_model.eval()  # Set to evaluation mode
dummy_input = torch.randn(1, 10) # '1' for batch size

# Export the model to an ExportedProgram
exported_program = torch.export.export(pytorch_model, (dummy_input,))
print("Model successfully exported to an ExportedProgram.")

# Save the exported program to a file
tmpdir = tempfile.mkdtemp()
EXPORT_PATH = os.path.join(tmpdir, 'exported_model.pt2')
torch.export.save(exported_program, EXPORT_PATH)
print(f"ExportedProgram saved to {EXPORT_PATH}")

Model successfully exported to an ExportedProgram.
ExportedProgram saved to exported_model.pt2


### 2.2 Loading and Verifying the Exported Model
After saving the model, it's crucial to verify its integrity. The saved `.pt2` file can be loaded back into a new ExportedProgram object using `torch.export.load()`. You can then run inference on both the original and loaded models with the same input to ensure their outputs are identical.

In [3]:
# Load the program and verify
loaded_program = torch.export.load(EXPORT_PATH)
print("ExportedProgram loaded successfully.")

# Run inference with both original and loaded models to ensure they match
with torch.no_grad():
    original_output = pytorch_model(dummy_input)
    loaded_output = loaded_program.module()(dummy_input)

diff = torch.max(torch.abs(original_output - loaded_output)).item()
print(f"Output difference between original and loaded models: {diff:.6f}")

# Clean up
os.remove(EXPORT_PATH)

ExportedProgram loaded successfully.
Output difference between original and loaded models: 0.000000


## 3. Exporting JAX/Flax Models with Orbax

### 3.1 JAX/Orbax Equivalent: Functional Export with `JaxModule`

Model export in the JAX ecosystem focuses on converting models to the standard TensorFlow **SavedModel** format, which is ideal for production serving. The process involves two main components from `orbax.export`:

1. [`JaxModule`](https://orbax.readthedocs.io/en/latest/api_reference/export.jax_module.html): This class acts as a container, bundling your model's parameters (a PyTree) with its forward pass function (`apply_fn`). This treats the entire model as a single, exportable unit.

2. [`ExportManager`](https://orbax.readthedocs.io/en/latest/api_reference/export.export_manager.html): This utility orchestrates the conversion of a `JaxModule` into a TensorFlow SavedModel, saving the result to a specified directory.



### 3.2 Imports

In [4]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Sequence
import tensorflow as tf
from orbax.export import JaxModule
from orbax.export import ServingConfig, ExportManager
import shutil


### 3.3 JAX Model Definition and Preparation

To begin, define a simple MLP using Flax. Unlike stateful PyTorch models, a Flax model is defined by its architecture and needs a separate **PyTree** of parameters (`initial_params`) to be initialized.

In [5]:
# Define the JAX/Flax Model
class SimpleMLP(nn.Module):
    features: Sequence[int]

    @nn.compact
    def __call__(self, x):
        # Flatten the input image
        x = x.reshape((x.shape[0], -1))
        for i, dim in enumerate(self.features):
            x = nn.Dense(features=dim, name=f'dense_{i}')(x)
            if i < len(self.features) - 1:
                x = nn.relu(x)
        return x

# Initialize model and its parameters
INPUT_DIM = (28, 28, 1)
OUTPUT_FEATURES = 10
model = SimpleMLP(features=[128, 64, OUTPUT_FEATURES])
key = jax.random.PRNGKey(42)

# Create dummy input to initialize the model's parameters
dummy_input = jnp.ones((1, *INPUT_DIM), dtype=jnp.float32)
initial_params = model.init(key, dummy_input)['params']

print("JAX Model and Parameters Initialized")

JAX Model and Parameters Initialized


### 3.4 Packaging the Model with JaxModule
Next, define the forward pass function and package it with the `initial_params` into a `JaxModule`. The `input_polymorphic_shape` is specified to allow the exported model to handle variable batch sizes, denoted by `'b'`.

In [6]:
# Define the apply function that will be exported
def apply_fn(params, x):
    return model.apply({'params': params}, x)

# Wrap the parameters and function in a JaxModule
jax_module = JaxModule(
    params=initial_params,
    apply_fn=apply_fn,
    input_polymorphic_shape="b, 28, 28, 1"
)

print("JaxModule Created")

JaxModule Created


### 3.5 Exporting to SavedModel with `ExportManager`
Next, with the `JaxModule` prepared, configure its serving signature using [`ServingConfig`](https://orbax.readthedocs.io/en/latest/api_reference/export.serving_config.html). This step defines the expected shape and name of the input tensor for the TensorFlow graph. Finally, the `ExportManager` takes the module and its configuration to save into the **SavedModel**.


In [7]:
SAVE_PATH = os.path.join(tmpdir, "orbax_exported_savedmodel")

# Define the serving signature
serving_config = ServingConfig(
    signature_key="serving_default",
    input_signature=[
        tf.TensorSpec(shape=(None, *INPUT_DIM), dtype=tf.float32, name='input_image')
    ]
)

# Use ExportManager to save the JaxModule as a TensorFlow SavedModel
export_manager = ExportManager(
    module=jax_module,
    serving_configs=[serving_config]
)

print(f"Exporting model to: {SAVE_PATH}...")
export_manager.save(model_path=SAVE_PATH)
print("Export to SavedModel Complete!")

Exporting model to: ./orbax_exported_savedmodel...
Export to SavedModel Complete!


### 3.6 Verification in TensorFlow
To verify a successful export, the **SavedModel** is loaded using TensorFlow's standard library. Inference is then run on a sample batch of data, and the output is compared to the original JAX model's output to ensure numerical consistency.

In [8]:
import numpy as np

# Load the SavedModel from the specified path
print(f"Loading SavedModel from: {SAVE_PATH}...")
loaded_model = tf.saved_model.load(SAVE_PATH)
loaded_signature = loaded_model.signatures['serving_default']
print("SavedModel loaded successfully.")

# Create a test input with a dynamic batch size
test_batch_size = 5
test_input_np = np.random.rand(test_batch_size, *INPUT_DIM).astype(np.float32)
tf_input = tf.constant(test_input_np)

# Run inference using the loaded TensorFlow model
tf_output = loaded_signature(input_image=tf_input)
tf_output_array = tf_output['output_0'].numpy()

print(f"Inference successful. Output shape: {tf_output_array.shape}")

# Compare outputs between the original JAX model and the loaded TF model
jax_output = apply_fn(initial_params, jnp.asarray(test_input_np))
match = np.allclose(jax_output, tf_output_array, atol=1e-5)
print(f"Numerical check (JAX vs. SavedModel): {'MATCH' if match else 'MISMATCH'}")

# Clean up the exported directory
shutil.rmtree(SAVE_PATH)

Loading SavedModel from: ./orbax_exported_savedmodel...
SavedModel loaded successfully.
Inference successful. Output shape: (5, 10)
Numerical check (JAX vs. SavedModel): MATCH
