Link: http://onnx.ai/sklearn-onnx/index.html

In [1]:
# Train a model.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)

RandomForestClassifier()

In [2]:
# Convert into ONNX format
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

initial_type = [('float_input', FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type)

filename = '../models/rf_iris.onnx'
with open(filename, "wb") as f:
    f.write(onx.SerializeToString())

In [3]:
# Compute the prediction with ONNX Runtime
import onnxruntime as rt
import numpy

sess = rt.InferenceSession(filename)

input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name

pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]

In [4]:
pred_onx

array([0, 1, 2, 0, 0, 2, 0, 0, 2, 1, 2, 2, 1, 1, 2, 2, 1, 0, 1, 1, 2, 2,
       0, 1, 0, 0, 2, 2, 0, 0, 1, 1, 0, 2, 0, 2, 1, 1], dtype=int64)

In [5]:
clr.predict(X_test)

array([0, 1, 2, 0, 0, 2, 0, 0, 2, 1, 2, 2, 1, 1, 2, 2, 1, 0, 1, 1, 2, 2,
       0, 1, 0, 0, 2, 2, 0, 0, 1, 1, 0, 2, 0, 2, 1, 1])

## Converting ONNX models to ORT format
Link: https://onnxruntime.ai/docs/tutorials/mobile/model-conversion.html

In [7]:
! python -m onnxruntime.tools.convert_onnx_models_to_ort "../models/rf_iris.onnx"

Converting optimized ONNX model C:\Users\weldl\Workspace\onnx-android-ml-pipeline\models\rf_iris.onnx to ORT format model C:\Users\weldl\Workspace\onnx-android-ml-pipeline\models\rf_iris.all.ort
Converted 1 models. 0 failures.


2021-10-11 10:51:37,368 ort_format_model.utils [INFO] - Processed C:\Users\weldl\Workspace\onnx-android-ml-pipeline\models\rf_iris.all.ort
2021-10-11 10:51:37,368 ort_format_model.utils [INFO] - Created config in C:\Users\weldl\Workspace\onnx-android-ml-pipeline\models\rf_iris.all.required_operators.config
