Skip to content

Commit

Permalink
Test topology pruning (#179)
Browse files Browse the repository at this point in the history
Test topology pruning
  • Loading branch information
xadupre committed Jun 13, 2019
1 parent bf5311a commit 4200f11
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions tests/test_topology_prune.py
@@ -0,0 +1,75 @@
"""
Tests topology.
"""
import unittest
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import make_pipeline
from sklearn import datasets

from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.algebra.onnx_ops import OnnxIdentity


class IdentityTransformer(BaseEstimator, TransformerMixin):
def __init__(self):
TransformerMixin.__init__(self)
BaseEstimator.__init__(self)

def fit(self, X, y, sample_weight=None):
return self

def transform(self, X):
return X


class identity(IdentityTransformer):

def __init__(self):
IdentityTransformer.__init__(self)


def dummy_shape_calculator(operator):
op_input = operator.inputs[0]
operator.outputs[0].type = FloatTensorType(op_input.type.shape)


def dummy_converter(scope, operator, container):
X = operator.inputs[0]
out = operator.outputs

id1 = OnnxIdentity(X)
id2 = OnnxIdentity(id1, output_names=out[1:])
id2.add_to(scope, container)


class TestTopologyPrune(unittest.TestCase):

def test_dummy_identity(self):

digits = datasets.load_digits(n_class=6)
Xd = digits.data[:20]
yd = digits.target[:20]
n_samples, n_features = Xd.shape

idtr = make_pipeline(IdentityTransformer(), identity())
idtr.fit(Xd, yd)

update_registered_converter(IdentityTransformer, "IdentityTransformer",
dummy_shape_calculator, dummy_converter)
update_registered_converter(identity, "identity",
dummy_shape_calculator, dummy_converter)

model_onnx = convert_sklearn(
idtr,
"idtr",
[("input", FloatTensorType([1, Xd.shape[1]]))],
)

idnode = [node for node in model_onnx.graph.node
if node.op_type == "Identity"]
assert len(idnode) == 2


if __name__ == "__main__":
unittest.main()

0 comments on commit 4200f11

Please sign in to comment.