Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SVC output #209

Merged
merged 7 commits into from
Jul 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def _parse_sklearn_column_transformer(scope, model, inputs,
def _parse_sklearn_classifier(scope, model, inputs, custom_parsers=None):
probability_tensor = _parse_sklearn_simple_model(
scope, model, inputs, custom_parsers=custom_parsers)
if model.__class__ in [NuSVC, SVC] and not model.probability:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No unit test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return probability_tensor
this_operator = scope.declare_local_operator('SklearnZipMap')
this_operator.inputs = probability_tensor
classes = model.classes_
Expand Down Expand Up @@ -387,7 +389,7 @@ def build_sklearn_parsers_map():
map_parser[ColumnTransformer] = _parse_sklearn_column_transformer

for tmodel in sklearn_classifier_list:
if tmodel not in [LinearSVC, SVC, NuSVC]:
if tmodel not in [LinearSVC]:
map_parser[tmodel] = _parse_sklearn_classifier
return map_parser

Expand Down
1 change: 1 addition & 0 deletions tests/test_algebra_meta_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_onnx_spec(self):
'BitShift', # opset 11
'Cast', # unsupported type
'Compress', # shape inference fails
'CumSum', # opset 11
# Input X must be 4-dimensional. X: {1,1,3}
'ConvInteger',
'ConvTranspose',
Expand Down
14 changes: 13 additions & 1 deletion tests/test_sklearn_svm_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from skl2onnx.operator_converters.support_vector_machines import (
convert_sklearn_svm)
from skl2onnx.shape_calculators.svm import calculate_sklearn_svm_output_shapes
from onnxruntime import __version__ as ort_version
from test_utils import dump_data_and_model, fit_regression_model


Expand Down Expand Up @@ -149,6 +150,8 @@ def test_convert_svc_binary_linear_ptrue(self):
model,
model_onnx,
basename="SklearnBinSVCLinearPT",
allow_failure="StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.4.0')"
)

def test_convert_svc_multi_linear_pfalse(self):
Expand Down Expand Up @@ -206,7 +209,7 @@ def test_convert_svc_multi_linear_ptrue(self):
model_onnx,
basename="SklearnMclSVCLinearPT-Dec4",
allow_failure="StrictVersion(onnxruntime.__version__)"
" < StrictVersion('0.5.0')"
" <= StrictVersion('0.4.0')"
)

def test_convert_svr_linear(self):
Expand Down Expand Up @@ -259,6 +262,9 @@ def test_convert_nusvc_binary_pfalse(self):
" < StrictVersion('0.5.0')"
)

@unittest.skipIf(
StrictVersion(ort_version) <= StrictVersion("0.4.0"),
reason="use of recent Cast operator")
def test_convert_nusvc_binary_ptrue(self):
model, X = self._fit_binary_classification(NuSVC(probability=True))
model_onnx = convert_sklearn(
Expand All @@ -283,6 +289,8 @@ def test_convert_nusvc_binary_ptrue(self):
model,
model_onnx,
basename="SklearnBinNuSVCPT",
allow_failure="StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.4.0')"
)

def test_convert_nusvc_multi_pfalse(self):
Expand Down Expand Up @@ -361,6 +369,8 @@ def test_convert_svc_multi_ptrue_4(self):
model,
model_onnx,
basename="SklearnMcSVCPF4",
allow_failure="StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.4.0')"
)

def test_convert_nusvc_multi_ptrue(self):
Expand Down Expand Up @@ -388,6 +398,8 @@ def test_convert_nusvc_multi_ptrue(self):
model,
model_onnx,
basename="SklearnMclNuSVCPT",
allow_failure="StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.4.0')"
)

def test_convert_nusvr(self):
Expand Down