## Tensorflow to ONNX conversion

tf2onnx suports tf 1.X, tf 2.X, tf.keras, and tflite

In [1]:
import os

import numpy as np
import tensorflow as tf

import onnx
import onnxruntime as ort

In addition to these imports, this notebook depends on tf2onnx. It is used from the command line

## Download model in GraphDef format

In [2]:
fname = tf.keras.utils.get_file(
    'mobilenet.tgz',
    'https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz',
    extract=True)
graph_file = os.path.join(os.path.dirname(fname), 'mobilenet_v1_1.0_224/frozen_graph.pb')

## Create inference function from frozen graph 
Here tensorflow 2 is used

In [3]:
graph_input = 'input'
graph_output = 'MobilenetV1/Predictions/Softmax'

# helper function to load graph in TF2
# taken from https://www.tensorflow.org/guide/migrate
def wrap_frozen_graph(graph_def, inputs, outputs):
    def _imports_graph_def():
        tf.compat.v1.import_graph_def(graph_def, name="")
        
    wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
    import_graph = wrapped_import.graph
    return wrapped_import.prune(
        tf.nest.map_structure(import_graph.as_graph_element, inputs),
        tf.nest.map_structure(import_graph.as_graph_element, outputs)
    )


graph_def = tf.compat.v1.GraphDef()
with open(graph_file, 'rb') as f:
    graph_def.ParseFromString(f.read())
    
func = wrap_frozen_graph(graph_def, inputs=graph_input+':0', outputs=graph_output+':0')

## Evaluate model on some random input

In [4]:
input_shape = func.inputs[0].shape
input_data = tf.random.normal(input_shape, dtype=tf.float32)
pred = func(input_data)

## Convert model to tflite and SavedModel format

In [5]:
# convert to tflite
tflite_file = 'mobilenet.tflite'
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
    graph_def_file=graph_file,
    input_arrays=[graph_input],
    input_shapes={graph_input: input_shape},
    output_arrays=[graph_output]
)

# Save the model
with open(tflite_file, 'wb') as f:
  f.write(converter.convert())

In [6]:
# create a Trackable object that can be saved as SavedModel
class Model(tf.Module):
    def __init__(self, function):
        super().__init__()
        self.function = function
    
    def __call__(self, x):
        return self.function(x)
    
model = Model(func)

# save the model
savedmodel_dir = 'mobilenet_savedmodel'
tf.saved_model.save(model, savedmodel_dir)

INFO:tensorflow:Assets written to: mobilenet_savedmodel/assets


INFO:tensorflow:Assets written to: mobilenet_savedmodel/assets


## Convert to ONNX

In [7]:
# graphdef to onnx
onnx_graphdef = 'mobilenet_graph.onnx'
!python -m tf2onnx.convert --graphdef {graph_file} --output {onnx_graphdef} --inputs {graph_input}:0 --outputs {graph_output}:0

# tflite to onnx
onnx_tflite = 'mobilenet_tflite.onnx'
!python -m tf2onnx.convert --tflite {tflite_file} --output {onnx_tflite}

# SavedModel to onnx
onnx_savedmodel = 'mobilenet_savedmodel.onnx'
!python -m tf2onnx.convert --saved-model {savedmodel_dir} --output {onnx_savedmodel} --signature_def serving_default --tag serve

Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.convert_variables_to_constants`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
2021-03-17 12:38:54,333 - INFO - Using tensorflow=2.4.1, onnx=1.8.1, tf2onnx=1.8.3/0fbdb5
2021-03-17 12:38:54,333 - INFO - Using opset <onnx, 9>
2021-03-17 12:38:54,652 - INFO - Computed 0 values for constant folding
2021-03-17 12:38:55,414 - INFO - Optimizing ONNX model
2021-03-17 12:38:55,913 - INFO - After optimization: Add -27 (27->0), Cast -1 (1->0), Const -13 (70->57), Identity -2 (2->0), Mul -13 (13->0), Transpose -57 (58->1)
2021-03-17 12:38:55,938 - INFO - 
2021-03-17 12:38:55,939 - INFO - Successfully converted TensorFlow model /Users/loostrum/.keras/datasets/mobilenet_v1_1.0_224/frozen_graph.pb to ONNX
2021-03-17 12:38:55,994 - INFO - ONNX model is saved 

## Evaluate ONNX models and compare to tensorflow output

In [8]:
models = {'graphdef': onnx_graphdef, 'tflite': onnx_tflite, 'SavedModel': onnx_savedmodel}

for model, fname in models.items():

    # verify the ONNX model is valid
    onnx_model = onnx.load(fname)
    onnx.checker.check_model(onnx_model)

    
    # get ONNX predictions
    sess = ort.InferenceSession(fname)
    input_name = sess.get_inputs()[0].name
    output_name = sess.get_outputs()[0].name
    
    onnx_input = {input_name: input_data.numpy()}
    pred_onnx = sess.run([output_name], onnx_input)[0]
    
    print(f"{model}: {np.allclose(pred_onnx, pred, atol=1e-5)}")

graphdef: True
tflite: True
SavedModel: True
