Skip to content

Commit

Permalink
The custom operator support for the keras/coreml/sklearn converters. (#…
Browse files Browse the repository at this point in the history
…114)

* the interface of custom operator for the converters.

* update the implementation.

* implememt a thread-safe version.

* more polish.

* fix the failure during build verifying.

* fix python 2.7 issues.

* refine the custom op code.

* scikit fixing.

* Fix the mutable default argument issue.
  • Loading branch information
wenbingl committed Jul 21, 2018
1 parent 0fcf672 commit 5dc1622
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 29 deletions.
2 changes: 2 additions & 0 deletions onnxmltools/__init__.py
Expand Up @@ -21,6 +21,8 @@
from .convert import convert_sklearn
from .convert import convert_keras

from .convert.common.interface import *

from .utils import load_model
from .utils import save_model
from .utils import save_text
Expand Down
3 changes: 2 additions & 1 deletion onnxmltools/convert/common/_container.py
Expand Up @@ -7,6 +7,7 @@
import six
from distutils.version import StrictVersion
from ...proto import helper
from .interface import ModelContainer


class RawModelContainer(object):
Expand Down Expand Up @@ -106,7 +107,7 @@ def output_names(self):
return [name for name in self._output_raw_names]


class ModelComponentContainer:
class ModelComponentContainer(ModelContainer):
'''
In the conversion phase, this class is used to collect all materials required to build an ONNX GraphProto, which is
encapsulated in a ONNX ModelProto.
Expand Down
34 changes: 27 additions & 7 deletions onnxmltools/convert/common/_topology.py
Expand Up @@ -12,7 +12,7 @@
from ._container import ModelComponentContainer
from . import _registration
from . import utils

from .interface import OperatorBase

class Variable:

Expand Down Expand Up @@ -42,7 +42,7 @@ def full_name(self):
return self.onnx_name


class Operator:
class Operator(OperatorBase):

def __init__(self, onnx_name, scope, type, raw_operator, targeted_onnx_version):
'''
Expand Down Expand Up @@ -81,10 +81,17 @@ def input_full_names(self):
@property
def output_full_names(self):
'''
Return all outpu variables' names
Return all output variables' names
'''
return [variable.full_name for variable in self.outputs]

@property
def original_operator(self):
'''
Return the original operator/layer
'''
return self.raw_operator

def infer_types(self):
# Invoke a core inference function
_registration.get_shape_calculator(self.type)(self)
Expand Down Expand Up @@ -225,7 +232,8 @@ def delete_local_variable(self, onnx_name):
class Topology:

def __init__(self, model, default_batch_size=1, initial_types=None,
reserved_variable_names=None, reserved_operator_names=None, targeted_onnx=None):
reserved_variable_names=None, reserved_operator_names=None, targeted_onnx=None,
custom_conversion_functions=None, custom_shape_calculators=None):
'''
Initialize a Topology object, which is an intermediate representation of a computational graph.
Expand All @@ -235,6 +243,8 @@ def __init__(self, model, default_batch_size=1, initial_types=None,
name and a type defined in data_types.py.
:param reserved_variable_names: A set of strings which are not allowed to be used as a variable name
:param reserved_operator_names: A set of strings which are not allowed to be used as a operator name
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
'''
self.scopes = []
self.raw_model = model
Expand All @@ -244,6 +254,8 @@ def __init__(self, model, default_batch_size=1, initial_types=None,
self.initial_types = initial_types if initial_types else list()
self.default_batch_size = default_batch_size
self.targeted_onnx_version = StrictVersion(targeted_onnx)
self.custom_conversion_functions = custom_conversion_functions if custom_conversion_functions else {}
self.custom_shape_calculators = custom_shape_calculators if custom_shape_calculators else {}

# This attribute is used in optimizing the graph structure. If root_names is not empty, only the variables
# specified will be treated as the roots (i.e., set is_fed to True in the beginning of a graph evaluation) of
Expand Down Expand Up @@ -475,7 +487,12 @@ def _infer_all_types(self):

# Traverse the graph from roots to leaves
for operator in self.topological_operator_iterator():
operator.infer_types()
if operator.type in self.custom_shape_calculators:
self.custom_shape_calculators[operator.type](operator)
elif operator.type in self.custom_conversion_functions:
pass # in Keras converter, the shape calculator can be optional.
else:
operator.infer_types()

def _resolve_duplicates(self):
'''
Expand Down Expand Up @@ -650,8 +667,11 @@ def convert_topology(topology, model_name, doc_string, targeted_onnx):
# Traverse the graph from roots to leaves
for operator in topology.topological_operator_iterator():
scope = next(scope for scope in topology.scopes if scope.name == operator.scope)
# Convert the selected operator into some ONNX objects and save them into the container
_registration.get_converter(operator.type)(scope, operator, container)
if operator.type in topology.custom_conversion_functions:
topology.custom_conversion_functions[operator.type](scope, operator, container)
else:
# Convert the selected operator into some ONNX objects and save them into the container
_registration.get_converter(operator.type)(scope, operator, container)

# When calling ModelComponentContainer's add_initializer(...), nothing is added into the input list. However, in
# ONNX initializers should also be model's (GraphProto) inputs. Thus, we create ValueInfoProto objects from
Expand Down
84 changes: 84 additions & 0 deletions onnxmltools/convert/common/interface.py
@@ -0,0 +1,84 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# This file defines the interface of the converter internal object for callback,
# So the usage of the methods and properties list here will not be affected among the different versions.


import abc
import six

@six.add_metaclass(abc.ABCMeta)
class ModelContainer:
__metaclass = abc.ABCMeta

@abc.abstractmethod
def add_initializer(self, name, onnx_type, shape, content):
"""
Add a TensorProto into the initializer list of the final ONNX model
:param name: Variable name in the produced ONNX model.
:param onnx_type: Element types allowed in ONNX tensor, e.g., TensorProto.FLOAT and TensorProto.STRING.
:param shape: Tensor shape, a list of integers.
:param content: Flattened tensor values (i.e., a float list or a float array).
"""
return

@abc.abstractmethod
def add_node(self, op_type, inputs, outputs, op_domain='', op_version=1, **attrs):
"""
Add a NodeProto into the node list of the final ONNX model. If the input operator's domain-version information
cannot be found in our domain-version pool (a Python set), we may add it.
:param op_type: A string (e.g., Pool and Conv) indicating the type of the NodeProto
:param inputs: A list of strings. They are the input variables' names of the considered NodeProto
:param outputs: A list of strings. They are the output variables' names of the considered NodeProto
:param op_domain: The domain name (e.g., ai.onnx.ml) of the operator we are trying to add.
:param op_version: The version number (e.g., 0 and 1) of the operator we are trying to add.
:param attrs: A Python dictionary. Keys and values are attributes' names and attributes' values, respectively.
"""
return

@six.add_metaclass(abc.ABCMeta)
class OperatorBase:
__metaclass__ = abc.ABCMeta

@property
@abc.abstractmethod
def full_name(self):
"""
Return a globally unique operator ID
"""
pass

@property
@abc.abstractmethod
def input_full_names(self):
"""
Return all input variables' names
"""
pass

@property
@abc.abstractmethod
def output_full_names(self):
"""
Return all outpu variables' names
"""
pass

@property
@abc.abstractmethod
def original_operator(self):
"""
Return the original operator/layer
"""
pass

@six.add_metaclass(abc.ABCMeta)
class ScopeBase:
__metaclass__ = abc.ABCMeta

pass
7 changes: 5 additions & 2 deletions onnxmltools/convert/coreml/_parse.py
Expand Up @@ -427,14 +427,16 @@ def _parse_neural_network_model(topology, parent_scope, model, inputs, outputs):
operator.outputs.append(parent_variable)


def parse_coreml(model, initial_types=None, targeted_onnx=onnx.__version__):
def parse_coreml(model, initial_types=None, targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
'''
This is the root function of the whole parsing procedure.
:param model: CoreML model
:param initial_types: A list providing some types for some root variables. Each element is a tuple of a variable
name and a type defined in data_types.py.
:param targeted_onnx: a version string such as `1.1.2` or `1.2.1` for specifying the ONNX version used to produce
the output model.
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
:return: a Topology object. It's a intermediate representation of the input CoreML model
'''

Expand All @@ -453,7 +455,8 @@ def parse_coreml(model, initial_types=None, targeted_onnx=onnx.__version__):
# CoremlModelContainer, to make sure our topology-related functions can seamlessly handle both of CoreML and
# scikit-learn.
topology = Topology(CoremlModelContainer(model), default_batch_size, initial_types, reserved_variable_names,
targeted_onnx=targeted_onnx)
targeted_onnx=targeted_onnx, custom_conversion_functions=custom_conversion_functions,
custom_shape_calculators=custom_shape_calculators)
scope = topology.declare_scope('__root__')

# Instead of using CoremlModelContainer, we directly pass the model in because _parse_model is CoreML-specific.
Expand Down
7 changes: 5 additions & 2 deletions onnxmltools/convert/coreml/convert.py
Expand Up @@ -18,7 +18,8 @@
from .shape_calculators import neural_network as nn_shape_calculators


def convert(model, name=None, initial_types=None, doc_string='', targeted_onnx=onnx.__version__):
def convert(model, name=None, initial_types=None, doc_string='',
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
'''
This function converts the specified CoreML model into its ONNX counterpart. Some information such as the produced
ONNX model name can be specified.
Expand All @@ -30,6 +31,8 @@ def convert(model, name=None, initial_types=None, doc_string='', targeted_onnx=o
:param doc_string: A string attached onto the produced ONNX model
:param targeted_onnx: A string (for example, '1.1.2' and '1.2') used to specify the targeted ONNX version of the
produced model. If ONNXMLTools cannot find a compatible ONNX python package, an error may be thrown.
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
:return: An ONNX model (type: ModelProto) which is equivalent to the input CoreML model
Example of initial types:
Expand All @@ -47,7 +50,7 @@ def convert(model, name=None, initial_types=None, doc_string='', targeted_onnx=o
name = str(uuid4().hex)

# Parse CoreML model as our internal data structure (i.e., Topology)
topology = parse_coreml(spec, initial_types, targeted_onnx)
topology = parse_coreml(spec, initial_types, targeted_onnx, custom_conversion_functions, custom_shape_calculators)

# Parse CoreML description, author, and license. Those information will be attached to the final ONNX model.
metadata = spec.description.metadata
Expand Down
8 changes: 5 additions & 3 deletions onnxmltools/convert/keras/_parse.py
Expand Up @@ -75,7 +75,7 @@ def determine_tensor_type(tensor, default_batch_size, keras_shape=None):
raise ValueError('Unable to find out a correct type for tensor %s' % tensor)


def parse_keras(model, initial_types=None, targeted_onnx=onnx.__version__):
def parse_keras(model, initial_types=None, targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
'''
The main parsing function of Keras Model and Sequential objects.
Expand All @@ -84,12 +84,14 @@ def parse_keras(model, initial_types=None, targeted_onnx=onnx.__version__):
name and a type defined in data_types.py.
:param targeted_onnx: a version string such as `1.1.2` or `1.2.1` for specifying the ONNX version used to produce
the output model.
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
:return: a Topology object. It's a intermediate representation of the input Keras model
'''
raw_model_container = KerasModelContainer(model)

topology = Topology(raw_model_container, default_batch_size=1, initial_types=initial_types,
targeted_onnx=targeted_onnx)
topology = Topology(raw_model_container, default_batch_size=1, initial_types=initial_types, targeted_onnx=targeted_onnx,
custom_conversion_functions=custom_conversion_functions, custom_shape_calculators=custom_shape_calculators)
scope = topology.declare_scope('__root__')

# Each inbound node defines an evaluation of the underlining model (if the model is called multiple times, it may
Expand Down
11 changes: 8 additions & 3 deletions onnxmltools/convert/keras/convert.py
Expand Up @@ -7,14 +7,17 @@
from uuid import uuid4
from ...proto import onnx
from ..common._topology import convert_topology
from ..common._registration import register_converter
from ..common._registration import register_shape_calculator
from ._parse import parse_keras

# Register conversion functions and shape inference functions
from . import operator_converters
from . import shape_calculators


def convert(model, name=None, initial_types=None, doc_string='', targeted_onnx=onnx.__version__):
def convert(model, name=None, initial_types=None, doc_string='',
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
'''
Convert Keras-Tensorflow Model and Sequence objects into Topology. Note that default batch size is 1 here instead of
`None` used in CoreML conversion framework. To overwrite this behavior, we can specify initial_types. Assume that a
Expand All @@ -29,15 +32,17 @@ def convert(model, name=None, initial_types=None, doc_string='', targeted_onnx=o
:param doc_string: A string attached onto the produced ONNX model
:param targeted_onnx: A string (for example, '1.1.2' and '1.2') used to specify the targeted ONNX version of the
produced model. If ONNXMLTools cannot find a compatible ONNX python package, an error may be thrown.
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
:return: An ONNX model (type: ModelProto) which is equivalent to the input Keras model
'''
topology = parse_keras(model, initial_types, targeted_onnx)

topology = parse_keras(model, initial_types, targeted_onnx, custom_conversion_functions, custom_shape_calculators)

topology.compile()

if name is None:
name = str(uuid4().hex)

onnx_model = convert_topology(topology, name, doc_string, targeted_onnx)

return onnx_model
20 changes: 13 additions & 7 deletions onnxmltools/convert/main.py
Expand Up @@ -8,25 +8,31 @@
from .common import utils


def convert_sklearn(model, name=None, initial_types=None, doc_string='', targeted_onnx=onnx.__version__):
def convert_sklearn(model, name=None, initial_types=None, doc_string='',
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
if not utils.sklearn_installed():
raise RuntimeError('scikit-learn is not installed. Please install scikit-learn to use this feature.')

from .sklearn.convert import convert
return convert(model, name=name, initial_types=initial_types, doc_string=doc_string, targeted_onnx=targeted_onnx)
return convert(model, name, initial_types,
doc_string, targeted_onnx, custom_conversion_functions, custom_shape_calculators)


def convert_coreml(model, name=None, initial_types=None, doc_string='', targeted_onnx=onnx.__version__):
def convert_coreml(model, name=None, initial_types=None, doc_string='',
targeted_onnx=onnx.__version__ , custom_conversion_functions=None, custom_shape_calculators=None):
if not utils.coreml_installed():
raise RuntimeError('coremltools is not installed. Please install coremltools to use this feature.')

from .coreml.convert import convert
return convert(model, name=name, initial_types=initial_types, doc_string=doc_string, targeted_onnx=targeted_onnx)
return convert(model, name, initial_types,
doc_string, targeted_onnx, custom_conversion_functions, custom_shape_calculators)


def convert_keras(model, name=None, initial_types=None, doc_string='', targeted_onnx=onnx.__version__):
def convert_keras(model, name=None, initial_types=None, doc_string='',
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
if not utils.keras_installed():
raise RuntimeError('keras is not installed. Please install coremltools to use this feature.')
raise RuntimeError('keras is not installed. Please install it to use this feature.')

from .keras.convert import convert
return convert(model, name, initial_types=initial_types, doc_string=doc_string, targeted_onnx=targeted_onnx)
return convert(model, name, initial_types,
doc_string, targeted_onnx, custom_conversion_functions, custom_shape_calculators)
5 changes: 3 additions & 2 deletions onnxmltools/convert/sklearn/_parse.py
Expand Up @@ -167,13 +167,14 @@ def _parse_sklearn(scope, model, inputs):
return _parse_sklearn_simple_model(scope, model, inputs)


def parse_sklearn(model, initial_types=None, targeted_onnx=onnx.__version__):
def parse_sklearn(model, initial_types=None, targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
# Put scikit-learn object into an abstract container so that our framework can work seamlessly on models created
# with different machine learning tools.
raw_model_container = SklearnModelContainer(model)

# Declare a computational graph. It will become a representation of the input scikit-learn model after parsing.
topology = Topology(raw_model_container, initial_types=initial_types, targeted_onnx=targeted_onnx)
topology = Topology(raw_model_container, initial_types=initial_types, targeted_onnx=targeted_onnx,
custom_conversion_functions = custom_conversion_functions, custom_shape_calculators = custom_shape_calculators)

# Declare an object to provide variables' and operators' naming mechanism. In contrast to CoreML, one global scope
# is enough for parsing scikit-learn models.
Expand Down

0 comments on commit 5dc1622

Please sign in to comment.