# Tutorial: Exporting StableHLO from JAX

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)][jax-tutorial-colab]
[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)][jax-tutorial-kaggle]

JAX is a Python library for high-performance numerical computing. This tutorial shows how to export JAX and Flax (JAX-powered neural network library) models to StableHLO, and directly to TensorFlow SavedModel.

## Tutorial Setup

### Install required dependencies

We use `jax` and `jaxlib` (JAX's support library with compiled binaries), along with `flax` and `transformers` for some models to export.
We also need to install `tensorflow` to work with SavedModel, and recommend using `tensorflow-cpu` or `tf-nightly` for this tutorial.

[jax-tutorial-colab]: https://colab.research.google.com/github/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb
[jax-tutorial-kaggle]: https://kaggle.com/kernels/welcome?src=https://github.com/openxla/stablehlo/blob/main/docs/tutorials/jax-export.ipynb

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
!pip install -U jax jaxlib flax transformers tensorflow-cpu

Collecting jax
  Downloading jax-0.4.35-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib
  Downloading jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl.metadata (983 bytes)
Collecting flax
  Downloading flax-0.10.1-py3-none-any.whl.metadata (11 kB)
Collecting tensorflow-cpu
  Downloading tensorflow_cpu-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tensorboard<2.19,>=2.18 (from tensorflow-cpu)
  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)
Collecting keras>=3.5.0 (from tensorflow-cpu)
  Downloading keras-3.6.0-py3-none-any.whl.metadata (5.8 kB)
Downloading jax-0.4.35-py3-none-any.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jaxlib-0.4.35-cp310-cp310-manylinux2014_x86_64.whl (87.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.3/87.3 MB[0m [31m27.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownl

In [2]:
#@title Define `get_stablehlo_asm` to help with MLIR printing
from jax._src.interpreters import mlir as jax_mlir
from jax._src.lib.mlir import ir

# Returns prettyprint of StableHLO module without large constants
def get_stablehlo_asm(module_str):
  with jax_mlir.make_ir_context():
    stablehlo_module = ir.Module.parse(module_str, context=jax_mlir.make_ir_context())
    return stablehlo_module.operation.get_asm(large_elements_limit=20)

# Disable logging for better tutorial rendering
import logging
logging.disable(logging.WARNING)

_Note: This helper uses a JAX internal API that may break at any time, but it serves no functional purpose in the tutorial aside from readability._

## Export JAX model to StableHLO using `jax.export`

In this section we'll export a basic JAX function and a Flax model to StableHLO.

The preferred API for export is [`jax.export`](https://jax.readthedocs.io/en/latest/jax.export.html#module-jax.export). The function to export must be JIT transformed, specifically a result of `jax.jit`, to be exported to StableHLO.

### Export basic JAX model to StableHLO

Let's start by exporting a basic `plus` function to StableHLO, using `np.int32` argument types to trace the function.

Export requires specifying shapes using `jax.ShapeDtypeStruct`, which can be constructed from NumPy values.

In [3]:
import jax
from jax import export
import jax.numpy as jnp
import numpy as np

# Create a JIT-transformed function
@jax.jit
def plus(x,y):
  return jnp.add(x,y)

# Create abstract input shapes
inputs = (np.int32(1), np.int32(1),)
input_shapes = [jax.ShapeDtypeStruct(input.shape, input.dtype) for input in inputs]

# Export the function to StableHLO
stablehlo_add = export.export(plus)(*input_shapes).mlir_module()
print(get_stablehlo_asm(stablehlo_add))

module @jit_plus attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<i32>
    return %0 : tensor<i32>
  }
}



### Export Hugging Face FlaxResNet18 to StableHLO

Now let's look at a simple model that appears in the wild, `resnet18`.

We'll export a `flax` model from the Hugging Face `transformers` ResNet page, [FlaxResNetModel](https://huggingface.co/docs/transformers/en/model_doc/resnet#transformers.FlaxResNetModel). This steps setup was copied from the Hugging Face documentation.

The documentation also states: _"Finally, this model supports inherent JAX features such as: **Just-In-Time (JIT) compilation** ..."_ which means it is perfect for export.

Similar to our very basic example, our steps for export are:

1. Instantiate a callable (model/function)
2. JIT-transform it with `jax.jit`
3. Specify shapes for export using `jax.ShapeDtypeStruct` on NumPy values
4. Use the JAX `export` API to get a StableHLO module

In [4]:
from transformers import AutoImageProcessor, FlaxResNetModel
import jax
import numpy as np

# Construct jit-transformed flax model with sample inputs
resnet18 = FlaxResNetModel.from_pretrained("microsoft/resnet-18", return_dict=False)
resnet18_jit = jax.jit(resnet18)
sample_input = np.random.randn(1, 3, 224, 224)
input_shape = jax.ShapeDtypeStruct(sample_input.shape, sample_input.dtype)

# Export to StableHLO
stablehlo_resnet18_export = export.export(resnet18_jit)(input_shape)
resnet18_stablehlo = get_stablehlo_asm(stablehlo_resnet18_export.mlir_module())
print(resnet18_stablehlo[:600], "\n...\n", resnet18_stablehlo[-345:])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.5k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x3x224x224xf32>) -> (tensor<1x512x7x7xf32> {jax.result_info = "[0]"}, tensor<1x512x1x1xf32> {jax.result_info = "[1]"}) {
    %c = stablehlo.constant dense<49> : tensor<i32>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %cst_1 = stablehlo.constant dense<9.99999974E-6> : tensor<f32>
    %cst_2 = stablehlo.constant dense_reso 
...
 func.func private @relu_3(%arg0: tensor<1x7x7x512xf32>) -> tensor<1x7x7x512xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<1x7x7x512xf32>
    %1 = stablehlo.maximum %arg0, %0 : tensor<1x7x7x512xf32>
    return %1 : tensor<1x7x7x512xf32>
  }
}



### Export with dynamic batch size

Now let's export that same model with a dynamic batch size!

In the first example, we used an input shape of `tensor<1x3x224x224xf32>`, specifying strict constraints on the input shape. If we want to defer the concrete shapes used in compilation until a later point, we can specify a `symbolic_shape`. In this example, we'll export using `tensor<?x3x224x224xf32>`.

Symbolic shapes are specified using `export.symbolic_shape`, with letters representing symint dimensions. For example, a valid 2-d matrix multiplication could use symbolic constraints of: `2,a * a,5` to ensure the refined program will have valid shapes. Symbolic integer names are kept track of by an `export.SymbolicScope` to avoid unintentional name clashes.

In [None]:
# Construct dynamic sample inputs
dyn_scope = export.SymbolicScope()
dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)

# Export to StableHLO
dyn_resnet18_export = export.export(resnet18_jit)(dyn_input_shape)
dyn_resnet18_stablehlo = get_stablehlo_asm(dyn_resnet18_export.mlir_module())
print(dyn_resnet18_stablehlo[:1900], "\n...\n", dyn_resnet18_stablehlo[-1000:])

module @jit__unnamed_wrapped_function_ attributes {jax.uses_shape_polymorphism = true, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<?x3x224x224xf32>) -> (tensor<?x512x7x7xf32> {jax.result_info = "[0]"}, tensor<?x512x1x1xf32> {jax.result_info = "[1]"}) {
    %c = stablehlo.constant dense<1> : tensor<i32>
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x3x224x224xf32>) -> tensor<i32>
    %1 = stablehlo.compare  GE, %0, %c,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
    stablehlo.custom_call @shape_assertion(%1, %0) {api_version = 2 : i32, error_message = "Input shapes do not match the polymorphic shapes specification. Expected value >= 1 for dimension variable 'a'. Using the following polymorphic shapes specifications: args[0].shape = (a, 3, 224, 224). Obtained dimension variables: 'a' = {0} from specification 'a' for dimension args[0].shape[0] (= {0}), . Please see https://jax.readthedocs.io/en/latest/export

In [8]:
import jax
import jax.numpy as jnp
from jax import export
import numpy as np
from jax.nn import relu, softmax

# Define a simple CNN class
class SimpleCNN:
    def __init__(self):
        # Initialize weights for convolution and dense layers
        self.conv1_weight = jnp.ones((32, 3, 3, 3))  # 32 filters, 3x3 kernel, 3 input channels
        self.conv2_weight = jnp.ones((64, 32, 3, 3))  # 64 filters, 3x3 kernel, 32 input channels

        # Define the output size after convolutions with input (batch, 3, 224, 224)
        # Adjusting the fully connected layer shape accordingly
        self.fc_weight = jnp.ones((64 * 224 * 224, 10))  # Flattened feature size to 10 classes

    # Define forward pass
    def __call__(self, x):
        x = jax.lax.conv(x, self.conv1_weight, (1, 1), 'SAME')
        x = relu(x)
        x = jax.lax.conv(x, self.conv2_weight, (1, 1), 'SAME')
        x = relu(x)

        # Flatten for fully connected layer
        x = x.reshape((x.shape[0], -1))  # Ensure flattened size matches fc_weight
        x = jnp.dot(x, self.fc_weight)
        return softmax(x)

# Instantiate the model and apply JIT
cnn_model = SimpleCNN()
cnn_jit = jax.jit(cnn_model)

# Define a valid dynamic input shape for Google Colab
try:
    # Set a dynamic input shape; 'a' is a batch size symbol, 3 is the color channel, and 224x224 is the image resolution
    dyn_scope = export.SymbolicScope()
    dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)

    # Export the CNN model to StableHLO
    dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
    dyn_cnn_stablehlo = dyn_cnn_export.mlir_module()

    # Directly print the MLIR module content
    print("MLIR Representation:\n", dyn_cnn_stablehlo)

except Exception as e:
    print("Error in generating StableHLO MLIR module:", e)


MLIR Representation:
 #loc = loc(unknown)
#loc1 = loc("<ipython-input-8-b72e4635e514>":41:0)
#loc2 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3553:0)
#loc3 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3473:0)
#loc4 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3257:0)
#loc5 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py":78:0)
#loc6 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3030:0)
#loc7 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":2975:0)
#loc8 = loc("/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py":539:0)
#loc9 = loc("/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py":302:0)
#loc11 = loc("x")
#loc13 = loc("<ipython-input-8-b72e4635e514>":21:0)
#loc15 = loc("<ipython-input-8-b72e4635e514>":23:0)
#loc19 = loc("<cell line: 35>"(#loc1))
#loc20 = loc("r

In [2]:
import jax
import jax.numpy as jnp
from jax.experimental import export
import numpy as np
from jax.nn import relu, softmax

# Define 10 structurally distinct models
class CNNModel1:
    def __call__(self, x):
        x = jax.lax.conv(x, jnp.ones((16, 3, 3, 3)), (1, 1), 'SAME')
        x = relu(x)
        x = jnp.max(x, axis=(1, 2))  # Global Max Pooling
        return softmax(x)

class CNNModel2:
    def __call__(self, x):
        x = jax.lax.conv(x, jnp.ones((8, 3, 5, 5)), (1, 1), 'SAME')
        x = relu(x)
        x = jnp.mean(x, axis=(1, 2))  # Global Average Pooling
        x = x.reshape((x.shape[0], -1))  # Flatten to match fully connected layer
        x = jnp.dot(x, jnp.ones((x.shape[-1], 10)))  # Fully connected layer with output size 10
        return softmax(x)

class CNNModel3:
    def __call__(self, x):
        x = jax.lax.conv(x, jnp.ones((32, 3, 3, 3)), (2, 2), 'VALID')
        x = relu(x)
        x = jax.lax.conv(x, jnp.ones((64, 32, 3, 3)), (2, 2), 'VALID')
        x = relu(x)
        x = x.reshape((x.shape[0], -1))
        return softmax(x)

class CNNModel4:
    def __call__(self, x):
        x = jax.lax.conv(x, jnp.ones((16, 3, 7, 7)), (2, 2), 'SAME')
        x = relu(x)
        x = jnp.mean(x, axis=(1, 2))
        return softmax(x)

class CNNModel5:
    def __call__(self, x):
        x = jax.lax.conv(x, jnp.ones((16, 3, 3, 3)), (1, 1), 'SAME')
        x = relu(x)
        x = jax.lax.conv(x, jnp.ones((32, 16, 3, 3)), (2, 2), 'VALID')
        x = relu(x)
        x = x.reshape((x.shape[0], -1))
        return softmax(x)

class LinearModel1:
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = jnp.dot(x, jnp.ones((x.shape[-1], 50)))
        x = relu(x)
        x = jnp.dot(x, jnp.ones((50, 10)))
        return softmax(x)

class LinearModel2:
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = jnp.dot(x, jnp.ones((x.shape[-1], 10)))  # Output size modified to 10 for compatibility
        return softmax(x)

class CNNModel6:
    def __call__(self, x):
        x = jax.lax.conv(x, jnp.ones((128, 3, 3, 3)), (1, 1), 'SAME')
        x = relu(x)
        x = jax.lax.conv(x, jnp.ones((256, 128, 3, 3)), (2, 2), 'VALID')
        x = relu(x)
        x = x.reshape((x.shape[0], -1))
        return softmax(x)

class CNNModel7:
    def __call__(self, x):
        x = jax.lax.conv(x, jnp.ones((8, 3, 1, 1)), (1, 1), 'SAME')
        x = relu(x)
        x = jax.lax.conv(x, jnp.ones((16, 8, 1, 1)), (1, 1), 'SAME')
        x = jnp.mean(x, axis=(1, 2))
        return softmax(x)

class CNNModel8:
    def __call__(self, x):
        x = jax.lax.conv(x, jnp.ones((32, 3, 3, 3)), (1, 1), 'SAME')
        x = relu(x)
        x = jax.lax.conv(x, jnp.ones((64, 32, 3, 3)), (1, 1), 'SAME')
        x = relu(x)
        x = jnp.max(x, axis=(1, 2))  # Global Max Pooling
        flattened_dim = x.shape[-1]
        x = jnp.dot(x, jnp.ones((flattened_dim, 10)))  # Fully connected layer
        return softmax(x)

# List of models
models = [CNNModel1(), CNNModel2(), CNNModel3(), CNNModel4(), CNNModel5(),
          LinearModel1(), LinearModel2(), CNNModel6(), CNNModel7(), CNNModel8()]

# Define a function to print MLIR from StableHLO
def get_mlir_output(cnn_jit, model_index):
    try:
        # Define a dynamic input shape; 'a' is a batch size symbol, 3 is the color channel, and 224x224 is the image resolution
        dyn_scope = export.SymbolicScope()
        dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)

        # Export the model to StableHLO
        dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
        dyn_cnn_stablehlo = dyn_cnn_export.mlir_module()

        # Print the MLIR for each model
        print(f"MLIR for Model {model_index}:\n", dyn_cnn_stablehlo, "\n")
    except Exception as e:
        print(f"Error in generating MLIR for Model {model_index}:", e)

# Iterate over each model and export MLIR
for i, model in enumerate(models, 1):
    # Apply JIT compilation
    cnn_jit = jax.jit(model)
    get_mlir_output(cnn_jit, i)


  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_ji

MLIR for Model 1:
 #loc = loc(unknown)
#loc1 = loc("<ipython-input-2-2e537abe9551>":103:0)
#loc2 = loc("<ipython-input-2-2e537abe9551>":115:0)
#loc3 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3553:0)
#loc4 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3473:0)
#loc5 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3257:0)
#loc6 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py":78:0)
#loc7 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3030:0)
#loc8 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":2975:0)
#loc9 = loc("/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py":539:0)
#loc11 = loc("x")
#loc13 = loc("<ipython-input-2-2e537abe9551>":11:0)
#loc16 = loc("get_mlir_output"(#loc1))
#loc17 = loc("<cell line: 112>"(#loc2))
#loc18 = loc("run_code"(#loc3))
#loc19 = loc("run_ast_nodes"

  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_jit)(dyn_input_shape)
  dyn_scope = export.SymbolicScope()
  dyn_input_shape = jax.ShapeDtypeStruct(export.symbolic_shape("a,3,224,224", scope=dyn_scope), np.float32)
  dyn_cnn_export = export.export(cnn_ji

MLIR for Model 6:
 #loc = loc(unknown)
#loc1 = loc("<ipython-input-2-2e537abe9551>":103:0)
#loc2 = loc("<ipython-input-2-2e537abe9551>":115:0)
#loc3 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3553:0)
#loc4 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3473:0)
#loc5 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3257:0)
#loc6 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/async_helpers.py":78:0)
#loc7 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":3030:0)
#loc8 = loc("/usr/local/lib/python3.10/dist-packages/IPython/core/interactiveshell.py":2975:0)
#loc9 = loc("/usr/local/lib/python3.10/dist-packages/ipykernel/zmqshell.py":539:0)
#loc11 = loc("x")
#loc14 = loc("<ipython-input-2-2e537abe9551>":53:0)
#loc17 = loc("get_mlir_output"(#loc1))
#loc18 = loc("<cell line: 112>"(#loc2))
#loc19 = loc("run_code"(#loc3))
#loc20 = loc("run_ast_nodes"

A few things to note in the exported StableHLO:

1. The exported program now has `tensor<?x3x224x224xf32>`. These input types can be refined in many ways: StableHLO has APIs to [refine shapes](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L69-L99) and [canonicalize dynamic programs](https://github.com/openxla/stablehlo/blob/541db997e449dcfee8536043dfdd49bb13f9ed1a/stablehlo/transforms/Passes.td#L18-L28) to static programs. TensorFlow SavedModel execution also takes care of refinement which we'll see in the next example.
2. JAX will generate guards to ensure the values of `a` are valid, in this case `a > 1` is checked. These can be washed away at compile time once refined.

## Export to TensorFlow SavedModel

It is common to export a StableHLO model to SavedModel for interoperability with existing compilation pipelines, existing TensorFlow tooling, or serving via [TensorFlow Serving](https://github.com/tensorflow/serving).

JAX makes it easy to pack StableHLO into a SavedModel, and load that SavedModel in the future. For this section, we'll be using our dynamic model from the previous section.

### Export to SavedModel using `jax2tf`

JAX provides a simple API for exporting StableHLO into a format that can be packaged in SavedModel in `jax.experimental.jax2tf`. This uses the `export` function under the hood, so the same `jit` requirements apply.

Full details on `jax2tf` can be found in the [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#jax-and-tensorflow-interoperation-jax2tfcall_tf). For this example, we'll only need to know the `polymorphic_shapes` option to specify our dynamic batch dimension.

In [None]:
from jax.experimental import jax2tf
import tensorflow as tf

exported_f = jax2tf.convert(resnet18, polymorphic_shapes=["(a,3,224,224)"])

# Copied from the jax2tf README.md > Usage: saved model
my_model = tf.Module()
my_model.f = tf.function(exported_f, autograph=False).get_concrete_function(tf.TensorSpec([None, 3, 224, 224], tf.float32))
tf.saved_model.save(my_model, '/tmp/resnet18_tf', options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))

!ls /tmp/resnet18_tf

[34massets[m[m         fingerprint.pb saved_model.pb [34mvariables[m[m


### Reload and call the SavedModel

Now we can load that SavedModel and compile using our `sample_input` from a previous example.

_Note: The restored model does *not* require JAX to run, just XLA._

In [None]:
restored_model = tf.saved_model.load('/tmp/resnet18_tf')
restored_result = restored_model.f(tf.constant(sample_input, tf.float32))
print("Result shape:", restored_result[0].shape)

Result shape: (1, 512, 7, 7)


## Troubleshooting

### `jax.jit` issues

If the function can be JIT'ed, then it can be exported. Ensure `jax.jit` works first, or look in desired project for uses of JIT already (for example, [AlphaFold's `apply`](https://github.com/google-deepmind/alphafold/blob/dbe2a438ebfc6289f960292f15dbf421a05e563d/alphafold/model/model.py#L89) can be exported easily).

See [JAX's JIT compilation documentation](https://jax.readthedocs.io/en/latest/jit-compilation.html) and [`jax.jit` API reference and examples](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) for troubleshooting JIT transformations. The most common issue is control flow, which can often be resolved with `static_argnums` / `static_argnames` as in the linked example.

### Support tickets

You can open an issue on GitHub for further help. Include a reproducible example using one of the above APIs in your issue report, this will help get the issue resolved much quicker!