Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for double output variable for ONNX models #679

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions hummingbird/ml/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from collections import OrderedDict
from copy import deepcopy
import pprint
from uuid import uuid4

import numpy as np
from onnxconverter_common.optimizer import LinkedNode, _topological_sort
Expand All @@ -27,7 +26,12 @@
from ._topology import Topology
from ._utils import sklearn_installed, sparkml_installed
from .operator_converters import constants
from .supported import get_sklearn_api_operator_name, get_onnxml_api_operator_name, get_sparkml_api_operator_name
from .supported import (
get_sklearn_api_operator_name,
is_sklearn_models_with_two_outputs,
get_onnxml_api_operator_name,
get_sparkml_api_operator_name,
)

# Stacking is only supported starting from scikit-learn 0.22.
try:
Expand Down Expand Up @@ -249,13 +253,22 @@ def _parse_sklearn_single_model(topology, model, inputs):
if isinstance(model, str):
raise RuntimeError("Parameter model must be an object not a " "string '{0}'.".format(model))

alias = get_sklearn_api_operator_name(type(model))
model_type = type(model)
alias = get_sklearn_api_operator_name(model_type)
this_operator = topology.declare_logical_operator(alias, model)
this_operator.inputs = inputs

# We assume that all scikit-learn operators produce a single output.
variable = topology.declare_logical_variable("variable")
this_operator.outputs.append(variable)
if is_sklearn_models_with_two_outputs(model_type):
# This operator produces two outputs (e.g., label and probability)
variable = topology.declare_logical_variable("variable1")
this_operator.outputs.append(variable)

variable = topology.declare_logical_variable("variable2")
this_operator.outputs.append(variable)
else:
# We assume that all scikit-learn operators produce a single output.
variable = topology.declare_logical_variable("variable")
this_operator.outputs.append(variable)

return this_operator.outputs

Expand Down Expand Up @@ -602,7 +615,7 @@ def _parse_onnx_api(topology, model, inputs):
node_list = LinkedNode.build_from_onnx(graph.node, [], inputs_names + [in_.name for in_ in initializers], output_names)

# Make sure the entire node_list isn't only 'Identity'
if all([x.op_type == 'Identity' for x in node_list]):
if all([x.op_type == "Identity" for x in node_list]):
raise RuntimeError("ONNX model contained only Identity nodes {}.".format(node_list))

# This a new node list but with some node been removed plus eventual variable renaming.
Expand Down
71 changes: 71 additions & 0 deletions hummingbird/ml/supported.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,60 @@ def _build_sklearn_api_operator_name_map():
}


def _build_sklearn_operator_with_two_outputs():
"""
Associate Sklearn with the operator class names that have two outputs.
"""
"""
Put all supported Sklearn operators on a list.
"""
ops = set()

if sklearn_installed():
# Tree-based models
from sklearn.ensemble import (
ExtraTreesClassifier,
GradientBoostingClassifier,
HistGradientBoostingClassifier,
RandomForestClassifier,
)

from sklearn.tree import DecisionTreeClassifier

# Linear-based models
from sklearn.linear_model import (
LogisticRegression,
LogisticRegressionCV,
SGDClassifier,
)

# SVM-based models
from sklearn.svm import LinearSVC, SVC, NuSVC

ops.update(
[
# Trees
DecisionTreeClassifier,
ExtraTreesClassifier,
GradientBoostingClassifier,
HistGradientBoostingClassifier,
RandomForestClassifier,
LinearSVC,
LogisticRegression,
LogisticRegressionCV,
SGDClassifier,
# SVM
NuSVC,
SVC,
]
)

ops.update([xgb_operator_list[0]] if len(xgb_operator_list) > 0 else [])
ops.update([lgbm_operator_list[0]] if len(lgbm_operator_list) > 0 else [])

return ops


def _build_onnxml_api_operator_name_map():
"""
Associate ONNXML with the operator class names.
Expand Down Expand Up @@ -472,6 +526,22 @@ def get_sklearn_api_operator_name(model_type):
return sklearn_api_operator_name_map[model_type]


def is_sklearn_models_with_two_outputs(model_type):
"""
Get the operator name for the input model type in *scikit-learn API* format.

Args:
model_type: A scikit-learn model object (e.g., RandomForestClassifier)
or an object with scikit-learn API (e.g., LightGBM)

Returns:
A string which stands for the type of the input model in the Hummingbird conversion framework
"""
assert model_type in sklearn_api_operator_name_map

return model_type in sklearn_operator_with_two_outputs


def get_onnxml_api_operator_name(model_type):
"""
Get the operator name for the input model type in *ONNX-ML API* format.
Expand Down Expand Up @@ -513,6 +583,7 @@ def get_sparkml_api_operator_name(model_type):
prophet_operator_list = _build_prophet_operator_list()

sklearn_api_operator_name_map = _build_sklearn_api_operator_name_map()
sklearn_operator_with_two_outputs = _build_sklearn_operator_with_two_outputs()
onnxml_api_operator_name_map = _build_onnxml_api_operator_name_map()
sparkml_api_operator_name_map = _build_sparkml_api_operator_name_map()

Expand Down
41 changes: 36 additions & 5 deletions tests/test_xgboost_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sklearn.model_selection import train_test_split

import hummingbird.ml
from hummingbird.ml._utils import xgboost_installed, tvm_installed, pandas_installed
from hummingbird.ml._utils import xgboost_installed, tvm_installed, pandas_installed, onnx_runtime_installed
from hummingbird.ml import constants
from tree_utils import gbdt_implementation_map

Expand Down Expand Up @@ -253,7 +253,6 @@ def test_run_xgb_pandas(self):
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
def test_xgb_regressor_converter_torchscript(self):
warnings.filterwarnings("ignore")
import torch

for max_depth in [1, 3, 8, 10, 12]:
model = xgb.XGBRegressor(n_estimators=10, max_depth=max_depth)
Expand All @@ -272,7 +271,6 @@ def test_xgb_regressor_converter_torchscript(self):
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires XGBoost installed")
def test_xgb_classifier_converter_torchscript(self):
warnings.filterwarnings("ignore")
import torch

for max_depth in [1, 3, 8, 10, 12]:
model = xgb.XGBClassifier(n_estimators=10, max_depth=max_depth)
Expand All @@ -293,7 +291,6 @@ def test_xgb_classifier_converter_torchscript(self):
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_xgb_regressor_converter_tvm(self):
warnings.filterwarnings("ignore")
import torch

for max_depth in [1, 3, 8, 10, 12]:
model = xgb.XGBRegressor(n_estimators=10, max_depth=max_depth)
Expand All @@ -313,7 +310,6 @@ def test_xgb_regressor_converter_tvm(self):
@unittest.skipIf(not tvm_installed(), reason="TVM test requires TVM installed")
def test_xgb_classifier_converter_tvm(self):
warnings.filterwarnings("ignore")
import torch

for max_depth in [1, 3, 8, 10, 12]:
model = xgb.XGBClassifier(n_estimators=10, max_depth=max_depth)
Expand All @@ -328,6 +324,41 @@ def test_xgb_classifier_converter_tvm(self):
self.assertIsNotNone(tvm_model)
np.testing.assert_allclose(model.predict_proba(X), tvm_model.predict_proba(X), rtol=1e-06, atol=1e-06)

# Check that we can export into ONNX.
@unittest.skipIf(not onnx_runtime_installed(), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS")
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires LightGBM installed")
def test_xgb_onnx(self):
warnings.filterwarnings("ignore")

X = [[0, 1], [1, 1], [2, 0]]
X = np.array(X, dtype=np.float32)
y = np.array([100, -10, 50], dtype=np.float32)
model = xgb.XGBRegressor(n_estimators=3, min_child_samples=1)
model.fit(X, y)

# Create ONNX model
onnx_model = hummingbird.ml.convert(model, "onnx", X)

np.testing.assert_allclose(onnx_model.predict(X).flatten(), model.predict(X))

# Check output renaming with two outputs
@unittest.skipIf(not onnx_runtime_installed(), reason="ONNXML test require ONNX, ORT and ONNXMLTOOLS")
@unittest.skipIf(not xgboost_installed(), reason="XGBoost test requires LightGBM installed")
def test_xgb_onnx_two_outputs(self):
model = xgb.XGBClassifier(n_estimators=3, max_depth=5)
np.random.seed(0)
X = np.random.rand(100, 200)
X = np.array(X, dtype=np.float32)
y = np.random.randint(2, size=100)

model.fit(X, y)

torch_model = hummingbird.ml.convert(model, "onnx", X, extra_config={constants.OUTPUT_NAMES: ['labels', 'predictions']})
self.assertIsNotNone(torch_model)

self.assertTrue(torch_model.model.graph.output[0].name == 'labels')
self.assertTrue(torch_model.model.graph.output[1].name == 'predictions')


if __name__ == "__main__":
unittest.main()