Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a couple of mismatches for every tree in scikit-learn. #237

Merged
merged 22 commits into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
784d64b
Enables opset 9 for TextVectorizer
sdpython Apr 24, 2019
0047376
check random forest
sdpython May 10, 2019
d18b90c
fix spaces
sdpython Jun 10, 2019
249b19f
update converter for TfIdf after a change of spec in onnxruntime
sdpython Jun 11, 2019
1cae401
check random forest
sdpython May 10, 2019
98e45ef
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jun 20, 2019
2e796a4
Delete test_SklearnGradientBoostingConverters.py
sdpython Jun 20, 2019
f7eecdd
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jun 25, 2019
f4601b4
Update tests_helper.py
sdpython Jun 26, 2019
204c3c4
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jul 3, 2019
a8d0118
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jul 3, 2019
1b0aa55
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jul 24, 2019
9a446fc
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Jul 26, 2019
95c2930
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Aug 19, 2019
cd3ef25
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Aug 22, 2019
8f75535
adjust threshold for mismatch in sklearn trees
sdpython Aug 22, 2019
66680b4
Fix eps
sdpython Aug 23, 2019
8ffac5e
Merge branch 'master' into rdf
xadupre Aug 27, 2019
f3580cb
Merge branch 'master' of https://github.com/onnx/sklearn-onnx
sdpython Aug 28, 2019
22d92b8
Update tests_helper.py
sdpython Aug 28, 2019
bc6fd6e
dummy PR
sdpython Aug 28, 2019
6ec85f7
Merge branch 'doc' of https://github.com/xadupre/sklearn-onnx into rdf
sdpython Aug 28, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
80 changes: 76 additions & 4 deletions skl2onnx/common/tree_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
Common functions to convert any learner based on trees.
"""
import numpy as np


def get_default_tree_classifier_attribute_pairs():
Expand Down Expand Up @@ -47,14 +48,81 @@ def get_default_tree_regressor_attribute_pairs():
return attrs


def find_switch_point(fy, nfy):
"""
Finds the double so that
``(float)x != (float)(x + espilon)``.
"""
a = np.float64(fy)
b = np.float64(nfy)
fa = np.float32(a)
a0, b0 = a, a
while a != a0 or b != b0:
a0, b0 = a, b
m = (a + b) / 2
fm = np.float32(m)
if fm == fa:
a = m
fa = fm
else:
b = m
return a


def sklearn_threshold(dy, dtype, mode):
"""
*scikit-learn* does not compare x to a threshold
but (float)x to a double threshold. As we need a float
threshold, we need a different value than the threshold
rounded to float. For floats, it finds float *w* which
verifies::

(float)x <= y <=> (float)x <= w

For doubles, it finds double *w* which verifies::

(float)x <= y <=> x <= w
"""
if mode == "BRANCH_LEQ":
if dtype == np.float32:
fy = np.float32(dy)
if fy == dy:
return np.float64(fy)
if fy < dy:
return np.float64(fy)
eps = max(abs(fy), np.finfo(np.float32).eps) * 10
nfy = np.nextafter([fy], [fy - eps], dtype=np.float32)[0]
return np.float64(nfy)
elif dtype == np.float64:
fy = np.float32(dy)
eps = max(abs(fy), np.finfo(np.float32).eps) * 10
afy = np.nextafter([fy], [fy - eps], dtype=np.float32)[0]
afy2 = find_switch_point(afy, fy)
if fy > dy > afy2:
return afy2
bfy = np.nextafter([fy], [fy + eps], dtype=np.float32)[0]
bfy2 = find_switch_point(fy, bfy)
if fy <= dy <= bfy2:
return bfy2
return np.float64(fy)
raise TypeError("Unexpected dtype {}.".format(dtype))
raise RuntimeError("Threshold is not changed for other mode and "
"'BRANCH_LEQ' (actually '{}').".format(mode))


def add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id,
feature_id, mode, value, true_child_id, false_child_id,
weights, weight_id_bias, leaf_weights_are_counts):
weights, weight_id_bias, leaf_weights_are_counts,
adjust_threshold_for_sklearn, dtype):
attr_pairs['nodes_treeids'].append(tree_id)
attr_pairs['nodes_nodeids'].append(node_id)
attr_pairs['nodes_featureids'].append(feature_id)
attr_pairs['nodes_modes'].append(mode)
attr_pairs['nodes_values'].append(value)
if adjust_threshold_for_sklearn and mode != 'LEAF':
attr_pairs['nodes_values'].append(
sklearn_threshold(value, dtype, mode))
else:
attr_pairs['nodes_values'].append(value)
attr_pairs['nodes_truenodeids'].append(true_child_id)
attr_pairs['nodes_falsenodeids'].append(false_child_id)
attr_pairs['nodes_missing_value_tracks_true'].append(False)
Expand Down Expand Up @@ -91,7 +159,9 @@ def add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id,

def add_tree_to_attribute_pairs(attr_pairs, is_classifier, tree, tree_id,
tree_weight, weight_id_bias,
leaf_weights_are_counts):
leaf_weights_are_counts,
adjust_threshold_for_sklearn=False,
dtype=None):
for i in range(tree.node_count):
node_id = i
weight = tree.value[i]
Expand All @@ -111,4 +181,6 @@ def add_tree_to_attribute_pairs(attr_pairs, is_classifier, tree, tree_id,

add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id,
feat_id, mode, threshold, left_child_id, right_child_id,
weight, weight_id_bias, leaf_weights_are_counts)
weight, weight_id_bias, leaf_weights_are_counts,
adjust_threshold_for_sklearn=adjust_threshold_for_sklearn,
dtype=dtype)
5 changes: 3 additions & 2 deletions skl2onnx/operator_converters/ada_boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def convert_sklearn_ada_boost_classifier(scope, operator, container):
attrs['classlabels_strings'] = classes

add_tree_to_attribute_pairs(attrs, True, op.estimators_[tree_id].tree_,
0, 1, 0, True)
0, 1, 0, True, True, dtype=container.dtype)
container.add_node(
op_type, operator.input_full_names,
[label_name, proba_name],
Expand Down Expand Up @@ -234,7 +234,8 @@ def _get_estimators_label(scope, operator, container, model):
attrs['n_targets'] = int(model.estimators_[tree_id].n_outputs_)
add_tree_to_attribute_pairs(attrs, False,
model.estimators_[tree_id].tree_,
0, 1, 0, False)
0, 1, 0, False, True,
dtype=container.dtype)

container.add_node(op_type, input_name,
estimator_label_name, op_domain='ai.onnx.ml',
Expand Down
6 changes: 4 additions & 2 deletions skl2onnx/operator_converters/decision_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def convert_sklearn_decision_tree_classifier(scope, operator, container):
else:
raise ValueError('Labels must be all integers or all strings.')

add_tree_to_attribute_pairs(attrs, True, op.tree_, 0, 1., 0, True)
add_tree_to_attribute_pairs(attrs, True, op.tree_, 0, 1., 0, True,
True, dtype=container.dtype)

container.add_node(
op_type, operator.input_full_names,
Expand All @@ -51,7 +52,8 @@ def convert_sklearn_decision_tree_regressor(scope, operator, container):
attrs = get_default_tree_regressor_attribute_pairs()
attrs['name'] = scope.get_unique_operator_name(op_type)
attrs['n_targets'] = int(op.n_outputs_)
add_tree_to_attribute_pairs(attrs, False, op.tree_, 0, 1., 0, False)
add_tree_to_attribute_pairs(attrs, False, op.tree_, 0, 1., 0, False,
True, dtype=container.dtype)

input_name = operator.input_full_names
if type(operator.inputs[0].type) == Int64TensorType:
Expand Down
8 changes: 5 additions & 3 deletions skl2onnx/operator_converters/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,16 @@ def convert_sklearn_gradient_boosting_classifier(scope, operator, container):
for tree_id in range(n_est):
tree = op.estimators_[tree_id][0].tree_
add_tree_to_attribute_pairs(attrs, True, tree, tree_id,
tree_weight, 0, False)
tree_weight, 0, False, True,
dtype=container.dtype)
else:
for i in range(n_est):
for c in range(op.n_classes_):
tree_id = i * op.n_classes_ + c
tree = op.estimators_[i][c].tree_
add_tree_to_attribute_pairs(attrs, True, tree, tree_id,
tree_weight, c, False)
tree_weight, c, False, True,
dtype=container.dtype)

container.add_node(
op_type, operator.input_full_names,
Expand Down Expand Up @@ -138,7 +140,7 @@ def convert_sklearn_gradient_boosting_regressor(scope, operator, container):
tree = op.estimators_[i][0].tree_
tree_id = i
add_tree_to_attribute_pairs(attrs, False, tree, tree_id, tree_weight,
0, False)
0, False, True, dtype=container.dtype)

input_name = operator.input_full_names
if type(operator.inputs[0].type) == Int64TensorType:
Expand Down
6 changes: 4 additions & 2 deletions skl2onnx/operator_converters/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def convert_sklearn_random_forest_classifier(scope, operator, container):
for tree_id in range(estimtator_count):
tree = op.estimators_[tree_id].tree_
add_tree_to_attribute_pairs(attr_pairs, True, tree, tree_id,
tree_weight, 0, True)
tree_weight, 0, True, True,
dtype=container.dtype)

container.add_node(
op_type, operator.input_full_names,
Expand All @@ -78,7 +79,8 @@ def convert_sklearn_random_forest_regressor_converter(scope,
for tree_id in range(estimtator_count):
tree = op.estimators_[tree_id].tree_
add_tree_to_attribute_pairs(attrs, False, tree, tree_id,
tree_weight, 0, False)
tree_weight, 0, False, True,
dtype=container.dtype)

input_name = operator.input_full_names
if type(operator.inputs[0].type) == Int64TensorType:
Expand Down
31 changes: 27 additions & 4 deletions tests/test_sklearn_random_forest_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

import unittest
import numpy
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import (
RandomForestClassifier, RandomForestRegressor,
ExtraTreesClassifier, ExtraTreesRegressor
)
from skl2onnx.common.data_types import onnx_built_with_ml, FloatTensorType
from test_utils import (
dump_one_class_classification,
Expand Down Expand Up @@ -67,6 +69,27 @@ def test_random_forest_classifier_mismatched_estimator_counts(self):
model.__class__.__name__ +
'_mismatched_estimator_counts')

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_random_forest_regressor_mismatches(self):
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, _ = train_test_split(
X, y, random_state=13)
X_test = X_test.astype(numpy.float32)
clr = RandomForestRegressor(n_jobs=1, n_estimators=100)
clr.fit(X_train, y_train)
clr.fit(X, y)
model_onnx, prefix = convert_model(clr, 'reg',
[('input',
FloatTensorType([None, 4]))])
dump_data_and_model(X_test, clr, model_onnx,
basename=prefix + "RegMis" +
clr.__class__.__name__ +
'_mismatched_estimator_counts')

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_random_forest_regressor(self):
model = RandomForestRegressor(n_estimators=3)
dump_single_regression(
Expand Down