## Using the ONNX model file for predictions
### In this Notebook we take the model trained, saved in ONNX format and we do some predictions using a sampled test dataset

In [10]:
import time
import os
import random as rn
import json

import numpy as np
import pandas as pd
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

import tensorflow as tf

# conda env: mlcpuv1
import keras2onnx
import onnxruntime as rt

In [2]:
# check TF version (> 2.3)
print('TF version', tf.__version__)
print('ONNX runtime version', rt.__version__)
print('keras2onnx version', keras2onnx.__version__)

TF version 2.3.1
ONNX runtime version 1.4.0
keras2onnx version 1.7.0


### prepare the test dataset

In [3]:
# prepare the dataset for test

# we take the dataset from Sklearn
data = load_breast_cancer(as_frame=True)

# I prefer working with Dataframe
orig_df = data.frame

# we must rename columns, to remove spaces in names
# otherwise we get problems with ONNX

# substitute all spaces with _
dict_columns = {}

for col in orig_df.columns:
    dict_columns[col] = col.replace(" ", "_")

orig_df = orig_df.rename(columns=dict_columns)

# Split the dataset in train, valid, test

N_TOTAL = orig_df.shape[0]
FRAC_TRAIN = 0.7
FRAC_VALID = 0.15

N_TRAIN = int(N_TOTAL * FRAC_TRAIN)
N_VALID = int(N_TOTAL * FRAC_VALID)
N_TEST = N_TOTAL - N_TRAIN - N_VALID

print('Numbers of samples for (total, train, valid, test):', N_TOTAL, N_TRAIN, N_VALID, N_TEST)

# shuffle the data
orig_df = orig_df.sample(frac=1.)

df_train = orig_df.iloc[:N_TRAIN]
df_valid = orig_df.iloc[N_TRAIN:N_TRAIN+N_VALID]
df_test = orig_df.iloc[N_TRAIN+N_VALID:]

Numbers of samples for (total, train, valid, test): 569 398 85 86


In [4]:
# to convert in TF dataset
#adapted from https://www.tensorflow.org/tutorials/structured_data/feature_columns
def df_to_dataset(df, predictor,  batch_size=32, shuffle=True):
    df = df.copy()
    labels = df.pop(predictor)
    ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))
    
    if shuffle:
        # don't shuffle test
        ds = ds.shuffle(buffer_size=len(df))
        
    ds = ds.batch(batch_size)
    return ds

In [5]:
# we take only the test dataset
ds_test = df_to_dataset(df_test, 'target', batch_size=16, shuffle=False)

### load the ONNX model

In [44]:
# load the ONNX model
ONNX_MODEL_FILE = 'modelbc-artifact/modelbc.onnx'

# first, a function to load the model and create an ONNX session
def create_session(onnx_file_name, print_info=False):
    sess = rt.InferenceSession(onnx_file_name)
    
    print('Loading OK')
    
    if print_info:
        print("ONNX model expects ", len(sess.get_inputs()), 'features:')
        # prints names of features
        for n, input in enumerate(sess.get_inputs()):
            print(input.name)
        
    return sess

# build the input as expected from ONNX runtime
def build_input_feed(f_batch):
    # input: the features_batch as extracted from TF dataset
    # devo costruire il dict come se lo aspetta onnx
    
    input_dict = {}
    
    # ogni feature è un singolo valore
    for col in f_batch.keys():
        # get the numpy array of values
        values = f_batch[col].numpy()
        n_rows = len(values)
        input_dict[col] = values.reshape((n_rows, 1))
    
    # input_feed is the dictionary input to ONNX model
    # for every feature a column vector (n_rows, 1)
    return input_dict

def onnx_predict(f_batch, sess=create_session(ONNX_MODEL_FILE)):
    # transform the input
    input_feed = build_input_feed(f_batch)
    
    # run inference
    pred_onnx = sess.run(None, input_feed)
    
    # build output
    output = {}
    output['probs'] = pred_onnx[0]
    
    return output

def predict(input_dict, sess=create_session(ONNX_MODEL_FILE)):
    
    input_feed = {}
    
    for col in input_dict.keys():
        # get the numpy array of values
        np_values = np.array(input_dict[col])
        n_rows = len(np_values)
        
        input_feed[col] = np_values.reshape((n_rows, 1))
        
    pred_onnx = sess.run(None, input_feed)
        
    output = {}
    output['probs'] = pred_onnx[0]
        
    return output

Loading OK
Loading OK


### do the test

In [7]:
# this way I take only the feature batch out of a dataset batch

# do the test on the entire test dataset
for i, f_batch in enumerate(iter(ds_test)):
    print('Batch n.', i+1)
    
    # needed
    f_batch = f_batch[0]
    
    tStart = time.time()
    
    onnx_probs = onnx_predict(f_batch)
    tEla = time.time() - tStart
    print('Predictions:', onnx_probs['probs'].ravel())
    
    print('Time (sec.) for batch prediction:', round(tEla, 3))
    print('')
    
print('')
print('ONNX test OK.')

Batch n. 1
Predictions: [3.4209043e-02 9.9927604e-01 7.8657955e-02 1.9699335e-05 9.9996507e-01
 9.9995029e-01 8.5532665e-05 9.9990296e-01 9.3904138e-04 0.0000000e+00
 9.9976373e-01 9.9999046e-01 4.3094158e-05 9.9998891e-01 9.9797368e-01
 1.1874348e-02]
Time (sec.) for batch prediction: 0.002

Batch n. 2
Predictions: [0.0000000e+00 0.0000000e+00 0.0000000e+00 2.3551297e-01 2.4139881e-06
 1.4573336e-05 1.2516975e-06 2.6822090e-07 8.0625677e-01 7.2739691e-02
 9.9982810e-01 9.9993289e-01 2.8401023e-01 6.2039793e-03 2.9720610e-01
 9.9933505e-01]
Time (sec.) for batch prediction: 0.002

Batch n. 3
Predictions: [9.9850005e-01 3.9249659e-05 9.9999654e-01 7.7486038e-06 9.9508357e-01
 9.7657365e-01 9.9995804e-01 4.9226868e-01 9.9983120e-01 9.9553525e-01
 9.9377042e-01 9.9997973e-01 9.9999344e-01 4.3253601e-03 6.7058474e-02
 9.9996686e-01]
Time (sec.) for batch prediction: 0.002

Batch n. 4
Predictions: [6.2584877e-07 9.8074281e-01 9.5055419e-01 0.0000000e+00 9.9995172e-01
 1.4218688e-04 9.999983

#### Ok, to do a prediction on 16 samples it takes around 2 msec.

### Now check using as input the JSON file sample1.json

In [47]:
# read the data from file
data_str = open('sample1.json', 'r').read()

input_dict = json.loads(data_str)

In [48]:
onnx_probs = predict(input_dict)

In [49]:
onnx_probs

{'probs': array([[0.0000000e+00],
        [9.9976540e-01],
        [2.3906797e-02],
        [9.9380422e-01],
        [8.5192496e-01],
        [9.9996710e-01],
        [9.9933505e-01],
        [9.9992895e-01],
        [9.9999237e-01],
        [2.3245811e-06],
        [9.9956256e-01],
        [1.7464161e-05],
        [2.9324174e-02],
        [1.1874348e-02],
        [9.9999654e-01],
        [9.6708000e-01]], dtype=float32)}