<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#ORT-Inferencing" data-toc-modified-id="ORT-Inferencing-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>ORT Inferencing</a></span></li></ul></div>

In [1]:
# 1. magic to print version
# 2. magic so that the notebook will reload external python modules
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2

import numpy as np
from onnxruntime import InferenceSession, SessionOptions

%watermark -a 'Ethen' -d -t -v -p numpy,onnxruntime

Ethen 2021-06-14 07:35:49 

CPython 3.6.4
IPython 7.15.0

numpy 1.18.5
onnxruntime 1.5.1


# ORT Inferencing

Given a [text (a.k.a. sequence) classification task](http://ethen8181.github.io/machine-learning/model_deployment/onnxruntime/text_classification_onnxruntime.html), performing inferencing using onnx runtime's python API often times looks like the following.

In [2]:
def create_inference_session(
    model_path: str,
    intra_op_num_threads: int = 4,
    provider: str = 'CPUExecutionProvider'
) -> InferenceSession: 

    # properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    options.intra_op_num_threads = intra_op_num_threads

    # load the model as a onnx graph
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    return session

In [3]:
intra_op_num_threads = 4

onnx_model_path = "text_classification.onnx"

input_id = [
    101, 3183, 2079, 2017, 2293, 1996, 2087, 1998, 2339, 1029,
    102, 3183, 2079, 2017, 2293, 2087, 1998, 2339, 1029, 102
]

In [4]:
# create a session
session = create_inference_session(onnx_model_path, intra_op_num_threads)

# perform inferencing
input_feed = {'input_ids': [input_id]}
onnx_output = session.run(['output'], input_feed)[0]
onnx_output

array([[-1.2110593,  1.7904084]], dtype=float32)

This works great for a single example, but when it comes to multiple examples each with different sequence length, directly passing these inputs to our `InferenceSession` will result in an error.

In [5]:
input_ids = [
    [101, 3183, 2079, 2017, 2293, 1996, 2087, 1998, 2339, 1029,  102, 3183,
     2079, 2017, 2293, 2087, 1998, 2339, 1029,  102],
    [101, 2129, 2116, 9646, 2515, 1996, 5304, 2428, 2031, 1029,  102, 2129,
     2116, 9646, 2024, 2045, 1029,  102]
]
try:
    input_feed = {'input_ids': input_ids}
    onnx_output = session.run(['output'], input_feed)[0]
except RuntimeError as e:
    print(e)

Could not create tensor from given input list


To avoid this error message, we can either pad these inputs to the same sequence length or loop through them one by one to perform the graph execution. We perform the latter in the next code chunk.

In [6]:
def batch_predict(session, input_ids):
    batch_scores = []
    for input_id in input_ids:
        input_feed = {'input_ids': [input_id]}
        onnx_output = session.run(['output'], input_feed)[0]
        batch_scores.append(onnx_output)

    return np.concatenate(batch_scores)

In [7]:
onnx_output = batch_predict(session, input_ids)
onnx_output

array([[-1.2110593,  1.7904084],
       [-1.1729933,  1.6591723]], dtype=float32)

`SequenceClassificationOrtInference` class allows us to directly feed our batches of dynamic sequence length to the `.batch_predict` method.

In [8]:
from ort_inference import SequenceClassificationOrtInference


ort_inference = SequenceClassificationOrtInference(onnx_model_path, intra_op_num_threads)
ort_inference

<ort_inference.SequenceClassificationOrtInference at 0x7ff82cfb94c8>

In [9]:
# returns a list of lists by default
batch_score = ort_inference.batch_predict(input_ids)
ort_output = np.array(batch_score, dtype=np.float32)
ort_output

array([[-1.2110593,  1.7904084],
       [-1.1729934,  1.6591723]], dtype=float32)

We can confirm the output from both methods are identical

In [10]:
np.allclose(onnx_output, ort_output)

True

Avoiding the for loop in Python also makes it faster when we are working with larger batch sizes.

In [11]:
batch_size = 64
input_ids = [input_id for _ in range(batch_size)]
len(input_ids)

64

In [12]:
%%timeit
batch_predict(session, input_ids)

689 ms ± 22 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%%timeit
batch_score = ort_inference.batch_predict(input_ids)
ort_output = np.array(batch_score, dtype=np.float32)
ort_output

517 ms ± 10.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
