Skip to content

Commit

Permalink
Fixes #184, #201, fixes GradientBoostingClassifier for multi class (#226
Browse files Browse the repository at this point in the history
)

* add a unit test for issue #201, fixes #184
* update test_sklearn_gradient_boosting_converters.py
* extend fix to sklearn < 0.21
  • Loading branch information
xadupre committed Jul 30, 2019
1 parent aae4f7a commit 33abc90
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
1 change: 0 additions & 1 deletion skl2onnx/_parse.py
Expand Up @@ -296,7 +296,6 @@ def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None):
this_operator.classlabels_strings = classes
label_type = StringType()

print("***", scope.tensor_type())
output_label = scope.declare_local_variable('output_label', label_type)
output_probability = scope.declare_local_variable(
'output_probability',
Expand Down
16 changes: 13 additions & 3 deletions skl2onnx/operator_converters/gradient_boosting.py
Expand Up @@ -53,10 +53,20 @@ def convert_sklearn_gradient_boosting_classifier(scope, operator, container):
op.loss))
else:
# class_prior_ was introduced in scikit-learn 0.21.
if hasattr(op.init_, 'class_prior_'):
base_values = op.init_.class_prior_
x0 = np.zeros((1, op.estimators_[0, 0].n_features_))
if hasattr(op, '_raw_predict_init'):
# sklearn >= 0.21
base_values = op._raw_predict_init(x0).ravel()
elif hasattr(op, '_init_decision_function'):
# sklearn >= 0.21
base_values = op._init_decision_function(x0).ravel()
else:
base_values = op.init_.priors
raise RuntimeError("scikit-learn < 0.19 is not supported.")

# if hasattr(op.init_, 'class_prior_'):
# base_values = op.init_.class_prior_
# else:
# base_values = op.init_.priors
else:
raise NotImplementedError(
'Setting init to an estimator is not supported, you may raise an '
Expand Down
48 changes: 40 additions & 8 deletions tests/test_sklearn_gradient_boosting_converters.py
Expand Up @@ -4,26 +4,37 @@
# license information.
# --------------------------------------------------------------------------

from logging import getLogger
import unittest
import numpy as np
from distutils.version import StrictVersion
from pandas import DataFrame
from sklearn.datasets import make_classification
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import (
GradientBoostingClassifier,
GradientBoostingRegressor
)
from sklearn.model_selection import train_test_split
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType, Int64TensorType
from skl2onnx.common.data_types import onnx_built_with_ml
from test_utils import dump_binary_classification, dump_multiple_classification
from test_utils import dump_data_and_model, fit_regression_model
from onnxruntime import InferenceSession, __version__

threshold = "0.4.0"


class TestSklearnGradientBoostingModels(unittest.TestCase):

def setUp(self):
log = getLogger('skl2onnx')
log.disabled = True

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
@unittest.skipIf(
StrictVersion(__version__) <= StrictVersion("0.5.0"),
StrictVersion(__version__) <= StrictVersion(threshold),
reason="Depends on PR #1015 onnxruntime.")
def test_gradient_boosting_classifier1Deviance(self):
model = GradientBoostingClassifier(n_estimators=1, max_depth=2)
Expand All @@ -41,8 +52,10 @@ def test_gradient_boosting_classifier1Deviance(self):
sess = InferenceSession(model_onnx.SerializeToString())
res = sess.run(None, {'input': X.astype(np.float32)})
pred = model.predict_proba(X)
if res[1][0][0] != pred[0, 0]:
rows = ["X", str(X),
delta = abs(res[1][0][0] - pred[0, 0])
if delta > 1e-5:
rows = ["diff", str(delta),
"X", str(X),
"base_values_", str(model.init_.class_prior_),
"predicted_label", str(model.predict(X)),
"expected", str(pred),
Expand All @@ -52,7 +65,7 @@ def test_gradient_boosting_classifier1Deviance(self):
dump_binary_classification(
model, suffix="1Deviance",
allow_failure="StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.5.0')")
" <= StrictVersion('%s')" % threshold)

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -61,7 +74,7 @@ def test_gradient_boosting_classifier3(self):
dump_binary_classification(
model, suffix="3",
allow_failure="StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.5.0')")
" <= StrictVersion('%s')" % threshold)

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
Expand All @@ -70,7 +83,7 @@ def test_gradient_boosting_classifier_multi(self):
dump_multiple_classification(
model,
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.5.0')",
"<= StrictVersion('%s')" % threshold,
)

def test_gradient_boosting_regressor_ls_loss(self):
Expand Down Expand Up @@ -182,6 +195,25 @@ def test_gradient_boosting_regressor_zero_init(self):
" <= StrictVersion('0.2.1')"
)

@unittest.skipIf(
StrictVersion(__version__) <= StrictVersion(threshold),
reason="Depends on PR #1015 onnxruntime.")
def test_gradient_boosting_regressor_learning_rate(self):
X, y = make_classification(
n_features=100, n_samples=1000, n_classes=2, n_informative=8)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.5, random_state=42)
model = GradientBoostingClassifier().fit(X_train, y_train)
onnx_model = convert_sklearn(
model, 'lr2', [('input', FloatTensorType(X_test.shape))])
sess = InferenceSession(onnx_model.SerializeToString())
res = sess.run(None, input_feed={'input': X_test.astype(np.float32)})
r1 = np.mean(np.isclose(model.predict_proba(X_test),
list(map(lambda x: list(map(lambda y: x[y], x)),
res[1])), atol=1e-4))
r2 = np.mean(res[0] == model.predict(X_test))
assert r1 == r2


if __name__ == "__main__":
unittest.main()

0 comments on commit 33abc90

Please sign in to comment.