### Set up notebook

In [324]:
import os
import re
import absl.flags as flags
import flax.linen as nn
import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
import orbax.checkpoint as ocp

In [325]:
# This needs to be run first before any JAX code, to force JAX to use CPU in our demo for training.
num_cpu_devices = 8
xla_flags = os.getenv('XLA_FLAGS', '')
xla_flags = re.sub(
    r'--xla_force_host_platform_device_count=\S+', '', xla_flags
).split()
os.environ['XLA_FLAGS'] = ' '.join(
    [f'--xla_force_host_platform_device_count={num_cpu_devices}'] + xla_flags
)
jax.config.update('jax_platforms', 'cpu')
flags.FLAGS.jax_allow_unused_tpus = True
jax.devices()

In [326]:
assert len(jax.devices()) == 8

jax.config.update('jax_enable_x64', True)

In [327]:
from orbax.experimental.model import core as obm
from orbax.experimental.model import jax2obm
from orbax.experimental.model.jax2obm import jax_supplemental_pb2
from orbax.experimental.model.jax2obm import obm_to_jax
from orbax.export import oex_orchestration
from orbax.export.protos import oex_orchestration_pb2

## Demo JAX roundtripping

In [328]:
# Create the model.
class Mnist(nn.Module):
  """Flax MNIST model."""

  @nn.compact
  def __call__(self, x):
    """See base class."""
    x = nn.Conv(features=32, kernel_size=(4, 4))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(4, 4))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x


model = Mnist()

In [329]:
# Initiate the model parameters.
input_args_spec = jax.ShapeDtypeStruct((4, 28, 28, 1), jnp.float64)
params = model.init(
    jax.random.PRNGKey(666),
    jnp.ones(shape=input_args_spec.shape, dtype=input_args_spec.dtype),
)

Here we shard the model (2x2x2) across 8 devices.

In [330]:
# Shard the model.
def get_mesh():
  devices = mesh_utils.create_device_mesh((2, 2, 2))
  return jax.sharding.Mesh(devices, ('b', 'x', 'y'))


mesh = get_mesh()

params_sharding_spec = jax.tree_util.tree_map(
    lambda _: NamedSharding(mesh, jax.sharding.PartitionSpec('y')), params
)
input_sharding_spec = NamedSharding(mesh, PartitionSpec('b', 'x', None, None))

model_apply_fn = jax.jit(
    model.apply,
    in_shardings=(
        params_sharding_spec,
        input_sharding_spec,
    ),
    out_shardings=NamedSharding(mesh, PartitionSpec('b', 'y')),
)

params_args_spec = jax2obm.get_shape_dtype_struct(params)

In [331]:
# Convert the JAX function to SHLO.
obm_shlo_fn = jax2obm.convert(
    model_apply_fn,
    (params_args_spec, input_args_spec),
    {},
)

# Add to the OBM module.
obm_module = dict()
model_function_name = 'mnist_forward_fn'
obm_module[model_function_name] = obm_shlo_fn

In [332]:
save_dir_path = "/tmp/model"
!mkdir /tmp/model

In [333]:
# Saves the params to Orbax Checkpoint, which will be loaded later.
!rm -r  /tmp/model/my_checkpoint
checkpoint_path = 'my_checkpoint'
checkpoint_abs_path = os.path.join(save_dir_path, checkpoint_path)
checkpointer = ocp.Checkpointer(ocp.StandardCheckpointHandler())
checkpointer.save(checkpoint_abs_path, params)
weights_name = 'my_weights'
obm_module[weights_name] = jax2obm.main_lib.convert_path_to_value(
    checkpoint_path,
    mime_type='orbax_checkpoint',
)

In [334]:
!ls /tmp/model/my_checkpoint/

In [335]:
# Save the OBM module.
obm.save(
    obm_module,
    save_dir_path,
    obm.SaveOptions(
        version=2,
        supplemental_info=obm.GlobalSupplemental(
            oex_orchestration.create(
                signature=oex_orchestration.calculate_signature(
                    model_function_signature=obm_shlo_fn.signature
                ),
                model_function_name=model_function_name,
                weights_name=weights_name,
            ),
            'my_orchestration.pb',
        ),
    ),
)

In [308]:
# All of those information will be provided by the manifest at load time.
del model_function_name
del weights_name
del checkpoint_path
del checkpoint_abs_path

In [309]:
!ls /tmp/model

## Load the model from disk.

In [310]:
# Load the manifest.
manifest_proto = obm.manifest_pb2.Manifest()
with open(os.path.join(save_dir_path, obm.MANIFEST_FILENAME), 'rb') as f:
  manifest_proto.ParseFromString(f.read())

In [311]:
manifest_proto

In [312]:
# Load the orchestration.
orch_filename = (
    manifest_proto.supplemental_info.single.file_system_location.string_path
)
orch_proto = oex_orchestration_pb2.Pipeline()
with open(os.path.join(save_dir_path, orch_filename), 'rb') as f:
  orch_proto.ParseFromString(f.read())

In [313]:
orch_proto

In [314]:
# Load the model function.

loaded_model_function_name = orch_proto.model_function_name
loaded_obm_function = manifest_proto.objects[
    loaded_model_function_name
].function

In [315]:
loaded_obm_function

In [316]:
# Load the supplemental information.
jax_supplemental_filename = (
    loaded_obm_function.body.stable_hlo_body.supplemental_info.file_system_location.string_path
)
jax_supplemental_proto = jax_supplemental_pb2.Function()
with open(os.path.join(save_dir_path, jax_supplemental_filename), 'rb') as f:
  jax_supplemental_proto.ParseFromString(f.read())

In [317]:
jax_supplemental_proto

In [318]:
# Deserializes into a JAX function.
deserialized_jax_exported = obm_to_jax.obm_functions_to_jax_function(
    loaded_obm_function,
    jax_supplemental_proto,
)

In [319]:
# Restore/load the params from the saved Orbax Checkpoint.

loaded_weights_name = orch_proto.weights_name
loaded_checkpoint_path = manifest_proto.objects[
    loaded_weights_name
].value.external.data.file_system_location.string_path
restored_params = checkpointer.restore(
    os.path.join(save_dir_path, loaded_checkpoint_path)
)

## Test that loaded and original model match.

In [320]:
test_input_data = jax.device_put(
    jax.random.uniform(
        jax.random.PRNGKey(999), (4, 28, 28, 1), dtype=jnp.float64
    ),
    input_sharding_spec,
)

In [321]:
# Compare results of loaded function with loaded weights against the original JAX function with original weights.

In [322]:
result_from_original_jax_call = model_apply_fn(params, test_input_data)
result_from_deserialized_jax_call = deserialized_jax_exported.call(
    jax.device_put(restored_params, params_sharding_spec),
    test_input_data,
)

In [323]:
assert jnp.array_equal(
    result_from_deserialized_jax_call, result_from_original_jax_call
)