# Keras imagenet to ONNX
This Notebook download pretrained imagenet. Save the model on disk and convert it to ONNX model format.

In [None]:
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np

### First download keras imagenet model

In [None]:
model = ResNet50(weights='imagenet')

### Preprocess image

In [None]:
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

### Run prediction (KERAS)

In [None]:
preds = model.predict(x)
print('Predicted:', decode_predictions(preds, top=3)[0])

Save keras model on disk

In [None]:
model.save('/tf/keras-model.h5')

# ONNX - PART
### Import ONNX, onnxmltools, onnxruntime
#### Restart Python Kernel - We will load model from file (not required)

In [None]:
import IPython

IPython.Application.instance().kernel.do_shutdown(True) #automatically restarts kernel

In [None]:
import os
os.environ['TF_KERAS'] = '1'
import onnx
import onnxmltools
from tensorflow import keras
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
import onnxruntime
import numpy as np

In [None]:
model_keras = keras.models.load_model('/tf/keras-model.h5')

# Converting Keras to ONNX

In [None]:
onnx_model = onnxmltools.convert_keras(model_keras, target_opset=10)
onnxmltools.utils.save_model(onnx_model, 'onnx-model.onnx')

# Model converted. Now  let's test it with onnxruntime.

In [None]:
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

In [None]:
sess = onnxruntime.InferenceSession('onnx-model.onnx')

In [None]:
x = x if isinstance(x, list) else [x]
feed = dict([(input.name, x[n]) for n, input in enumerate(sess.get_inputs())])
pred_onnx = sess.run(None, feed)
print('Predicted:', decode_predictions(pred_onnx[0], top=3)[0])