# Conversion of Objax models to Tensorflow

This tutorial demonstrates how to export models from Objax to Tensorflow and then export them into SavedModel format.

SavedModel format could be read and served by [Tensorflow serving infrastructure](https://www.tensorflow.org/tfx/guide/serving) or by custom user code written in C++. Thus export to Tensorflow allows users to potentially run experiments in Objax and then serve these models in production (using Tensorflow infrastructure).

## Installation and Imports

First of all, let's install Objax and import all necessary python modules.

In [None]:
# install the latest version of Objax from github
%pip --quiet install git+https://github.com/google/objax.git

  Building wheel for objax (setup.py) ... [?25l[?25hdone


In [None]:
import math
import random
import tempfile

import numpy as np
import tensorflow as tf

import objax
from objax.zoo.wide_resnet import WideResNet

## Setup Objax model

Let's make a model in Objax and create a prediction operation which we will be later converting to Tensorflow.

In this tutorial we use randomly initialized model, so we don't need to wait for model training to finish. However conversion to Tensorflow would be the same if we train model first. 

In [None]:
# Model
model = WideResNet(nin=3, nclass=10, depth=4, width=1)

# Prediction operation
@objax.Function.with_vars(model.vars())
def predict_op(x):
  return objax.functional.softmax(model(x, training=False))

predict_op = objax.Jit(predict_op)

Now, let's generate a few examples and run prediction operation on them:

In [None]:
input_shape = (4, 3, 32, 32)

x1 = np.random.uniform(size=input_shape)
y1 = predict_op(x1)
print('y1:\n', y1)

x2 = np.random.uniform(size=input_shape)
y2 = predict_op(x2)
print('y2:\n', y2)

y1:
 [[0.06012213 0.07564961 0.13191961 0.12591475 0.08727679 0.16077745
  0.10821478 0.07958559 0.08044666 0.09009264]
 [0.05957783 0.07606611 0.13554339 0.12834728 0.08717373 0.16128528
  0.10691827 0.07874896 0.07711729 0.08922189]
 [0.06061443 0.07384451 0.13470736 0.12721166 0.0860136  0.16349576
  0.11025441 0.0766369  0.07674664 0.09047476]
 [0.05972087 0.07695024 0.1373597  0.12381804 0.08652159 0.16483182
  0.10871573 0.07637089 0.07585222 0.08985896]]
y2:
 [[0.05968136 0.08011787 0.13363907 0.12838946 0.08562963 0.15783262
  0.10989606 0.07535356 0.07875139 0.09070906]
 [0.0572416  0.07606035 0.13607582 0.12197609 0.08373585 0.16551377
  0.11429026 0.07743792 0.0776429  0.0900254 ]
 [0.06138196 0.07201274 0.13394636 0.12132262 0.08225243 0.1682174
  0.11442989 0.07763992 0.0776984  0.09109832]
 [0.05850162 0.07468063 0.13054986 0.12376051 0.08367112 0.16295902
  0.11684892 0.07688387 0.08140761 0.09073676]]


## Convert a model to Tensorflow

We use `Objax2Tf` object to convert Objax module into `tf.Module`.

Internally `Objax2Tf` makes a copy of all Objax variables used by the provided module and converts `__call__` method of the provided Objax module
into [Tensorflow function](https://www.tensorflow.org/api_docs/python/tf/function).

In [None]:
predict_op_tf = objax.util.Objax2Tf(predict_op)

print('isinstance(predict_op_tf, tf.Module) =', isinstance(predict_op_tf, tf.Module))
print('Number of variables: ', len(predict_op_tf.variables))

isinstance(predict_op_tf, tf.Module) = True
Number of variables:  39


After module is converted we can run it and compare results between Objax and Tensorflow. Results are pretty close numerically, however they are not exactly the same due to implementation differences between JAX and Tensorflow.

In [None]:
y1_tf = predict_op_tf(x1)
print('max(abs(y1_tf - y1)) =', np.amax(np.abs(y1_tf - y1)))

y2_tf = predict_op_tf(x2)
print('max(abs(y2_tf - y2)) =', np.amax(np.abs(y2_tf - y2)))





max(abs(y1_tf - y1)) = 4.4703484e-08
max(abs(y2_tf - y2)) = 2.2351742e-08


## Export Tensorflow model as SavedModel

Converting an Objax model to Tensorflow allows us to export it as [Tensorflow SavedModel](https://www.tensorflow.org/guide/saved_model).

Discussion of details of SavedModel format is out of scope of this tutorial, thus we only provide an example showing how to save and load SavedModel. For more details about SavedModel please refert to the following Tensorflow documentation:

* [Using the SavedModel format](https://www.tensorflow.org/guide/saved_model) guide
* [tf.saved_model.save](https://www.tensorflow.org/api_docs/python/tf/saved_model/save) API call
* [tf.saved_model.load](https://www.tensorflow.org/api_docs/python/tf/saved_model/load) API call

### Saving model as SavedModel

First of all, let's create a new empty directory where model will be saved:

In [None]:
model_dir = tempfile.mkdtemp()

%ls -al $model_dir

total 8
drwx------ 2 root root 4096 Dec 17 23:28 [0m[01;34m.[0m/
drwxrwxrwt 1 root root 4096 Dec 17 23:28 [30;42m..[0m/


Then let's use `tf.saved_model.save` API to save our Tensorflow model.
Since `Objax2Tf` is a subclass of `tf.Module`, instances of `Objax2Tf` class could be directly used with `tf.saved_model.save` API:

In [None]:
tf.saved_model.save(
    predict_op_tf,
    model_dir,
    signatures=predict_op_tf.__call__.get_concrete_function(
        tf.TensorSpec(input_shape, tf.float32)))













INFO:tensorflow:Assets written to: /tmp/tmpppqtia2e/assets


INFO:tensorflow:Assets written to: /tmp/tmpppqtia2e/assets


Now we can list the content of `model_dir` and see files and subdirectories of SavedModel:

In [None]:
%ls -al $model_dir

total 612
drwx------ 4 root root   4096 Dec 17 23:28 [0m[01;34m.[0m/
drwxrwxrwt 1 root root   4096 Dec 17 23:28 [30;42m..[0m/
drwxr-xr-x 2 root root   4096 Dec 17 23:28 [01;34massets[0m/
-rw-r--r-- 1 root root 608158 Dec 17 23:28 saved_model.pb
drwxr-xr-x 2 root root   4096 Dec 17 23:28 [01;34mvariables[0m/


### Loading exported SavedModel

We can load SavedModel as a new Tensorflow object `loaded_tf_model`.

In [None]:
loaded_tf_model = tf.saved_model.load(model_dir)
print('Exported signatures: ', loaded_tf_model.signatures)



Exported signatures:  _SignatureMap({'serving_default': <ConcreteFunction signature_wrapper(*, args_0) at 0x7FF247EE5F28>})


Then we can run inference using loaded Tensorflow model `loaded_tf_model` and compare resuls with the model `predict_op_tf` which was converted from Objax:

In [None]:
loaded_predict_op_tf = loaded_tf_model.signatures['serving_default']

y1_loaded_tf = loaded_predict_op_tf(tf.cast(x1, tf.float32))['output_0']
print('max(abs(y1_loaded_tf - y1_tf)) =', np.amax(np.abs(y1_loaded_tf - y1_tf)))

y2_loaded_tf = loaded_predict_op_tf(tf.cast(x2, tf.float32))['output_0']
print('max(abs(y2_loaded_tf - y2_tf)) =', np.amax(np.abs(y2_loaded_tf - y2_tf)))

max(abs(y1_loaded_tf - y1_tf)) = 0.0
max(abs(y2_loaded_tf - y2_tf)) = 0.0
