From a760bf00832820ac31bdf6d553aa17436749541a Mon Sep 17 00:00:00 2001 From: Prabhat Date: Fri, 21 Jun 2019 23:52:34 +0100 Subject: [PATCH] =?UTF-8?q?Support=20int=20features=20in=20BernoulliNB,=20?= =?UTF-8?q?MultinomialNB=20and=20SGDClassifier=20=E2=80=A6=20(#185)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Support int features in BernoulliNB, MultinomialNB and SGDClassifier converters * Avoid unit test failures --- skl2onnx/operator_converters/naive_bayes.py | 14 +- .../operator_converters/sgd_classifier.py | 11 +- tests/test_sklearn_naive_bayes_converter.py | 172 +++++++++++++++--- .../test_sklearn_sgd_classifier_converter.py | 135 +++++++++++--- tests/test_utils/__init__.py | 1 + tests/test_utils/tests_helper.py | 14 ++ 6 files changed, 287 insertions(+), 60 deletions(-) diff --git a/skl2onnx/operator_converters/naive_bayes.py b/skl2onnx/operator_converters/naive_bayes.py index dd59133b3..56c79bc77 100644 --- a/skl2onnx/operator_converters/naive_bayes.py +++ b/skl2onnx/operator_converters/naive_bayes.py @@ -8,6 +8,7 @@ from ..proto import onnx_proto from ..common._apply_operation import apply_add, apply_cast, apply_exp from ..common._apply_operation import apply_log, apply_sub, apply_reshape +from ..common.data_types import Int64TensorType from ..common._registration import register_converter @@ -207,11 +208,19 @@ def convert_sklearn_naive_bayes(scope, operator, container): else: sum_op_version = 8 + input_name = operator.inputs[0].full_name + if type(operator.inputs[0].type) == Int64TensorType: + cast_input_name = scope.get_unique_variable_name('cast_input') + + apply_cast(scope, operator.input_full_names, cast_input_name, + container, to=onnx_proto.TensorProto.FLOAT) + input_name = cast_input_name + if operator.type == 'SklearnMultinomialNB': matmul_result_name = scope.get_unique_variable_name('matmul_result') container.add_node( - 'MatMul', [operator.inputs[0].full_name, feature_log_prob_name], + 'MatMul', [input_name, feature_log_prob_name], matmul_result_name, name=scope.get_unique_operator_name('MatMul')) apply_add(scope, [matmul_result_name, class_log_prior_name], sum_result_name, container, broadcast=1) @@ -229,7 +238,6 @@ def convert_sklearn_naive_bayes(scope, operator, container): container.add_initializer(constant_name, onnx_proto.TensorProto.FLOAT, [], [1.0]) - input_name = operator.inputs[0].full_name if nb.binarize is not None: threshold_name = scope.get_unique_variable_name('threshold') @@ -250,7 +258,7 @@ def convert_sklearn_naive_bayes(scope, operator, container): np.zeros((M, num_features)).ravel()) container.add_node( - 'Greater', [operator.inputs[0].full_name, threshold_name], + 'Greater', [input_name, threshold_name], condition_name, name=scope.get_unique_operator_name('Greater'), op_version=9) apply_cast(scope, condition_name, cast_values_name, container, diff --git a/skl2onnx/operator_converters/sgd_classifier.py b/skl2onnx/operator_converters/sgd_classifier.py index c56784aa8..87f902fa6 100644 --- a/skl2onnx/operator_converters/sgd_classifier.py +++ b/skl2onnx/operator_converters/sgd_classifier.py @@ -8,6 +8,7 @@ from ..common._apply_operation import ( apply_add, apply_cast, apply_clip, apply_concat, apply_div, apply_exp, apply_identity, apply_mul, apply_reciprocal, apply_reshape, apply_sub) +from ..common.data_types import Int64TensorType from ..common._registration import register_converter from ..proto import onnx_proto @@ -28,8 +29,16 @@ def _decision_function(scope, operator, container, model): container.add_initializer(intercept_name, onnx_proto.TensorProto.FLOAT, model.intercept_.shape, model.intercept_) + input_name = operator.inputs[0].full_name + if type(operator.inputs[0].type) == Int64TensorType: + cast_input_name = scope.get_unique_variable_name('cast_input') + + apply_cast(scope, operator.input_full_names, cast_input_name, + container, to=onnx_proto.TensorProto.FLOAT) + input_name = cast_input_name + container.add_node( - 'MatMul', [operator.inputs[0].full_name, coef_name], + 'MatMul', [input_name, coef_name], matmul_result_name, name=scope.get_unique_operator_name('MatMul')) apply_add(scope, [matmul_result_name, intercept_name], diff --git a/tests/test_sklearn_naive_bayes_converter.py b/tests/test_sklearn_naive_bayes_converter.py index e8278e41d..399761d99 100644 --- a/tests/test_sklearn_naive_bayes_converter.py +++ b/tests/test_sklearn_naive_bayes_converter.py @@ -1,36 +1,20 @@ -import numpy as np import onnx import unittest from distutils.version import StrictVersion -from sklearn.datasets import load_digits, load_iris from sklearn.naive_bayes import MultinomialNB, BernoulliNB from skl2onnx import convert_sklearn -from skl2onnx.common.data_types import FloatTensorType +from skl2onnx.common.data_types import FloatTensorType, Int64TensorType from skl2onnx.common.data_types import onnx_built_with_ml -from test_utils import dump_data_and_model +from test_utils import dump_data_and_model, fit_classification_model class TestNaiveBayesConverter(unittest.TestCase): - def _fit_model_binary_classification(self, model, data): - X = data.data - y = data.target - mid_point = len(data.target_names) / 2 - y[y < mid_point] = 0 - y[y >= mid_point] = 1 - model.fit(X, y) - return model, X.astype(np.float32) - - def _fit_model_multiclass_classification(self, model, data): - X = data.data - y = data.target - model.fit(X, y) - return model, X.astype(np.float32) @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_multinomial_nb_binary_classification(self): - model, X = self._fit_model_binary_classification( - MultinomialNB(), load_iris()) + model, X = fit_classification_model( + MultinomialNB(), 2, pos_features=True) model_onnx = convert_sklearn( model, "multinomial naive bayes", @@ -41,7 +25,7 @@ def test_model_multinomial_nb_binary_classification(self): X, model, model_onnx, - basename="SklearnBinMultinomialNB", + basename="SklearnBinMultinomialNB-Dec4", allow_failure="StrictVersion(onnxruntime.__version__)" "<= StrictVersion('0.2.1')", ) @@ -53,8 +37,8 @@ def test_model_multinomial_nb_binary_classification(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_bernoulli_nb_binary_classification(self): - model, X = self._fit_model_binary_classification( - BernoulliNB(), load_digits()) + model, X = fit_classification_model( + BernoulliNB(), 2) model_onnx = convert_sklearn( model, "bernoulli naive bayes", @@ -73,8 +57,8 @@ def test_model_bernoulli_nb_binary_classification(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_multinomial_nb_multiclass(self): - model, X = self._fit_model_multiclass_classification( - MultinomialNB(), load_iris()) + model, X = fit_classification_model( + MultinomialNB(), 5, pos_features=True) model_onnx = convert_sklearn( model, "multinomial naive bayes", @@ -85,7 +69,27 @@ def test_model_multinomial_nb_multiclass(self): X, model, model_onnx, - basename="SklearnMclMultinomialNB", + basename="SklearnMclMultinomialNB-Dec4", + allow_failure="StrictVersion(onnxruntime.__version__)" + "<= StrictVersion('0.2.1')", + ) + + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_multinomial_nb_multiclass_params(self): + model, X = fit_classification_model( + MultinomialNB(alpha=0.5, fit_prior=False), 5, pos_features=True) + model_onnx = convert_sklearn( + model, + "multinomial naive bayes", + [("input", FloatTensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnMclMultinomialNBParams-Dec4", allow_failure="StrictVersion(onnxruntime.__version__)" "<= StrictVersion('0.2.1')", ) @@ -97,8 +101,8 @@ def test_model_multinomial_nb_multiclass(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_bernoulli_nb_multiclass(self): - model, X = self._fit_model_multiclass_classification( - BernoulliNB(), load_digits()) + model, X = fit_classification_model( + BernoulliNB(), 4) model_onnx = convert_sklearn( model, "bernoulli naive bayes", @@ -114,6 +118,118 @@ def test_model_bernoulli_nb_multiclass(self): "<= StrictVersion('0.2.1')", ) + @unittest.skipIf( + StrictVersion(onnx.__version__) <= StrictVersion("1.3"), + reason="Requires opset 9.", + ) + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_bernoulli_nb_multiclass_params(self): + model, X = fit_classification_model( + BernoulliNB(alpha=0, binarize=1.0, fit_prior=False), 4) + model_onnx = convert_sklearn( + model, + "bernoulli naive bayes", + [("input", FloatTensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnMclBernoulliNBParams", + allow_failure="StrictVersion(onnxruntime.__version__)" + "<= StrictVersion('0.2.1')", + ) + + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_multinomial_nb_binary_classification_int(self): + model, X = fit_classification_model( + MultinomialNB(), 2, is_int=True, pos_features=True) + model_onnx = convert_sklearn( + model, + "multinomial naive bayes", + [("input", Int64TensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnBinMultinomialNBInt-Dec4", + allow_failure="StrictVersion(onnxruntime.__version__)" + "<= StrictVersion('0.2.1')", + ) + + @unittest.skipIf( + StrictVersion(onnx.__version__) <= StrictVersion("1.3"), + reason="Requires opset 9.", + ) + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_bernoulli_nb_binary_classification_int(self): + model, X = fit_classification_model( + BernoulliNB(), 2, is_int=True) + model_onnx = convert_sklearn( + model, + "bernoulli naive bayes", + [("input", Int64TensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnBinBernoulliNBInt", + allow_failure="StrictVersion(onnxruntime.__version__)" + "<= StrictVersion('0.2.1')", + ) + + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_multinomial_nb_multiclass_int(self): + model, X = fit_classification_model( + MultinomialNB(), 5, is_int=True, pos_features=True) + model_onnx = convert_sklearn( + model, + "multinomial naive bayes", + [("input", Int64TensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnMclMultinomialNBInt-Dec4", + allow_failure="StrictVersion(onnxruntime.__version__)" + "<= StrictVersion('0.2.1')", + ) + + @unittest.skipIf( + StrictVersion(onnx.__version__) <= StrictVersion("1.3"), + reason="Requires opset 9.", + ) + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_bernoulli_nb_multiclass_int(self): + model, X = fit_classification_model( + BernoulliNB(), 4, is_int=True) + model_onnx = convert_sklearn( + model, + "bernoulli naive bayes", + [("input", Int64TensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnMclBernoulliNBInt-Dec4", + allow_failure="StrictVersion(onnxruntime.__version__)" + "<= StrictVersion('0.2.1')", + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_sklearn_sgd_classifier_converter.py b/tests/test_sklearn_sgd_classifier_converter.py index d187ca270..e052d5f86 100644 --- a/tests/test_sklearn_sgd_classifier_converter.py +++ b/tests/test_sklearn_sgd_classifier_converter.py @@ -2,29 +2,19 @@ import unittest import numpy as np -from sklearn.datasets import make_classification from sklearn.linear_model import SGDClassifier -from sklearn.model_selection import train_test_split from skl2onnx import convert_sklearn -from skl2onnx.common.data_types import FloatTensorType +from skl2onnx.common.data_types import FloatTensorType, Int64TensorType from skl2onnx.common.data_types import onnx_built_with_ml -from test_utils import dump_data_and_model +from test_utils import dump_data_and_model, fit_classification_model class TestSGDClassifierConverter(unittest.TestCase): - def _fit_model_classification(self, model, n_classes): - X, y = make_classification(n_classes=n_classes, n_features=100, - n_samples=10000, - random_state=42, n_informative=5) - X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.5, - random_state=42) - model.fit(X_train, y_train) - return model, X_test @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_binary_class_hinge(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='hinge', random_state=42), 2) model_onnx = convert_sklearn( model, @@ -46,7 +36,7 @@ def test_model_sgd_binary_class_hinge(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_multi_class_hinge(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='hinge', random_state=42), 5) model_onnx = convert_sklearn( model, @@ -68,7 +58,7 @@ def test_model_sgd_multi_class_hinge(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_binary_class_log(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='log', random_state=42), 2) model_onnx = convert_sklearn( model, @@ -90,14 +80,14 @@ def test_model_sgd_binary_class_log(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_multi_class_log(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='log', random_state=42), 5) model_onnx = convert_sklearn( model, "scikit-learn SGD multi-class classifier", [("input", FloatTensorType(X.shape))], ) - X = X[1:3] + X = np.array([X[1], X[1]]) self.assertIsNotNone(model_onnx) dump_data_and_model( X.astype(np.float32), @@ -113,7 +103,7 @@ def test_model_sgd_multi_class_log(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_binary_class_log_l1_no_intercept(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='log', penalty='l1', fit_intercept=False, random_state=42), 2) model_onnx = convert_sklearn( @@ -136,10 +126,10 @@ def test_model_sgd_binary_class_log_l1_no_intercept(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_multi_class_log_l1_no_intercept(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='log', penalty='l1', fit_intercept=False, random_state=42), 5) - X = X[1:3] + X = np.array([X[4], X[4]]) model_onnx = convert_sklearn( model, "scikit-learn SGD multi-class classifier", @@ -150,7 +140,7 @@ def test_model_sgd_multi_class_log_l1_no_intercept(self): X.astype(np.float32), model, model_onnx, - basename="SklearnSGDClassifierMultiLogL1NoIntercept", + basename="SklearnSGDClassifierMultiLogL1NoIntercept-Dec4", allow_failure="StrictVersion(onnx.__version__)" " < StrictVersion('1.2') or " "StrictVersion(onnxruntime.__version__)" @@ -160,7 +150,7 @@ def test_model_sgd_multi_class_log_l1_no_intercept(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_binary_class_elasticnet_power_t(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(penalty='elasticnet', l1_ratio=0.3, power_t=2, random_state=42), 2) model_onnx = convert_sklearn( @@ -183,7 +173,7 @@ def test_model_sgd_binary_class_elasticnet_power_t(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_multi_class_elasticnet_power_t(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(penalty='elasticnet', l1_ratio=0.3, power_t=2, random_state=42), 5) model_onnx = convert_sklearn( @@ -206,7 +196,7 @@ def test_model_sgd_multi_class_elasticnet_power_t(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_binary_class_modified_huber(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='modified_huber', random_state=42), 2) model_onnx = convert_sklearn( model, @@ -228,7 +218,7 @@ def test_model_sgd_binary_class_modified_huber(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_binary_class_squared_hinge(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='squared_hinge', random_state=42), 2) model_onnx = convert_sklearn( model, @@ -250,7 +240,7 @@ def test_model_sgd_binary_class_squared_hinge(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_multi_class_squared_hinge(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='squared_hinge', random_state=42), 5) model_onnx = convert_sklearn( model, @@ -272,7 +262,7 @@ def test_model_sgd_multi_class_squared_hinge(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_binary_class_perceptron(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='perceptron', random_state=42), 2) model_onnx = convert_sklearn( model, @@ -294,7 +284,7 @@ def test_model_sgd_binary_class_perceptron(self): @unittest.skipIf(not onnx_built_with_ml(), reason="Requires ONNX-ML extension.") def test_model_sgd_multi_class_preceptron(self): - model, X = self._fit_model_classification( + model, X = fit_classification_model( SGDClassifier(loss='perceptron', random_state=42), 5) model_onnx = convert_sklearn( model, @@ -313,6 +303,95 @@ def test_model_sgd_multi_class_preceptron(self): " <= StrictVersion('0.2.1')", ) + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_sgd_binary_class_hinge_int(self): + model, X = fit_classification_model( + SGDClassifier(loss='hinge', random_state=42), 2, is_int=True) + model_onnx = convert_sklearn( + model, + "scikit-learn SGD binary classifier", + [("input", Int64TensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnSGDClassifierBinaryHingeInt-Out0", + allow_failure="StrictVersion(onnx.__version__)" + " < StrictVersion('1.2') or " + "StrictVersion(onnxruntime.__version__)" + " <= StrictVersion('0.2.1')", + ) + + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_sgd_multi_class_hinge_int(self): + model, X = fit_classification_model( + SGDClassifier(loss='hinge', random_state=42), 5, is_int=True) + model_onnx = convert_sklearn( + model, + "scikit-learn SGD multi-class classifier", + [("input", Int64TensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnSGDClassifierMultiHingeInt-Out0", + allow_failure="StrictVersion(onnx.__version__)" + " < StrictVersion('1.2') or " + "StrictVersion(onnxruntime.__version__)" + " <= StrictVersion('0.2.1')", + ) + + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_sgd_binary_class_log_int(self): + model, X = fit_classification_model( + SGDClassifier(loss='log', random_state=42), 2, is_int=True) + model_onnx = convert_sklearn( + model, + "scikit-learn SGD binary classifier", + [("input", Int64TensorType(X.shape))], + ) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnSGDClassifierBinaryLogInt", + allow_failure="StrictVersion(onnx.__version__)" + " < StrictVersion('1.2') or " + "StrictVersion(onnxruntime.__version__)" + " <= StrictVersion('0.2.1')", + ) + + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_model_sgd_multi_class_log_int(self): + model, X = fit_classification_model( + SGDClassifier(loss='log', random_state=42), 5, is_int=True) + model_onnx = convert_sklearn( + model, + "scikit-learn SGD multi-class classifier", + [("input", Int64TensorType(X.shape))], + ) + X = np.array([X[0], X[0]]) + self.assertIsNotNone(model_onnx) + dump_data_and_model( + X, + model, + model_onnx, + basename="SklearnSGDClassifierMultiLogInt", + allow_failure="StrictVersion(onnx.__version__)" + " < StrictVersion('1.2') or " + "StrictVersion(onnxruntime.__version__)" + " <= StrictVersion('0.2.1')", + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 1f3013de1..11ce19d8c 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -16,6 +16,7 @@ dump_multiple_regression, dump_single_regression, convert_model, + fit_classification_model, fit_regression_model, ) diff --git a/tests/test_utils/tests_helper.py b/tests/test_utils/tests_helper.py index eb3c7dff2..f9f5a2daa 100644 --- a/tests/test_utils/tests_helper.py +++ b/tests/test_utils/tests_helper.py @@ -47,6 +47,20 @@ def _has_transform_model(model): return hasattr(model, "fit_transform") and hasattr(model, "score") +def fit_classification_model(model, n_classes, is_int=False, + pos_features=False): + X, y = make_classification(n_classes=n_classes, n_features=100, + n_samples=1000, + random_state=42, n_informative=5) + X = X.astype(numpy.int64) if is_int else X.astype(numpy.float32) + if pos_features: + X = numpy.abs(X) + X_train, X_test, y_train, _ = train_test_split(X, y, test_size=0.5, + random_state=42) + model.fit(X_train, y_train) + return model, X_test + + def fit_regression_model(model, is_int=False): X, y = make_regression(n_features=10, n_samples=1000, random_state=42) X = X.astype(numpy.int64) if is_int else X.astype(numpy.float32)