Skip to content

Commit

Permalink
Support feature input types other than float in function transformer (#…
Browse files Browse the repository at this point in the history
…208)

* Support feature input types other than float in function transformer converter
* Remove unused import
  • Loading branch information
Prabhat authored and xadupre committed Jul 3, 2019
1 parent ad1987d commit 65951d5
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
11 changes: 6 additions & 5 deletions skl2onnx/shape_calculators/function_transformer.py
Expand Up @@ -4,15 +4,15 @@
# license information.
# --------------------------------------------------------------------------

import copy
from ..common._registration import register_shape_calculator
from ..common.data_types import FloatTensorType


def calculate_sklearn_function_transformer_output_shapes(operator):
'''
"""
This operator is used only to merge columns in a pipeline.
Only id function is supported.
'''
Only identity function is supported.
"""
if operator.raw_operator.func is not None:
raise RuntimeError("FunctionTransformer is not supported unless the "
"transform function is None (= identity). "
Expand All @@ -27,7 +27,8 @@ def calculate_sklearn_function_transformer_output_shapes(operator):
C = 'None'
break

operator.outputs[0].type = FloatTensorType([N, C])
operator.outputs[0].type = copy.deepcopy(operator.inputs[0].type)
operator.outputs[0].type.shape = [N, C]


register_shape_calculator('SklearnFunctionTransformer',
Expand Down
7 changes: 4 additions & 3 deletions tests/test_sklearn_function_transformer_converter.py
Expand Up @@ -2,8 +2,9 @@
Tests scikit-imputer converter.
"""
import unittest
import numpy as np
import pandas
from sklearn.datasets import load_iris
from sklearn.datasets import load_digits, load_iris
from sklearn.pipeline import Pipeline

try:
Expand Down Expand Up @@ -44,10 +45,10 @@ def convert_dataframe_schema(df, drop=None):
inputs.append((k, t))
return inputs

data = load_iris()
data = load_digits()
X = data.data[:, :2]
y = data.target
data = pandas.DataFrame(X, columns=["X1", "X2"])
data = pandas.DataFrame(X, columns=["X1", "X2"], dtype=np.int64)

pipe = Pipeline(steps=[
(
Expand Down

0 comments on commit 65951d5

Please sign in to comment.