Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test topology pruning #179

Merged
merged 22 commits into from
Jun 13, 2019
Merged
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9c3d365
Enables opset 9 for TextVectorizer
sdpython Apr 24, 2019
9152cb9
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 2, 2019
ede9829
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 3, 2019
9ff2aa5
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 6, 2019
37ee770
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 9, 2019
2a6f00f
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 10, 2019
8df0812
check random forest
sdpython May 10, 2019
1c3b3e5
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 15, 2019
7934c4a
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 20, 2019
086f1fe
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 21, 2019
fa0517e
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 22, 2019
4e2053b
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 22, 2019
eb5923d
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython May 27, 2019
50311d5
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jun 3, 2019
4d82c47
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jun 6, 2019
6ade4aa
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jun 7, 2019
5a31e27
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jun 10, 2019
2bec92c
fix spaces
sdpython Jun 10, 2019
393a85a
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jun 10, 2019
974af24
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jun 12, 2019
b53ac8c
Test topology pruning
sdpython Jun 12, 2019
6922aba
Update test_topology_prune.py
sdpython Jun 13, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
77 changes: 77 additions & 0 deletions tests/test_topology_prune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Tests scikit-learn's binarizer converter.
prabhat00155 marked this conversation as resolved.
Show resolved Hide resolved
"""
import unittest
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import make_pipeline
from sklearn import datasets

from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.algebra.onnx_ops import OnnxIdentity


class IdentityTransformer(BaseEstimator, TransformerMixin):
def __init__(self):
TransformerMixin.__init__(self)
BaseEstimator.__init__(self)

def fit(self, X, y, sample_weight=None):
return self

def transform(self, X):
return X


class identity(IdentityTransformer):

def __init__(self):
IdentityTransformer.__init__(self)


def dummy_shape_calculator(operator):
op_input = operator.inputs[0]
N = op_input.type.shape[0]
C = op_input.type.shape[1]
operator.outputs[0].type = FloatTensorType([N, C])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need N and C here? We could just write:
operator.outputs[0].type = FloatTensorType(op_input.type.shape), couldn't we?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. I update the master branch on my repo and I create new branches with git checkout -b …

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def dummy_converter(scope, operator, container):
X = operator.inputs[0]
out = operator.outputs

id1 = OnnxIdentity(X)
id2 = OnnxIdentity(id1, output_names=out[1:])
id2.add_to(scope, container)


class TestTopologyPrune(unittest.TestCase):

def test_dummy_identity(self):

digits = datasets.load_digits(n_class=6)
Xd = digits.data[:20]
yd = digits.target[:20]
n_samples, n_features = Xd.shape

idtr = make_pipeline(IdentityTransformer(), identity())
idtr.fit(Xd, yd)

update_registered_converter(IdentityTransformer, "IdentityTransformer",
dummy_shape_calculator, dummy_converter)
update_registered_converter(identity, "identity",
dummy_shape_calculator, dummy_converter)

model_onnx = convert_sklearn(
idtr,
"idtr",
[("input", FloatTensorType([1, Xd.shape[1]]))],
)

idnode = [node for node in model_onnx.graph.node
if node.op_type == "Identity"]
assert len(idnode) == 2


if __name__ == "__main__":
unittest.main()