## Keras to ONNX conversion
This notebook shows how to convert your trained Keras model to ONNX, the generic format supported by DIANNA. <br>

The conversion is complete with the tf2onnx Python package, which supports both the SavedModel format and the older HDF5 (.h5 or .keras) format. It can convert multi-backend keras as well as tf.keras models.

In [1]:
import os

import numpy as np
import tensorflow as tf
from tensorflow import keras

import onnx
import onnxruntime as ort
# In addition to these imports, this notebook
# depends on tf2onnx. It is used from the command line.

Download and initialize built-in model.

In [2]:
#model = keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet')
model = keras.applications.resnet50.ResNet50(weights='imagenet')

2022-02-01 14:08:10.935671: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2022-02-01 14:08:10.936024: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-02-01 14:08:10.941962: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5


Evaluate model on some random input.

In [3]:
input_shape = [1] + model.inputs[0].shape[1:]  # input shape without a 1 for batch size, instead of None
input_data = np.random.normal(size=input_shape).astype(np.float32)
pred = model.predict(input_data)

2022-02-01 14:08:31.066266: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2022-02-01 14:08:31.071050: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2304005000 Hz


Save keras model to SavedModel format.

In [4]:
savedmodel_dir = 'resnet_savedmodel'
tf.saved_model.save(model, savedmodel_dir)

2022-02-01 14:08:57.232998: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.


INFO:tensorflow:Assets written to: ../models/resnet_savedmodel/assets


Convert to ONNX.

In [5]:
onnx_savedmodel = 'resnet_savedmodel.onnx'
!python -m tf2onnx.convert --saved-model {savedmodel_dir} --output {onnx_savedmodel} --signature_def serving_default --tag serve

2022-02-01 14:10:56,018 - INFO - Signatures found in model: [serving_default].
2022-02-01 14:10:56,019 - INFO - Output names: ['predictions']
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
2022-02-01 14:10:59,963 - INFO - Using tensorflow=2.4.1, onnx=1.9.0, tf2onnx=1.9.3/1190aa
2022-02-01 14:10:59,963 - INFO - Using opset <onnx, 9>
2022-02-01 14:11:03,452 - INFO - Computed 0 values for constant folding
2022-02-01 14:11:05,876 - INFO - Optimizing ONNX model
2022-02-01 14:11:11,212 - INFO - After optimization: Add -1 (18->17), BatchNormalization -53 (53->0), Const -162 (270->108), GlobalAveragePool +1 (0->1), Identity -57 (57->0), ReduceMean -1 (1->0), Squeeze +1 (0->1), Transpose -213 (214->1)
2022-02-01 14:11:11,764 - INFO - 
2022-02-01 14:11:11,764 - INFO - Successfully converted TensorFlow model ../models/resnet_savedmodel to ONNX
2022-02-01 14:11:11,765 - INFO - Model inputs: ['inp

Evaluate ONNX models and compare to keras model output.

In [6]:
# verify the ONNX model is valid
onnx_model = onnx.load(onnx_savedmodel)
onnx.checker.check_model(onnx_model)

# get ONNX predictions
sess = ort.InferenceSession(onnx_savedmodel)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

onnx_input = {input_name: input_data}
pred_onnx = sess.run([output_name], onnx_input)[0]

print(np.allclose(pred_onnx, pred, atol=1e-5))

True
