Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle transformer_weights hyper-parameter in FeatureUnion #222

Merged
merged 3 commits into from
Jul 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions skl2onnx/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ def _parse_sklearn_feature_union(scope, model, inputs, custom_parsers=None):
_parse_sklearn_simple_model(
scope, transform, inputs,
custom_parsers=custom_parsers)[0])
if (model.transformer_weights is not None and name in
model.transformer_weights):
transform_result = [transformed_result_names.pop()]
# Create a Multiply ONNX node
multiply_operator = scope.declare_local_operator('SklearnMultiply')
multiply_operator.inputs = transform_result
multiply_operator.operand = model.transformer_weights[name]
multiply_output = scope.declare_local_variable(
'multiply_output', FloatTensorType())
multiply_operator.outputs.append(multiply_output)
transformed_result_names.append(multiply_operator.outputs[0])

# Create a Concat ONNX node
concat_operator = scope.declare_local_operator('SklearnConcat')
concat_operator.inputs = transformed_result_names
Expand Down
2 changes: 1 addition & 1 deletion skl2onnx/common/_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def register_converter(operator_name, conversion_function, overwrite=False):
to enable overwriting.
"""
if not overwrite and operator_name in _converter_pool:
raise ValueError('We do not overwrite registrated converter '
raise ValueError('We do not overwrite registered converter '
'by default')
_converter_pool[operator_name] = conversion_function

Expand Down
2 changes: 2 additions & 0 deletions skl2onnx/operator_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from . import linear_classifier
from . import linear_regressor
from . import multilayer_perceptron
from . import multiply_op
from . import naive_bayes
from . import nearest_neighbours
from . import normaliser
Expand Down Expand Up @@ -64,6 +65,7 @@
linear_classifier,
linear_regressor,
multilayer_perceptron,
multiply_op,
naive_bayes,
nearest_neighbours,
normaliser,
Expand Down
23 changes: 23 additions & 0 deletions skl2onnx/operator_converters/multiply_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from ..common._apply_operation import apply_mul
from ..common._registration import register_converter
from ..proto import onnx_proto


def convert_sklearn_multiply(scope, operator, container):
prabhat00155 marked this conversation as resolved.
Show resolved Hide resolved
operand_name = scope.get_unique_variable_name(
'operand')

container.add_initializer(operand_name, onnx_proto.TensorProto.FLOAT,
[], [operator.operand])

apply_mul(scope, [operator.inputs[0].full_name, operand_name],
operator.outputs[0].full_name, container)


register_converter('SklearnMultiply', convert_sklearn_multiply)
1 change: 1 addition & 0 deletions skl2onnx/shape_calculators/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def calculate_sklearn_concat(operator):
register_shape_calculator('SklearnConcat', calculate_sklearn_concat)
register_shape_calculator('SklearnGenericUnivariateSelect',
calculate_sklearn_concat)
register_shape_calculator('SklearnMultiply', calculate_sklearn_concat)
register_shape_calculator('SklearnRFE', calculate_sklearn_concat)
register_shape_calculator('SklearnRFECV', calculate_sklearn_concat)
register_shape_calculator('SklearnSelectFdr', calculate_sklearn_concat)
Expand Down
102 changes: 102 additions & 0 deletions tests/test_sklearn_feature_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import unittest
import numpy as np
from sklearn.datasets import load_digits, load_iris
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.model_selection import train_test_split
from sklearn.pipeline import FeatureUnion
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType, Int64TensorType
from test_utils import dump_data_and_model


class TestSklearnAdaBoostModels(unittest.TestCase):
def test_feature_union_default(self):
data = load_iris()
X, y = data.data, data.target
X = X.astype(np.float32)
X_train, X_test, *_ = train_test_split(X, y, test_size=0.5,
random_state=42)
model = FeatureUnion([('standard', StandardScaler()),
('minmax', MinMaxScaler())]).fit(X_train)
model_onnx = convert_sklearn(
model, 'feature union', [('input', FloatTensorType(X_test.shape))])
self.assertTrue(model_onnx is not None)
dump_data_and_model(X_test,
model,
model_onnx,
basename="SklearnFeatureUnionDefault")

def test_feature_union_transformer_weights_0(self):
data = load_iris()
X, y = data.data, data.target
X = X.astype(np.float32)
X_train, X_test, *_ = train_test_split(X, y, test_size=0.5,
random_state=42)
model = FeatureUnion([('standard', StandardScaler()),
('minmax', MinMaxScaler())],
transformer_weights={'standard': 2, 'minmax': 4}
).fit(X_train)
model_onnx = convert_sklearn(
model, 'feature union', [('input', FloatTensorType(X_test.shape))])
self.assertTrue(model_onnx is not None)
dump_data_and_model(X_test,
model,
model_onnx,
basename="SklearnFeatureUnionTransformerWeights0")

def test_feature_union_transformer_weights_1(self):
data = load_digits()
X, y = data.data, data.target
X = X.astype(np.int64)
X_train, X_test, *_ = train_test_split(X, y, test_size=0.5,
random_state=42)
model = FeatureUnion([('pca', PCA()),
('svd', TruncatedSVD())],
transformer_weights={'pca': 10, 'svd': 3}
).fit(X_train)
model_onnx = convert_sklearn(
model, 'feature union', [('input', Int64TensorType(X_test.shape))])
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnFeatureUnionTransformerWeights1-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)

def test_feature_union_transformer_weights_2(self):
data = load_digits()
X, y = data.data, data.target
X = X.astype(np.float32)
X_train, X_test, *_ = train_test_split(X, y, test_size=0.5,
random_state=42)
model = FeatureUnion([('pca', PCA()),
('svd', TruncatedSVD())],
transformer_weights={'pca1': 10, 'svd2': 3}
).fit(X_train)
model_onnx = convert_sklearn(
model, 'feature union', [('input', FloatTensorType(X_test.shape))])
self.assertTrue(model_onnx is not None)
dump_data_and_model(
X_test,
model,
model_onnx,
basename="SklearnFeatureUnionTransformerWeights2-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__)"
"<= StrictVersion('0.2.1')",
)


if __name__ == "__main__":
unittest.main()