In [1]:
import numpy as np
import onnxruntime
from PIL import Image

## Random Forest

A Random Forest classification model from Scikit-Learn that predictions the iris class. 

In [2]:
# Get input
x = np.loadtxt('iris_data.csv', delimiter=',')
print(f'Input shape: {x.shape}')

# Load model and Inference
ort_session = onnxruntime.InferenceSession("rf_iris.onnx")
input_name = ort_session.get_inputs()[0].name
label_name = ort_session.get_outputs()[0].name
pred_onx = ort_session.run([label_name], {input_name: x.astype(np.float32)})[0]
print(f'Prediction: {pred_onx.tolist()}')

Input shape: (38, 4)
Prediction: [0, 2, 2, 2, 1, 0, 1, 2, 0, 0, 0, 2, 1, 0, 0, 2, 1, 1, 0, 2, 0, 0, 2, 2, 0, 1, 2, 0, 2, 0, 0, 2, 2, 2, 1, 2, 1, 1]


## Super Resolution Model

A Pytorch model that increase the size of an image and makes it sharper. The output is another image. 

In [3]:
# Get input
x = np.random.rand(1, 1, 224, 224)  # fake small image
print(f'Input shape: {x.shape}')

# Load model and Inference
ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: x.astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)[0]
print(f'Output shape: {ort_outs.shape}')

Input shape: (1, 1, 224, 224)
Output shape: (1, 1, 672, 672)


## Faster RCNN

This is an object detection within torchvision. The return output list the bbox, class, and confidence score. 

In [4]:
def get_image(image_path):
    img = Image.open(image_path)
    img.readonly = False  # https://github.com/python-pillow/Pillow/issues/3336
    # onnx only works with numpy. and values range from 0-1
    img = np.asarray(img) / 255
    img = img.transpose((2, 0, 1))[None, :, :, :]  # add batchsize of 1
    return img.astype("float32")  # onnx only works with 4bits

In [5]:
# Get input
x = get_image('people.jpg')  # fake data didn't work

# Load model and Inference
ort_session = onnxruntime.InferenceSession("faster-rcnn.onnx")
ort_inputs = {ort_session.get_inputs()[0].name: x.astype(np.float32)}
ort_outs = ort_session.run(None, ort_inputs)
ort_outs

[array([[0.0000000e+00, 3.1964536e+02, 4.4151535e+02, 1.3232561e+03],
        [7.6530365e+02, 3.5194049e+02, 1.1937161e+03, 1.3122842e+03],
        [1.1569447e+03, 3.9462311e+02, 1.5692056e+03, 1.3221746e+03],
        [1.5193402e+03, 4.2956503e+02, 1.9924286e+03, 1.3306053e+03],
        [3.6410751e+02, 3.8761441e+02, 7.7179010e+02, 1.3061730e+03],
        [6.1342780e+02, 2.6801114e+02, 9.1221960e+02, 1.2417460e+03],
        [1.4390378e+03, 3.3610452e+02, 1.6699921e+03, 7.9422113e+02],
        [1.7239886e+03, 2.8932889e+02, 1.9978444e+03, 8.8991754e+02],
        [1.1548141e+03, 1.8347633e+02, 1.4561610e+03, 7.0533588e+02],
        [1.6002692e+03, 1.9745711e+02, 1.7637482e+03, 6.3540698e+02],
        [8.8222968e+02, 1.9024481e+02, 1.2439409e+03, 7.4671735e+02],
        [5.4619714e+02, 2.4166563e+02, 7.1724133e+02, 5.4152289e+02],
        [2.3673456e+02, 1.0111604e+02, 4.8953281e+02, 6.6732379e+02],
        [8.3153918e+02, 2.0396672e+02, 1.1440724e+03, 6.0038452e+02],
        [2.1523401e+

## Inference Environment 

In [6]:
!conda list -n onnx-infer

# packages in environment at /home/jack/anaconda3/envs/onnx-infer:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
backcall                  0.1.0                    py37_0  
blas                      1.0                         mkl  
ca-certificates           2020.1.1                      0  
certifi                   2020.4.5.1               py37_0  
cycler                    0.10.0                   py37_0  
dbus                      1.13.14              hb2f20db_0  
decorator                 4.4.2                      py_0  
entrypoints               0.3                      py37_0  
expat                     2.2.6                he6710b0_0  
fontconfig                2.13.0               h9420a91_0  
freetype                  2.9.1                h8a8886c_1  
glib                      2.63.1               h5a9c865_0  
gst-plugins-base          1.14.0               hbbd80ab_1  
gstreamer               