-
Notifications
You must be signed in to change notification settings - Fork 178
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* remove unnecessary print, add quote around filenames in some places * replaces as_matrix by values (pandas warnings) * changes variable name to avoid getting warnings about invalid names * better consistency for converted, allows targetted onnx version to be None * Revert "better consistency for converted, allows targetted onnx version to be None" This reverts commit e257ca1. * handle the comparison of ONNX versions in only one place * fix bug with OneHotEncoder and scikit-learn 0.20 * release the constraint on scikit-learn (0.20.0 allowed) * fix one type issue for Python 2.7 * add documentation to compare_strict_version * Fixes #151, BernouilliNB converter * Removes unused nodes in graph * Adresses issue #143, enables build with keras 2.1.2 * Revert modifications due to a wrong merge * update keras version * Disable test on keras/mobilenet as it does not work * add unit test for xception (failing) * remove duplicate install * skip unit test if not installed (tensorflow still not available on python 3.7) * Fix when keras is not available * Fix missing import * Update test_single_operator_with_cntk_backend.py * Set up CI with Azure Pipelines * Update azure pipeline * Skip a unit test if tensorflow is not installed * merge * missing import * Revert "Merge branch 'master' of https://github.com/onnx/onnxmltools" This reverts commit 178e763, reversing changes made to 1a617ef. * revert changes * Revert changes * \r * \r * first step in the migration of xgboost code * XGBoost regression works * Finalize xgboost converter * Update README.md * Add function has_tensorflow * Update test_single_operator_with_cntk_backend.py * better desgin for a unit test * update xgboost classifier * Delete test_keras_xception.py * Delete requirements-deep.txt * Delete test_keras_modebilenetv2.py * less spaces * lower precision for xgboost comparison tests * disable xgboost testing on python 2
- Loading branch information
Showing
24 changed files
with
610 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
from .convert import convert |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
from ..common._container import XGBoostModelContainer | ||
from ..common._topology import * | ||
|
||
from xgboost import XGBRegressor, XGBClassifier | ||
|
||
xgboost_classifier_list = [XGBClassifier] | ||
|
||
# Associate types with our operator names. | ||
xgboost_operator_name_map = {XGBClassifier: 'XGBClassifier', | ||
XGBRegressor: 'XGBRegressor'} | ||
|
||
|
||
def _get_xgboost_operator_name(model_type): | ||
''' | ||
Get operator name of the input argument | ||
:param model_type: A xgboost object. | ||
:return: A string which stands for the type of the input model in our conversion framework | ||
''' | ||
if model_type not in xgboost_operator_name_map: | ||
raise ValueError("No proper operator name found for '%s'" % model_type) | ||
return xgboost_operator_name_map[model_type] | ||
|
||
|
||
def _parse_xgboost_simple_model(scope, model, inputs): | ||
''' | ||
This function handles all non-pipeline models. | ||
:param scope: Scope object | ||
:param model: A xgboost object | ||
:param inputs: A list of variables | ||
:return: A list of output variables which will be passed to next stage | ||
''' | ||
this_operator = scope.declare_local_operator(_get_xgboost_operator_name(type(model)), model) | ||
this_operator.inputs = inputs | ||
|
||
if type(model) in xgboost_classifier_list: | ||
# For classifiers, we may have two outputs, one for label and the other one for probabilities of all classes. | ||
# Notice that their types here are not necessarily correct and they will be fixed in shape inference phase | ||
label_variable = scope.declare_local_variable('label', FloatTensorType()) | ||
probability_map_variable = scope.declare_local_variable('probabilities', FloatTensorType()) | ||
this_operator.outputs.append(label_variable) | ||
this_operator.outputs.append(probability_map_variable) | ||
else: | ||
# We assume that all scikit-learn operator can only produce a single float tensor. | ||
variable = scope.declare_local_variable('variable', FloatTensorType()) | ||
this_operator.outputs.append(variable) | ||
return this_operator.outputs | ||
|
||
|
||
def _parse_xgboost(scope, model, inputs): | ||
''' | ||
This is a delegate function. It doesn't nothing but invoke the correct parsing function according to the input | ||
model's type. | ||
:param scope: Scope object | ||
:param model: A xgboost object | ||
:param inputs: A list of variables | ||
:return: The output variables produced by the input model | ||
''' | ||
return _parse_xgboost_simple_model(scope, model, inputs) | ||
|
||
|
||
def parse_xgboost(model, initial_types=None, target_opset=None, | ||
custom_conversion_functions=None, custom_shape_calculators=None): | ||
|
||
raw_model_container = XGBoostModelContainer(model) | ||
topology = Topology(raw_model_container, | ||
initial_types=initial_types, target_opset=target_opset, | ||
custom_conversion_functions=custom_conversion_functions, | ||
custom_shape_calculators=custom_shape_calculators) | ||
scope = topology.declare_scope('__root__') | ||
|
||
inputs = [] | ||
for var_name, initial_type in initial_types: | ||
inputs.append(scope.declare_local_variable(var_name, initial_type)) | ||
|
||
for variable in inputs: | ||
raw_model_container.add_input(variable) | ||
|
||
outputs = _parse_xgboost(scope, model, inputs) | ||
|
||
for variable in outputs: | ||
raw_model_container.add_output(variable) | ||
|
||
return topology |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Common function to converters and shape calculators. | ||
""" | ||
|
||
def get_xgb_params(xgb_node): | ||
""" | ||
Retrieves parameters of a model. | ||
""" | ||
if hasattr(xgb_node, 'kwargs'): | ||
# XGBoost >= 0.7 | ||
params = xgb_node.get_xgb_params() | ||
else: | ||
# XGBoost < 0.7 | ||
params = xgb_node.__dict__ | ||
|
||
return params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
|
||
from uuid import uuid4 | ||
from ...proto import onnx, get_opset_number_from_onnx | ||
from ..common._topology import convert_topology | ||
from ._parse import parse_xgboost | ||
|
||
# Invoke the registration of all our converters and shape calculators | ||
# from . import shape_calculators | ||
from . import operator_converters, shape_calculators | ||
|
||
|
||
def convert(model, name=None, initial_types=None, doc_string='', target_opset=None, | ||
targeted_onnx=onnx.__version__, custom_conversion_functions=None, | ||
custom_shape_calculators=None): | ||
''' | ||
This function produces an equivalent ONNX model of the given xgboost model. | ||
:param model: A xgboost model | ||
:param initial_types: a python list. Each element is a tuple of a variable name and a type defined in data_types.py | ||
:param name: The name of the graph (type: GraphProto) in the produced ONNX model (type: ModelProto) | ||
:param doc_string: A string attached onto the produced ONNX model | ||
:param target_opset: number, for example, 7 for ONNX 1.2, and 8 for ONNX 1.3. | ||
:param targeted_onnx: A string (for example, '1.1.2' and '1.2') used to specify the targeted ONNX version of the | ||
produced model. If ONNXMLTools cannot find a compatible ONNX python package, an error may be thrown. | ||
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function | ||
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator | ||
:return: An ONNX model (type: ModelProto) which is equivalent to the input xgboost model | ||
''' | ||
if initial_types is None: | ||
raise ValueError('Initial types are required. See usage of convert(...) in \ | ||
onnxmltools.convert.xgboost.convert for details') | ||
if name is None: | ||
name = str(uuid4().hex) | ||
|
||
target_opset = target_opset if target_opset else get_opset_number_from_onnx() | ||
topology = parse_xgboost(model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators) | ||
topology.compile() | ||
onnx_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx) | ||
return onnx_model |
Oops, something went wrong.