Skip to content

Commit

Permalink
Add support for H2O GBM MOJO
Browse files Browse the repository at this point in the history
  • Loading branch information
honzasterba committed Dec 12, 2019
1 parent 8ce90a3 commit 48f0eec
Show file tree
Hide file tree
Showing 17 changed files with 3,939 additions and 0 deletions.
1 change: 1 addition & 0 deletions onnxmltools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .convert import convert_sparkml
from .convert import convert_tensorflow
from .convert import convert_xgboost
from .convert import convert_h2o

from .utils import load_model
from .utils import save_model
1 change: 1 addition & 0 deletions onnxmltools/convert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .main import convert_sparkml
from .main import convert_tensorflow
from .main import convert_xgboost
from .main import convert_h2o
2 changes: 2 additions & 0 deletions onnxmltools/convert/common/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class LightGbmModelContainer(CommonSklearnModelContainer):
class XGBoostModelContainer(CommonSklearnModelContainer):
pass

class H2OModelContainer(CommonSklearnModelContainer):
pass

class SparkmlModelContainer(RawModelContainer):

Expand Down
7 changes: 7 additions & 0 deletions onnxmltools/convert/h2o/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from .convert import convert
53 changes: 53 additions & 0 deletions onnxmltools/convert/h2o/_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from ..common._container import H2OModelContainer
from ..common._topology import *

def _parse_h2o(scope, model, inputs):
'''
:param scope: Scope object
:param model: A h2o model data object
:param inputs: A list of variables
:return: A list of output variables which will be passed to next stage
'''
this_operator = scope.declare_local_operator("H2OTreeMojo", model)
this_operator.inputs = inputs

if model["params"]["classifier"]:
label_variable = scope.declare_local_variable('label', FloatTensorType())
probability_map_variable = scope.declare_local_variable('probabilities', FloatTensorType())
this_operator.outputs.append(label_variable)
this_operator.outputs.append(probability_map_variable)
else:
variable = scope.declare_local_variable('variable', FloatTensorType())
this_operator.outputs.append(variable)
return this_operator.outputs


def parse_h2o(model, initial_types=None, target_opset=None,
custom_conversion_functions=None, custom_shape_calculators=None):

raw_model_container = H2OModelContainer(model)
topology = Topology(raw_model_container, default_batch_size='None',
initial_types=initial_types, target_opset=target_opset,
custom_conversion_functions=custom_conversion_functions,
custom_shape_calculators=custom_shape_calculators)
scope = topology.declare_scope('__root__')

inputs = []
for var_name, initial_type in initial_types:
inputs.append(scope.declare_local_variable(var_name, initial_type))

for variable in inputs:
raw_model_container.add_input(variable)

outputs = _parse_h2o(scope, model, inputs)

for variable in outputs:
raw_model_container.add_output(variable)

return topology
51 changes: 51 additions & 0 deletions onnxmltools/convert/h2o/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from uuid import uuid4
import json
import h2o

from ...proto import onnx, get_opset_number_from_onnx
from ..common._topology import convert_topology
from ._parse import parse_h2o

# Invoke the registration of all our converters and shape calculators
from . import operator_converters, shape_calculators


def convert(model_path, name=None, initial_types=None, doc_string='', target_opset=None,
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
custom_shape_calculators=None):
'''
This function produces an equivalent ONNX model of the given h2o model.
:param model_path: A path to exported H2O MOJO model file
:param initial_types: a python list. Each element is a tuple of a variable name and a type defined in data_types.py
:param name: The name of the graph (type: GraphProto) in the produced ONNX model (type: ModelProto)
:param doc_string: A string attached onto the produced ONNX model
:param target_opset: number, for example, 7 for ONNX 1.2, and 8 for ONNX 1.3.
:param targeted_onnx: A string (for example, '1.1.2' and '1.2') used to specify the targeted ONNX version of the
produced model. If ONNXMLTools cannot find a compatible ONNX python package, an error may be thrown.
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
:return: An ONNX model (type: ModelProto) which is equivalent to the input xgboost model
'''
if initial_types is None:
raise ValueError('Initial types are required. See usage of convert(...) in ' +
'onnxmltools.convert.h2o.convert for details')
if name is None:
name = str(uuid4().hex)

mojo_str = h2o.print_mojo(model_path, format="json")
mojo_model = json.loads(mojo_str)
if mojo_model["params"]["algo"] != "gbm":
raise ValueError("Only GBM Mojo supported for now (algo=%s)." % mojo_model["params"]["algo"])

target_opset = target_opset if target_opset else get_opset_number_from_onnx()
topology = parse_h2o(mojo_model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators)
topology.compile()
onnx_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx)
return onnx_model
8 changes: 8 additions & 0 deletions onnxmltools/convert/h2o/operator_converters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

# To register converter for scikit-learn operators, import associated modules here.
from . import h2o
165 changes: 165 additions & 0 deletions onnxmltools/convert/h2o/operator_converters/h2o.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import numpy as np
from ...common._registration import register_converter

_LINK_FUNCTION_TO_POST_TRANSFORM = {
'identity': 'NONE',
'logit': 'LOGISTIC',
'ologit': 'LOGISTIC'
}


def _get_post_transform(params):
link_function = params["link_function"]
family = params["family"]
if family == "multinomial":
return 'SOFTMAX'
elif link_function not in _LINK_FUNCTION_TO_POST_TRANSFORM.keys():
raise ValueError("Link function %s not supported." % link_function)
else:
return _LINK_FUNCTION_TO_POST_TRANSFORM[link_function]


def _get_default_tree_attribute_pairs(is_classifier, params):
attrs = {
'post_transform': _get_post_transform(params)
}
nclasses = params["nclasses"]
if is_classifier:
predicted_classes = nclasses if nclasses > 2 else 1
attrs['base_values'] = [params["base_score"] for _ in range(0, predicted_classes)]
else:
attrs['n_targets'] = 1
attrs['base_values'] = [params["base_score"]]
for k in {'nodes_treeids', 'nodes_nodeids',
'nodes_featureids', 'nodes_modes', 'nodes_values',
'nodes_truenodeids', 'nodes_falsenodeids', 'nodes_missing_value_tracks_true'}:
attrs[k] = []
node_attr_prefix = _node_attr_prefix(is_classifier)
for k in {'_treeids', '_nodeids', '_ids', '_weights'}:
attrs[node_attr_prefix + k] = []
return attrs


def _add_node(
attr_pairs, is_classifier, tree_id, node_id,
feature_id, mode, value, true_child_id, false_child_id, weights,
missing
):
attr_pairs['nodes_treeids'].append(tree_id)
attr_pairs['nodes_nodeids'].append(node_id)
attr_pairs['nodes_featureids'].append(feature_id)
attr_pairs['nodes_modes'].append(mode)
attr_pairs['nodes_values'].append(float(value))
attr_pairs['nodes_truenodeids'].append(true_child_id)
attr_pairs['nodes_falsenodeids'].append(false_child_id)
attr_pairs['nodes_missing_value_tracks_true'].append(missing)
if mode == 'LEAF':
node_attr_prefix = _node_attr_prefix(is_classifier)
for i, w in enumerate(weights):
attr_pairs[node_attr_prefix + '_treeids'].append(tree_id)
attr_pairs[node_attr_prefix + '_nodeids'].append(node_id)
attr_pairs[node_attr_prefix + '_ids'].append(i)
attr_pairs[node_attr_prefix + '_weights'].append(float(w))


def _node_attr_prefix(is_classifier):
return "class" if is_classifier else "target"


def _fill_node_attributes(tree_id, node, attr_pairs, is_classifier):
if 'leftChild' in node:
if node["isCategorical"]:
raise ValueError("categorical splits not supported, use one_hot_explicit")
else:
operator = 'BRANCH_GTE'
value = node['splitValue']
_add_node(
attr_pairs=attr_pairs,
is_classifier=is_classifier,
tree_id=tree_id,
mode=operator,
value=value,
node_id=node['id'],
feature_id=node['colId'],
true_child_id=node['rightChild']['id'],
false_child_id=node['leftChild']['id'],
weights=None,
missing=(0 if node["leftward"] else 1),
)
_fill_node_attributes(tree_id, node["leftChild"], attr_pairs, is_classifier)
_fill_node_attributes(tree_id, node["rightChild"], attr_pairs, is_classifier)
else: # leaf
weights = [node['predValue']]
_add_node(
attr_pairs=attr_pairs,
is_classifier=is_classifier,
tree_id=tree_id,
value=0.,
node_id=node['id'],
feature_id=0, mode='LEAF',
true_child_id=0, false_child_id=0,
weights=weights,
missing=False
)


def assign_node_ids(node, next_id):
if node is None:
return next_id
node["id"] = next_id
next_id += 1
next_id = assign_node_ids(node.get("leftChild", None), next_id)
return assign_node_ids(node.get("rightChild", None), next_id)


def fill_tree_attributes(model, attr_pairs, node_attr_prefix):
for tree in model["trees"]:
assign_node_ids(tree["root"], 0)
_fill_node_attributes(tree["index"], tree["root"], attr_pairs, node_attr_prefix)


def convert_regression(scope, operator, container, params):
model = operator.raw_operator

attr_pairs = _get_default_tree_attribute_pairs(False, params)
fill_tree_attributes(model, attr_pairs, False)

# add nodes
container.add_node('TreeEnsembleRegressor', operator.input_full_names,
operator.output_full_names, op_domain='ai.onnx.ml',
name=scope.get_unique_operator_name('TreeEnsembleRegressor'), **attr_pairs)


def convert_classifier(scope, operator, container, params):
model = operator.raw_operator

attr_pairs = _get_default_tree_attribute_pairs(True, params)
fill_tree_attributes(model, attr_pairs, True)

n_trees_in_group = params["n_trees_in_group"]
attr_pairs['class_ids'] = [v % n_trees_in_group for v in attr_pairs['class_treeids']]
attr_pairs['classlabels_strings'] = params["class_labels"]

container.add_node('TreeEnsembleClassifier', operator.input_full_names,
operator.output_full_names,
op_domain='ai.onnx.ml',
name=scope.get_unique_operator_name('TreeEnsembleClassifier'),
**attr_pairs)


def convert_h2o(scope, operator, container):
params = operator.raw_operator["params"]
is_classifier = params["classifier"]
if is_classifier:
return convert_classifier(scope, operator, container, params)
else: # regression
return convert_regression(scope, operator, container, params)


register_converter('H2OTreeMojo', convert_h2o)
8 changes: 8 additions & 0 deletions onnxmltools/convert/h2o/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

# To register shape calculators for lightgbm operators, import associated modules here.
from . import h2otreemojo
31 changes: 31 additions & 0 deletions onnxmltools/convert/h2o/shape_calculators/h2otreemojo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

from ...common._registration import register_shape_calculator
from ...common.shape_calculator import calculate_linear_regressor_output_shapes
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
from ...common.data_types import (FloatTensorType, StringTensorType, Int64TensorType)


def calculate_h2otree_output_shapes(operator):
params = operator.raw_operator["params"]
if params["classifier"]:
calculate_tree_classifier_output_shapes(operator, params)
else:
calculate_linear_regressor_output_shapes(operator)


def calculate_tree_classifier_output_shapes(operator, params):
check_input_and_output_numbers(operator, input_count_range=1, output_count_range=2)
check_input_and_output_types(operator, good_input_types=[FloatTensorType, Int64TensorType])
N = operator.inputs[0].type.shape[0]
nclasses = params["nclasses"]
operator.outputs[0].type = StringTensorType(shape=[N])
if nclasses > 1:
operator.outputs[1].type = FloatTensorType([N, nclasses])


register_shape_calculator('H2OTreeMojo', calculate_h2otree_output_shapes)
6 changes: 6 additions & 0 deletions onnxmltools/convert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,9 @@ def convert_xgboost(*args, **kwargs):
from .xgboost.convert import convert
return convert(*args, **kwargs)

def convert_h2o(*args, **kwargs):
if not utils.h2o_installed():
raise RuntimeError('h2o is not installed. Please install h2o to use this feature.')

from .h2o.convert import convert
return convert(*args, **kwargs)
3 changes: 3 additions & 0 deletions onnxmltools/utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
elif hasattr(model, "predict_proba"):
# Classifier
prediction = [model.predict(data), model.predict_proba(data)]
elif hasattr(model, "predict_with_probabilities"):
# Classifier that returns all in one go
prediction = model.predict_with_probabilities(data)
elif hasattr(model, "decision_function"):
# Classifier without probabilities
prediction = [model.predict(data), model.decision_function(data)]
Expand Down
2 changes: 2 additions & 0 deletions onnxmltools/utils/utils_backend_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def compare_runtime(test, decimal=5, options=None, verbose=False, context=None):
except ExpectedAssertionError as expe:
raise expe
except Exception as e:
print(e)
if "CannotLoad" in options:
raise ExpectedAssertionError("Unable to load onnx '{0}' due to\n{1}".format(onx, e))
else:
Expand Down Expand Up @@ -155,6 +156,7 @@ def to_array(vv):
except ExpectedAssertionError as expe:
raise expe
except Exception as e:
raise e
if verbose:
import onnx
model = onnx.load(onx)
Expand Down
Loading

0 comments on commit 48f0eec

Please sign in to comment.