<a href="https://colab.research.google.com/github/yinweisu/gluon-cv/blob/onnx/tools/onnx/notebooks/onnx_classification_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install --upgrade numpy mxnet gluoncv onnxruntime

Requirement already up-to-date: numpy in /usr/local/lib/python3.7/dist-packages (1.20.1)
Requirement already up-to-date: mxnet in /usr/local/lib/python3.7/dist-packages (1.7.0.post2)
Requirement already up-to-date: gluoncv in /usr/local/lib/python3.7/dist-packages (0.9.4.post1)
Requirement already up-to-date: onnxruntime in /usr/local/lib/python3.7/dist-packages (1.7.0)


In [None]:
import numpy as np
import onnxruntime as rt
import mxnet as mx
import gluoncv as gcv
import urllib.request
import os
import csv
from ast import literal_eval

In [None]:
def fetch_model_and_shape(model_name):
  model_list_link = 'Not Ready Yet'
  model_list_fn = 'model_list.csv'
  urllib.request.urlretrieve(model_list_link, filename=model_list_fn)
  with open(model_list_fn, 'r') as csvfile:
    csv_reader = csv.reader(csvfile, delimiter=',')
    for row in csv_reader:
        mtype, mname, mshape, mlink = row[0], row[1], row[2], row[3]
        if mname == model_name:
          if mtype != 'Obj Classification':
            raise Exception(f'{mtype} not supported. Please checkout the corresponding notebook')
          onnx_model_fn = model_name+'.onnx'
          urllib.request.urlretrieve(mlink, filename=onnx_model_fn)
          break
  return onnx_model_fn, literal_eval(mshape)

def prepare_img(img_url, input_shape):
  # input_shape: BHWC
  height, width = input_shape[1], input_shape[2]
  urllib.request.urlretrieve(img_url)
  img_name = os.path.basename(img_url)
  img = mx.image.imread(img_name)
  img = mx.image.imresize(img, width, height)
  img = img.expand_dims(0).astype('float32')

  return img

def prepare_label(model_name):
  net = gcv.model_zoo.get_model(model_name, pretrained=True)
  return net.classes


Prepare the data: 

**Make sure to replace model and the image you want to use**

In [None]:
model_name = 'resnet18_v1'
model, input_shape = fetch_model_and_shape(model_name)
img_url = 'https://raw.githubusercontent.com/dmlc/web-data/master/gluoncv/classification/mt_baker.jpg'
img = prepare_img(img_url, input_shape)
labels = prepare_label(model_name)

In [None]:
# Create a onnx inference session and get the input name
onnx_session = rt.InferenceSession(model, None)
input_name = onnx_session.get_inputs()[0].name
# Make prediction
pred = onnx_session.run([], {input_name: img.asnumpy()})[0]
prob = mx.nd.softmax(pred)[0].asnumpy()
# find the 5 class indices with the highest score
ind = mx.nd.topk(pred, k=5)[0].astype('int').asnumpy().tolist()
# print the class name and predicted probability
print('The input picture is classified to be')
for i in range(5):
    print('- [%s], with probability %.3f.'%(labels[ind[i]], prob[ind[i]]))