Skip to content

Commit

Permalink
Extend coverage for OneHotEncoder (#188)
Browse files Browse the repository at this point in the history
* Update test_sklearn_one_hot_encoder_converter.py
  • Loading branch information
xadupre committed Jul 3, 2019
1 parent a64cc5b commit 064eb31
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
2 changes: 2 additions & 0 deletions skl2onnx/operator_converters/one_hot_encoder.py
Expand Up @@ -46,6 +46,8 @@ def convert_sklearn_one_hot_encoder(scope, operator, container):
raise TypeError("Categories must be int or strings "
"not {0}.".format(cat.dtype))
else:
# Relies on n_values: deprecated in 0.20,
# removed in 0.22.
if op.categorical_features == 'all':
categorical_feature_indices = [i for i in range(C)]
elif isinstance(op.categorical_features, collections.Iterable):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_sklearn_one_hot_encoder_converter.py
Expand Up @@ -155,6 +155,56 @@ def test_one_hot_encoder_one_string_one_int_cat(self):
basename="SklearnOneHotEncoderOneStringOneIntCat",
)

@unittest.skipIf(
not one_hot_encoder_supports_string(),
reason="OneHotEncoder does not support this in 0.19",
)
def test_model_one_hot_encoder_list_sparse(self):
model = OneHotEncoder(categories=[[0, 1, 4, 5],
[1, 2, 3, 5],
[0, 3, 4, 6]],
sparse=True)
data = numpy.array([[1, 2, 3], [4, 3, 0], [0, 1, 4], [0, 5, 6]],
dtype=numpy.int64)
model.fit(data)
model_onnx = convert_sklearn(
model,
"scikit-learn one-hot encoder",
[("input", Int64TensorType([1, 3]))],
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
data,
model,
model_onnx,
basename="SklearnOneHotEncoderCatSparse-SkipDim1",
)

@unittest.skipIf(
not one_hot_encoder_supports_string(),
reason="OneHotEncoder does not support this in 0.19",
)
def test_model_one_hot_encoder_list_dense(self):
model = OneHotEncoder(categories=[[0, 1, 4, 5],
[1, 2, 3, 5],
[0, 3, 4, 6]],
sparse=False)
data = numpy.array([[1, 2, 3], [4, 3, 0], [0, 1, 4], [0, 5, 6]],
dtype=numpy.int64)
model.fit(data)
model_onnx = convert_sklearn(
model,
"scikit-learn one-hot encoder",
[("input", Int64TensorType([1, 3]))],
)
self.assertTrue(model_onnx is not None)
dump_data_and_model(
data,
model,
model_onnx,
basename="SklearnOneHotEncoderCatDense-SkipDim1",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 064eb31

Please sign in to comment.