Skip to content

Commit

Permalink
Fixed label binariser output for binary dataset to align with scikit (#…
Browse files Browse the repository at this point in the history
…228)

* Fixed label binariser output for binary dataset to align with scikit

* Supress ScatterElements related build error

* Supress Unique op related build error

* Supress GatherElements op related build error

* Suppress the error in the build for now
  • Loading branch information
Prabhat committed Aug 6, 2019
1 parent a889074 commit d026f89
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
17 changes: 16 additions & 1 deletion skl2onnx/operator_converters/label_binariser.py
Expand Up @@ -52,7 +52,22 @@ def convert_sklearn_label_binariser(scope, operator, container):
[equal_condition_tensor_name, unit_tensor_name, zeros_tensor_name],
where_result_name,
name=scope.get_unique_operator_name('where'))
apply_cast(scope, where_result_name, operator.output_full_names, container,
where_res = where_result_name
if len(binariser_op.classes_) == 2:
array_feature_extractor_result_name = scope.get_unique_variable_name(
'array_feature_extractor_result')
pos_class_index_name = scope.get_unique_variable_name(
'pos_class_index')

container.add_initializer(
pos_class_index_name, onnx_proto.TensorProto.INT64, [], [1])

container.add_node(
'ArrayFeatureExtractor', [where_result_name, pos_class_index_name],
array_feature_extractor_result_name, op_domain='ai.onnx.ml',
name=scope.get_unique_operator_name('ArrayFeatureExtractor'))
where_res = array_feature_extractor_result_name
apply_cast(scope, where_res, operator.output_full_names, container,
to=onnx_proto.TensorProto.INT64)


Expand Down
5 changes: 4 additions & 1 deletion tests/test_algebra_meta_onnx.py
Expand Up @@ -32,7 +32,7 @@ def test_mul(self):
@unittest.skipIf(StrictVersion(onnx.__version__) < StrictVersion("1.5.0"),
reason="too unstable with older versions")
@unittest.skipIf(StrictVersion(onnxruntime.__version__) <
StrictVersion("0.4.0"),
StrictVersion("0.5.0"),
reason="too unstable with older versions")
def test_onnx_spec(self):
untested = {'AveragePool', # issue with ceil_mode
Expand All @@ -48,6 +48,7 @@ def test_onnx_spec(self):
'DequantizeLinear',
'Equal', # opset 11
'Expand', # shape inference fails
'GatherElements', # opset 11
'MatMulInteger',
'MaxPool', # issue with ceil_mode
'Mod',
Expand All @@ -58,6 +59,8 @@ def test_onnx_spec(self):
'Scan', # Graph attribute inferencing returned type
# information for 2 outputs. Expected 1
# Node () has input size 5 not in range [min=1, max=1].
'ScatterElements', # opset 11
'Unique', # opset 11
"Upsample",
}
folder = os.path.dirname(onnx.__file__)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_sklearn_label_binariser_converter.py
Expand Up @@ -87,6 +87,25 @@ def test_model_label_binariser_neg_pos_label(self):
"<= StrictVersion('0.2.1')",
)

def test_model_label_binariser_binary_labels(self):
X = np.array([1, 0, 0, 0, 1])
model = LabelBinarizer().fit(X)
model_onnx = convert_sklearn(
model,
"scikit-learn label binariser",
[("input", Int64TensorType([None]))],
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X.astype(np.int64),
model,
model_onnx,
basename="SklearnLabelBinariserBinaryLabels",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit d026f89

Please sign in to comment.