# Orbax export work together with DTensor

[DTensor](https://www.tensorflow.org/guide/dtensor_overview), an extension to TensorFlow for synchronous distributed computing, provides a global programming model that allows developers to compose applications that operate on Tensors globally while managing the distribution across devices internally.

We can export JAX multi-host models to TF SavedModel using DTensor.


## Setup

Here we use CPU to emulate 8-core virtual cores for testing.

In [None]:
# emulate the multi-core from the single CPU.
import os
flags = os.environ.get('XLA_FLAGS', '')
os.environ['XLA_FLAGS'] = flags + " --xla_force_host_platform_device_count=8"
print(os.environ['XLA_FLAGS'])

In [None]:
import jax
print(jax.devices())

## Export sharded JAX model with [DTensor](https://www.tensorflow.org/guide/dtensor_overview)

Here is the simple example demonstrate how to do. First, we need initializes accelerators and communication fabrics for DTensor.

In [None]:
from orbax.export import dtensor_utils
if dtensor_utils.dtensor_initialized():
  dtensor_utils.shutdown_dtensor()
dtensor_utils.initialize_dtensor(reset_context=True)
assert(dtensor_utils.dtensor_initialized())

Define the JAX model function, model parameters and inputs.

In [None]:
import jax
import jax.numpy as jnp
dim_x = 16000
dim_y = 8000
batch = 1

k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
example3_params =  {
    'w': jax.random.uniform(k1, (dim_x, dim_y)), 'b': jax.random.uniform(k2, (dim_y,))
}
example3_inputs = jax.random.uniform(k3, (batch, dim_x))

# model f(x) = a * sin(x) + b * x + c, here (a, b, c) are model parameters
def example3_model_fn(params, x):  # The JAX model function to export.
  w, b = params['w'], params['b']
  return jnp.matmul(x, w) + b

Define the JAX sharding and create the sharding JAX array.

In [None]:
import numpy as np
from jax.sharding import PartitionSpec as P
from jax.experimental import pjit

mesh_shape = (1, 8)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = jax.sharding.Mesh(devices, ('x', 'y'))
params_pspecs = {'w': P('x', 'y'), 'b': P('y')}

sharded_params = {}
sharded_params['w'] = jax.device_put(example3_params['w'], jax.sharding.NamedSharding(mesh, params_pspecs['w']))
sharded_params['b'] = jax.device_put(example3_params['b'], jax.sharding.NamedSharding(mesh, params_pspecs['b']))

sharded_inputs = jax.device_put(example3_inputs, jax.sharding.NamedSharding(mesh, None))

sharded_model_fn = pjit.pjit(example3_model_fn, in_shardings=(params_pspecs, None), out_shardings=None)

Here we use orbax export and DTensor API export the tf.SavedModel.

In [None]:
import tempfile
import tensorflow as tf
from orbax.export.validate import ValidationManager
from orbax.export import ExportManager
from orbax import export as obx_export
from orbax.export import JaxModule
from orbax.export import ServingConfig

export_dir =  tempfile.mkdtemp()

with mesh, dtensor_utils.maybe_enable_dtensor_export_on(mesh):
  jax_module = JaxModule(sharded_params, sharded_model_fn, pspecs=params_pspecs)

  serving_configs = [
    ServingConfig(
      'serving_default',
      input_signature= [tf.TensorSpec(shape=[batch, dim_x], dtype=tf.float32, name='x')],
    ),
  ]
  export_mgr = ExportManager(jax_module, serving_configs)
  export_mgr.save(export_dir)

In [None]:
loaded_model = tf.saved_model.load(export_dir)

In [None]:
tf_inputs = tf.convert_to_tensor(example3_inputs)

In [None]:
loaded_model(tf_inputs)