Skip to content

Commit

Permalink
Fix a couple of mismatches for every tree in scikit-learn. (#237)
Browse files Browse the repository at this point in the history
* adjust threshold for mismatch in sklearn trees
  • Loading branch information
xadupre committed Aug 28, 2019
1 parent 517b0bb commit 11ac5ef
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 18 deletions.
80 changes: 76 additions & 4 deletions skl2onnx/common/tree_ensemble.py
Expand Up @@ -6,6 +6,7 @@
"""
Common functions to convert any learner based on trees.
"""
import numpy as np


def get_default_tree_classifier_attribute_pairs():
Expand Down Expand Up @@ -47,14 +48,81 @@ def get_default_tree_regressor_attribute_pairs():
return attrs


def find_switch_point(fy, nfy):
"""
Finds the double so that
``(float)x != (float)(x + espilon)``.
"""
a = np.float64(fy)
b = np.float64(nfy)
fa = np.float32(a)
a0, b0 = a, a
while a != a0 or b != b0:
a0, b0 = a, b
m = (a + b) / 2
fm = np.float32(m)
if fm == fa:
a = m
fa = fm
else:
b = m
return a


def sklearn_threshold(dy, dtype, mode):
"""
*scikit-learn* does not compare x to a threshold
but (float)x to a double threshold. As we need a float
threshold, we need a different value than the threshold
rounded to float. For floats, it finds float *w* which
verifies::
(float)x <= y <=> (float)x <= w
For doubles, it finds double *w* which verifies::
(float)x <= y <=> x <= w
"""
if mode == "BRANCH_LEQ":
if dtype == np.float32:
fy = np.float32(dy)
if fy == dy:
return np.float64(fy)
if fy < dy:
return np.float64(fy)
eps = max(abs(fy), np.finfo(np.float32).eps) * 10
nfy = np.nextafter([fy], [fy - eps], dtype=np.float32)[0]
return np.float64(nfy)
elif dtype == np.float64:
fy = np.float32(dy)
eps = max(abs(fy), np.finfo(np.float32).eps) * 10
afy = np.nextafter([fy], [fy - eps], dtype=np.float32)[0]
afy2 = find_switch_point(afy, fy)
if fy > dy > afy2:
return afy2
bfy = np.nextafter([fy], [fy + eps], dtype=np.float32)[0]
bfy2 = find_switch_point(fy, bfy)
if fy <= dy <= bfy2:
return bfy2
return np.float64(fy)
raise TypeError("Unexpected dtype {}.".format(dtype))
raise RuntimeError("Threshold is not changed for other mode and "
"'BRANCH_LEQ' (actually '{}').".format(mode))


def add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id,
feature_id, mode, value, true_child_id, false_child_id,
weights, weight_id_bias, leaf_weights_are_counts):
weights, weight_id_bias, leaf_weights_are_counts,
adjust_threshold_for_sklearn, dtype):
attr_pairs['nodes_treeids'].append(tree_id)
attr_pairs['nodes_nodeids'].append(node_id)
attr_pairs['nodes_featureids'].append(feature_id)
attr_pairs['nodes_modes'].append(mode)
attr_pairs['nodes_values'].append(value)
if adjust_threshold_for_sklearn and mode != 'LEAF':
attr_pairs['nodes_values'].append(
sklearn_threshold(value, dtype, mode))
else:
attr_pairs['nodes_values'].append(value)
attr_pairs['nodes_truenodeids'].append(true_child_id)
attr_pairs['nodes_falsenodeids'].append(false_child_id)
attr_pairs['nodes_missing_value_tracks_true'].append(False)
Expand Down Expand Up @@ -91,7 +159,9 @@ def add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id,

def add_tree_to_attribute_pairs(attr_pairs, is_classifier, tree, tree_id,
tree_weight, weight_id_bias,
leaf_weights_are_counts):
leaf_weights_are_counts,
adjust_threshold_for_sklearn=False,
dtype=None):
for i in range(tree.node_count):
node_id = i
weight = tree.value[i]
Expand All @@ -111,4 +181,6 @@ def add_tree_to_attribute_pairs(attr_pairs, is_classifier, tree, tree_id,

add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id,
feat_id, mode, threshold, left_child_id, right_child_id,
weight, weight_id_bias, leaf_weights_are_counts)
weight, weight_id_bias, leaf_weights_are_counts,
adjust_threshold_for_sklearn=adjust_threshold_for_sklearn,
dtype=dtype)
3 changes: 2 additions & 1 deletion skl2onnx/convert.py
Expand Up @@ -27,7 +27,8 @@ def convert_sklearn(model, name=None, initial_types=None, doc_string='',
For pipeline conversion, user needs to make sure each component
is one of our supported items.
This function converts the specified *scikit-learn* model into its *ONNX* counterpart.
This function converts the specified *scikit-learn* model
into its *ONNX* counterpart.
Note that for all conversions, initial types are required.
*ONNX* model name can also be specified.
Expand Down
5 changes: 3 additions & 2 deletions skl2onnx/operator_converters/ada_boost.py
Expand Up @@ -160,7 +160,7 @@ def convert_sklearn_ada_boost_classifier(scope, operator, container):
attrs['classlabels_strings'] = classes

add_tree_to_attribute_pairs(attrs, True, op.estimators_[tree_id].tree_,
0, 1, 0, True)
0, 1, 0, True, True, dtype=container.dtype)
container.add_node(
op_type, operator.input_full_names,
[label_name, proba_name],
Expand Down Expand Up @@ -234,7 +234,8 @@ def _get_estimators_label(scope, operator, container, model):
attrs['n_targets'] = int(model.estimators_[tree_id].n_outputs_)
add_tree_to_attribute_pairs(attrs, False,
model.estimators_[tree_id].tree_,
0, 1, 0, False)
0, 1, 0, False, True,
dtype=container.dtype)

container.add_node(op_type, input_name,
estimator_label_name, op_domain='ai.onnx.ml',
Expand Down
6 changes: 4 additions & 2 deletions skl2onnx/operator_converters/decision_tree.py
Expand Up @@ -36,7 +36,8 @@ def convert_sklearn_decision_tree_classifier(scope, operator, container):
else:
raise ValueError('Labels must be all integers or all strings.')

add_tree_to_attribute_pairs(attrs, True, op.tree_, 0, 1., 0, True)
add_tree_to_attribute_pairs(attrs, True, op.tree_, 0, 1., 0, True,
True, dtype=container.dtype)

container.add_node(
op_type, operator.input_full_names,
Expand All @@ -51,7 +52,8 @@ def convert_sklearn_decision_tree_regressor(scope, operator, container):
attrs = get_default_tree_regressor_attribute_pairs()
attrs['name'] = scope.get_unique_operator_name(op_type)
attrs['n_targets'] = int(op.n_outputs_)
add_tree_to_attribute_pairs(attrs, False, op.tree_, 0, 1., 0, False)
add_tree_to_attribute_pairs(attrs, False, op.tree_, 0, 1., 0, False,
True, dtype=container.dtype)

input_name = operator.input_full_names
if type(operator.inputs[0].type) == Int64TensorType:
Expand Down
8 changes: 5 additions & 3 deletions skl2onnx/operator_converters/gradient_boosting.py
Expand Up @@ -92,14 +92,16 @@ def convert_sklearn_gradient_boosting_classifier(scope, operator, container):
for tree_id in range(n_est):
tree = op.estimators_[tree_id][0].tree_
add_tree_to_attribute_pairs(attrs, True, tree, tree_id,
tree_weight, 0, False)
tree_weight, 0, False, True,
dtype=container.dtype)
else:
for i in range(n_est):
for c in range(op.n_classes_):
tree_id = i * op.n_classes_ + c
tree = op.estimators_[i][c].tree_
add_tree_to_attribute_pairs(attrs, True, tree, tree_id,
tree_weight, c, False)
tree_weight, c, False, True,
dtype=container.dtype)

container.add_node(
op_type, operator.input_full_names,
Expand Down Expand Up @@ -138,7 +140,7 @@ def convert_sklearn_gradient_boosting_regressor(scope, operator, container):
tree = op.estimators_[i][0].tree_
tree_id = i
add_tree_to_attribute_pairs(attrs, False, tree, tree_id, tree_weight,
0, False)
0, False, True, dtype=container.dtype)

input_name = operator.input_full_names
if type(operator.inputs[0].type) == Int64TensorType:
Expand Down
6 changes: 4 additions & 2 deletions skl2onnx/operator_converters/random_forest.py
Expand Up @@ -55,7 +55,8 @@ def convert_sklearn_random_forest_classifier(scope, operator, container):
for tree_id in range(estimtator_count):
tree = op.estimators_[tree_id].tree_
add_tree_to_attribute_pairs(attr_pairs, True, tree, tree_id,
tree_weight, 0, True)
tree_weight, 0, True, True,
dtype=container.dtype)

container.add_node(
op_type, operator.input_full_names,
Expand All @@ -78,7 +79,8 @@ def convert_sklearn_random_forest_regressor_converter(scope,
for tree_id in range(estimtator_count):
tree = op.estimators_[tree_id].tree_
add_tree_to_attribute_pairs(attrs, False, tree, tree_id,
tree_weight, 0, False)
tree_weight, 0, False, True,
dtype=container.dtype)

input_name = operator.input_full_names
if type(operator.inputs[0].type) == Int64TensorType:
Expand Down
31 changes: 27 additions & 4 deletions tests/test_sklearn_random_forest_converters.py
Expand Up @@ -6,10 +6,12 @@

import unittest
import numpy
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import (
RandomForestClassifier, RandomForestRegressor,
ExtraTreesClassifier, ExtraTreesRegressor
)
from skl2onnx.common.data_types import onnx_built_with_ml, FloatTensorType
from test_utils import (
dump_one_class_classification,
Expand Down Expand Up @@ -67,6 +69,27 @@ def test_random_forest_classifier_mismatched_estimator_counts(self):
model.__class__.__name__ +
'_mismatched_estimator_counts')

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_random_forest_regressor_mismatches(self):
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, _ = train_test_split(
X, y, random_state=13)
X_test = X_test.astype(numpy.float32)
clr = RandomForestRegressor(n_jobs=1, n_estimators=100)
clr.fit(X_train, y_train)
clr.fit(X, y)
model_onnx, prefix = convert_model(clr, 'reg',
[('input',
FloatTensorType([None, 4]))])
dump_data_and_model(X_test, clr, model_onnx,
basename=prefix + "RegMis" +
clr.__class__.__name__ +
'_mismatched_estimator_counts')

@unittest.skipIf(not onnx_built_with_ml(),
reason="Requires ONNX-ML extension.")
def test_random_forest_regressor(self):
model = RandomForestRegressor(n_estimators=3)
dump_single_regression(
Expand Down

0 comments on commit 11ac5ef

Please sign in to comment.