Skip to content

Commit

Permalink
Support int features in BernoulliNB, MultinomialNB and SGDClassifier … (
Browse files Browse the repository at this point in the history
#185)

* Support int features in BernoulliNB, MultinomialNB and SGDClassifier converters

* Avoid unit test failures
  • Loading branch information
Prabhat committed Jun 21, 2019
1 parent 6fc5a0e commit a760bf0
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 60 deletions.
14 changes: 11 additions & 3 deletions skl2onnx/operator_converters/naive_bayes.py
Expand Up @@ -8,6 +8,7 @@
from ..proto import onnx_proto
from ..common._apply_operation import apply_add, apply_cast, apply_exp
from ..common._apply_operation import apply_log, apply_sub, apply_reshape
from ..common.data_types import Int64TensorType
from ..common._registration import register_converter


Expand Down Expand Up @@ -207,11 +208,19 @@ def convert_sklearn_naive_bayes(scope, operator, container):
else:
sum_op_version = 8

input_name = operator.inputs[0].full_name
if type(operator.inputs[0].type) == Int64TensorType:
cast_input_name = scope.get_unique_variable_name('cast_input')

apply_cast(scope, operator.input_full_names, cast_input_name,
container, to=onnx_proto.TensorProto.FLOAT)
input_name = cast_input_name

if operator.type == 'SklearnMultinomialNB':
matmul_result_name = scope.get_unique_variable_name('matmul_result')

container.add_node(
'MatMul', [operator.inputs[0].full_name, feature_log_prob_name],
'MatMul', [input_name, feature_log_prob_name],
matmul_result_name, name=scope.get_unique_operator_name('MatMul'))
apply_add(scope, [matmul_result_name, class_log_prior_name],
sum_result_name, container, broadcast=1)
Expand All @@ -229,7 +238,6 @@ def convert_sklearn_naive_bayes(scope, operator, container):

container.add_initializer(constant_name, onnx_proto.TensorProto.FLOAT,
[], [1.0])
input_name = operator.inputs[0].full_name

if nb.binarize is not None:
threshold_name = scope.get_unique_variable_name('threshold')
Expand All @@ -250,7 +258,7 @@ def convert_sklearn_naive_bayes(scope, operator, container):
np.zeros((M, num_features)).ravel())

container.add_node(
'Greater', [operator.inputs[0].full_name, threshold_name],
'Greater', [input_name, threshold_name],
condition_name, name=scope.get_unique_operator_name('Greater'),
op_version=9)
apply_cast(scope, condition_name, cast_values_name, container,
Expand Down
11 changes: 10 additions & 1 deletion skl2onnx/operator_converters/sgd_classifier.py
Expand Up @@ -8,6 +8,7 @@
from ..common._apply_operation import (
apply_add, apply_cast, apply_clip, apply_concat, apply_div, apply_exp,
apply_identity, apply_mul, apply_reciprocal, apply_reshape, apply_sub)
from ..common.data_types import Int64TensorType
from ..common._registration import register_converter
from ..proto import onnx_proto

Expand All @@ -28,8 +29,16 @@ def _decision_function(scope, operator, container, model):
container.add_initializer(intercept_name, onnx_proto.TensorProto.FLOAT,
model.intercept_.shape, model.intercept_)

input_name = operator.inputs[0].full_name
if type(operator.inputs[0].type) == Int64TensorType:
cast_input_name = scope.get_unique_variable_name('cast_input')

apply_cast(scope, operator.input_full_names, cast_input_name,
container, to=onnx_proto.TensorProto.FLOAT)
input_name = cast_input_name

container.add_node(
'MatMul', [operator.inputs[0].full_name, coef_name],
'MatMul', [input_name, coef_name],
matmul_result_name,
name=scope.get_unique_operator_name('MatMul'))
apply_add(scope, [matmul_result_name, intercept_name],
Expand Down
172 changes: 144 additions & 28 deletions tests/test_sklearn_naive_bayes_converter.py
@@ -1,36 +1,20 @@
import numpy as np
import onnx
import unittest
from distutils.version import StrictVersion
from sklearn.datasets import load_digits, load_iris
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx.common.data_types import FloatTensorType, Int64TensorType
from skl2onnx.common.data_types import onnx_built_with_ml
from test_utils import dump_data_and_model
from test_utils import dump_data_and_model, fit_classification_model


class TestNaiveBayesConverter(unittest.TestCase):
def _fit_model_binary_classification(self, model, data):
X = data.data
y = data.target
mid_point = len(data.target_names) / 2
y[y < mid_point] = 0
y[y >= mid_point] = 1
model.fit(X, y)
return model, X.astype(np.float32)

def _fit_model_multiclass_classification(self, model, data):
X = data.data
y = data.target
model.fit(X, y)
return model, X.astype(np.float32)

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_multinomial_nb_binary_classification(self):
model, X = self._fit_model_binary_classification(
MultinomialNB(), load_iris())
model, X = fit_classification_model(
MultinomialNB(), 2, pos_features=True)
model_onnx = convert_sklearn(
model,
"multinomial naive bayes",
Expand All @@ -41,7 +25,7 @@ def test_model_multinomial_nb_binary_classification(self):
X,
model,
model_onnx,
basename="SklearnBinMultinomialNB",
basename="SklearnBinMultinomialNB-Dec4",
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)
Expand All @@ -53,8 +37,8 @@ def test_model_multinomial_nb_binary_classification(self):
@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_bernoulli_nb_binary_classification(self):
model, X = self._fit_model_binary_classification(
BernoulliNB(), load_digits())
model, X = fit_classification_model(
BernoulliNB(), 2)
model_onnx = convert_sklearn(
model,
"bernoulli naive bayes",
Expand All @@ -73,8 +57,8 @@ def test_model_bernoulli_nb_binary_classification(self):
@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_multinomial_nb_multiclass(self):
model, X = self._fit_model_multiclass_classification(
MultinomialNB(), load_iris())
model, X = fit_classification_model(
MultinomialNB(), 5, pos_features=True)
model_onnx = convert_sklearn(
model,
"multinomial naive bayes",
Expand All @@ -85,7 +69,27 @@ def test_model_multinomial_nb_multiclass(self):
X,
model,
model_onnx,
basename="SklearnMclMultinomialNB",
basename="SklearnMclMultinomialNB-Dec4",
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_multinomial_nb_multiclass_params(self):
model, X = fit_classification_model(
MultinomialNB(alpha=0.5, fit_prior=False), 5, pos_features=True)
model_onnx = convert_sklearn(
model,
"multinomial naive bayes",
[("input", FloatTensorType(X.shape))],
)
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X,
model,
model_onnx,
basename="SklearnMclMultinomialNBParams-Dec4",
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)
Expand All @@ -97,8 +101,8 @@ def test_model_multinomial_nb_multiclass(self):
@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_bernoulli_nb_multiclass(self):
model, X = self._fit_model_multiclass_classification(
BernoulliNB(), load_digits())
model, X = fit_classification_model(
BernoulliNB(), 4)
model_onnx = convert_sklearn(
model,
"bernoulli naive bayes",
Expand All @@ -114,6 +118,118 @@ def test_model_bernoulli_nb_multiclass(self):
"<= StrictVersion('0.2.1')",
)

@unittest.skipIf(
StrictVersion(onnx.__version__) <= StrictVersion("1.3"),
reason="Requires opset 9.",
)
@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_bernoulli_nb_multiclass_params(self):
model, X = fit_classification_model(
BernoulliNB(alpha=0, binarize=1.0, fit_prior=False), 4)
model_onnx = convert_sklearn(
model,
"bernoulli naive bayes",
[("input", FloatTensorType(X.shape))],
)
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X,
model,
model_onnx,
basename="SklearnMclBernoulliNBParams",
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_multinomial_nb_binary_classification_int(self):
model, X = fit_classification_model(
MultinomialNB(), 2, is_int=True, pos_features=True)
model_onnx = convert_sklearn(
model,
"multinomial naive bayes",
[("input", Int64TensorType(X.shape))],
)
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X,
model,
model_onnx,
basename="SklearnBinMultinomialNBInt-Dec4",
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

@unittest.skipIf(
StrictVersion(onnx.__version__) <= StrictVersion("1.3"),
reason="Requires opset 9.",
)
@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_bernoulli_nb_binary_classification_int(self):
model, X = fit_classification_model(
BernoulliNB(), 2, is_int=True)
model_onnx = convert_sklearn(
model,
"bernoulli naive bayes",
[("input", Int64TensorType(X.shape))],
)
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X,
model,
model_onnx,
basename="SklearnBinBernoulliNBInt",
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_multinomial_nb_multiclass_int(self):
model, X = fit_classification_model(
MultinomialNB(), 5, is_int=True, pos_features=True)
model_onnx = convert_sklearn(
model,
"multinomial naive bayes",
[("input", Int64TensorType(X.shape))],
)
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X,
model,
model_onnx,
basename="SklearnMclMultinomialNBInt-Dec4",
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

@unittest.skipIf(
StrictVersion(onnx.__version__) <= StrictVersion("1.3"),
reason="Requires opset 9.",
)
@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_model_bernoulli_nb_multiclass_int(self):
model, X = fit_classification_model(
BernoulliNB(), 4, is_int=True)
model_onnx = convert_sklearn(
model,
"bernoulli naive bayes",
[("input", Int64TensorType(X.shape))],
)
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X,
model,
model_onnx,
basename="SklearnMclBernoulliNBInt-Dec4",
allow_failure="StrictVersion(onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit a760bf0

Please sign in to comment.