Skip to content

Commit

Permalink
Fixed LabelEncoder converter (#224)
Browse files Browse the repository at this point in the history
* Fixed LabelEncoder converter

* Fixed LabelEncoder op_version and typos

* Allow unit test failure in onnxruntime <= 0.5.0
  • Loading branch information
Prabhat committed Aug 19, 2019
1 parent 07fcd67 commit d313123
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 17 deletions.
20 changes: 9 additions & 11 deletions skl2onnx/operator_converters/label_encoder.py
Expand Up @@ -4,28 +4,26 @@
# license information.
# --------------------------------------------------------------------------

from ..common.data_types import StringTensorType, Int64TensorType
import numpy as np
from ..common._registration import register_converter


def convert_sklearn_label_encoder(scope, operator, container):
op = operator.raw_operator
op_type = 'LabelEncoder'
attrs = {'name': scope.get_unique_operator_name(op_type)}
attrs['classes_strings'] = [str(c) for c in op.classes_]

if isinstance(operator.inputs[0].type, Int64TensorType):
attrs['default_int64'] = -1
elif isinstance(operator.inputs[0].type, StringTensorType):
attrs['default_string'] = '__unknown__'
classes = op.classes_
if np.issubdtype(classes.dtype, np.floating):
attrs['keys_floats'] = classes
elif np.issubdtype(classes.dtype, np.signedinteger):
attrs['keys_int64s'] = classes
else:
raise RuntimeError(
'Unsupported input type: %s. It must be int64 or dtring.'
'' % type(operator.inputs[0].type))
attrs['keys_strings'] = np.array([s.encode('utf-8') for s in classes])
attrs['values_int64s'] = np.arange(len(classes))

container.add_node(op_type, operator.input_full_names,
operator.output_full_names, op_domain='ai.onnx.ml',
**attrs)
op_version=2, **attrs)


register_converter('SklearnLabelEncoder', convert_sklearn_label_encoder)
4 changes: 3 additions & 1 deletion skl2onnx/shape_calculators/label_encoder.py
Expand Up @@ -6,6 +6,7 @@

import copy
from ..common._registration import register_shape_calculator
from ..common.data_types import FloatTensorType
from ..common.data_types import Int64TensorType, StringTensorType
from ..common.utils import check_input_and_output_numbers
from ..common.utils import check_input_and_output_types
Expand All @@ -18,7 +19,8 @@ def calculate_sklearn_label_encoder_output_shapes(operator):
"""
check_input_and_output_numbers(operator, output_count_range=1)
check_input_and_output_types(operator, good_input_types=[
Int64TensorType, StringTensorType])
FloatTensorType, Int64TensorType,
StringTensorType])

input_shape = copy.deepcopy(operator.inputs[0].type.shape)
operator.outputs[0].type = Int64TensorType(copy.deepcopy(input_shape))
Expand Down
54 changes: 49 additions & 5 deletions tests/test_sklearn_label_encoder_converter.py
@@ -1,8 +1,7 @@
"""
Tests scikit-labebencoder converter.
"""
"""Tests scikit-LabelEncoder converter"""

import unittest
import numpy
import numpy as np
from sklearn.preprocessing import LabelEncoder
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import StringTensorType
Expand All @@ -22,10 +21,55 @@ def test_model_label_encoder(self):
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
dump_data_and_model(
numpy.array(data),
np.array(data),
model,
model_onnx,
basename="SklearnLabelEncoder",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.5.0')",
)

def test_model_label_encoder_float(self):
model = LabelEncoder()
data = np.array([1.2, 3.4, 5.4, 1.2])
model.fit(data)
model_onnx = convert_sklearn(
model,
"scikit-learn label encoder",
[("input", StringTensorType([1, 1]))],
)
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
dump_data_and_model(
data,
model,
model_onnx,
basename="SklearnLabelEncoderFloat",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.5.0')",
)

def test_model_label_encoder_int(self):
model = LabelEncoder()
data = np.array([10, 3, 5, -34, 0])
model.fit(data)
model_onnx = convert_sklearn(
model,
"scikit-learn label encoder",
[("input", StringTensorType([1, 1]))],
)
self.assertTrue(model_onnx is not None)
self.assertTrue(model_onnx.graph.node is not None)
dump_data_and_model(
data,
model,
model_onnx,
basename="SklearnLabelEncoderInt",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.5.0')",
)


Expand Down

0 comments on commit d313123

Please sign in to comment.