<img width="150" alt="Logo_ER10" src="https://user-images.githubusercontent.com/3244249/151994514-b584b984-a148-4ade-80ee-0f88b0aefa45.png">

## Tensorflow to ONNX conversion
This notebook shows how to convert your Tensorflow model to ONNX, the generic format supported by DIANNA. <br>
The conversion is complete with the tf2onnx Python package, which supports tensorflow 1.X, 2.X, and tf.keras, and tflite.

In [3]:
import os
import numpy as np
import tensorflow as tf
import onnx
import onnxruntime as ort
# In addition to these imports, this notebook
# depends on tf2onnx. It is used from the command line.

ModuleNotFoundError: No module named 'numpy'

Download tensorflow model in GraphDef format.

In [2]:
fname = tf.keras.utils.get_file(
    'mobilenet.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz',
    extract=True)
graph_file = os.path.join(os.path.dirname(fname), 'mobilenet_v1_1.0_224/frozen_graph.pb')

NameError: name 'tf' is not defined

Create inference function from frozen graph. (Here tensorflow 2 is used.)

In [3]:
graph_input = 'input'
graph_output = 'MobilenetV1/Predictions/Softmax'

# helper function to load graph in tf2
# taken from https://www.tensorflow.org/guide/migrate
def wrap_frozen_graph(graph_def, inputs, outputs):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")
        
    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph
    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs)
    )

graph_def = tf.compat.v1.GraphDef()
with open(graph_file, 'rb') as f:
    graph_def.ParseFromString(f.read())
    
func = wrap_frozen_graph(graph_def, inputs=graph_input+':0', outputs=graph_output+':0')

2022-01-31 15:08:11.337241: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-01-31 15:08:11.337536: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-01-31 15:08:11.340018: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


Evaluate model on some random input.

In [4]:
input_shape = func.inputs[0].shape
input_data = tf.random.normal(shape=input_shape, dtype=tf.float32)
pred = func(input_data)

2022-01-31 15:08:14.693775: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2022-01-31 15:08:14.705357: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2304005000 Hz


Convert model to tflite and SavedModel format.

In [5]:
# convert to tflite
tflite_file = '../../dianna/dianna/models/mobilenet.tflite'
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
    graph_def_file=graph_file,
    input_arrays=[graph_input],
    input_shapes={graph_input: input_shape},
    output_arrays=[graph_output]
)

# Save the model
with open(tflite_file, 'wb') as f:
  f.write(converter.convert())

2022-01-31 15:08:34.284342: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
2022-01-31 15:08:34.481577: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:316] Ignored output_format.
2022-01-31 15:08:34.481647: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:319] Ignored drop_control_dependency.
2022-01-31 15:08:34.527677: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


In [6]:
# create a Trackable object that can be saved as SavedModel
class Model(tf.Module):
    def __init__(self, function):
        super().__init__()
        self.function = function
    
    def __call__(self, x):
        return self.function(x)
    
model = Model(func)

# save the model
savedmodel_dir = 'mobilenet_savedmodel'
tf.saved_model.save(model, savedmodel_dir)

INFO:tensorflow:Assets written to: mobilenet_savedmodel/assets


INFO:tensorflow:Assets written to: mobilenet_savedmodel/assets


Convert GraphDef/tflite/SavedModel to onnx.

In [7]:
# graphdef to onnx
onnx_graphdef = 'mobilenet_graph.onnx'
!python -m tf2onnx.convert --graphdef {graph_file} --output {onnx_graphdef} --inputs {graph_input}:0 --outputs {graph_output}:0

# tflite to onnx
onnx_tflite = 'mobilenet_tflite.onnx'
!python -m tf2onnx.convert --tflite {tflite_file} --output {onnx_tflite}

# SavedModel to onnx
onnx_savedmodel = 'mobilenet_savedmodel.onnx'
!python -m tf2onnx.convert --saved-model {savedmodel_dir} --output {onnx_savedmodel} --signature_def serving_default --tag serve

# For completeness, this is how to convert a tf.keras model to ONNX:
# !python -m tf2onnx.convert --keras {model_dir} --output {output_file}

Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
2022-01-31 15:09:07,497 - INFO - Using tensorflow=2.4.1, onnx=1.9.0, tf2onnx=1.9.3/1190aa
2022-01-31 15:09:07,497 - INFO - Using opset <onnx, 9>
2022-01-31 15:09:08,222 - INFO - Computed 0 values for constant folding
2022-01-31 15:09:08,796 - INFO - Optimizing ONNX model
2022-01-31 15:09:09,597 - INFO - After optimization: Add -14 (27->13), Cast -1 (1->0), Const -13 (83->70), Identity -2 (2->0), Reshape -13 (14->1), Transpose -70 (71->1)
2022-01-31 15:09:09,669 - INFO - 
2022-01-31 15:09:09,669 - INFO - Successfully converted TensorFlow model /home/yangliu/.keras/datasets/mobilenet_v1_1.0_224/frozen_graph.pb to ONNX
2022-01-31 15:09:09,669 - INFO - Model inputs: ['i

Evaluate ONNX models and compare to tensorflow output.

In [9]:
models = {'graphdef': onnx_graphdef, 'tflite': onnx_tflite, 'SavedModel': onnx_savedmodel}

for model, fname in models.items():

    # verify the ONNX model is valid
    onnx_model = onnx.load(fname)
    onnx.checker.check_model(onnx_model)

    
    # get ONNX predictions
    sess = ort.InferenceSession(fname)
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name
    
    onnx_input = {input_name: input_data.numpy()}
    pred_onnx = sess.run([output_name], onnx_input)[0]
    
    print(f"{model}: {np.allclose(pred_onnx, pred, atol=1e-5)}")

graphdef: True
tflite: True
SavedModel: True
