diff --git a/skl2onnx/common/tree_ensemble.py b/skl2onnx/common/tree_ensemble.py index f11560a18..5b33e4db5 100644 --- a/skl2onnx/common/tree_ensemble.py +++ b/skl2onnx/common/tree_ensemble.py @@ -6,6 +6,7 @@ """ Common functions to convert any learner based on trees. """ +import numpy as np def get_default_tree_classifier_attribute_pairs(): @@ -47,14 +48,81 @@ def get_default_tree_regressor_attribute_pairs(): return attrs +def find_switch_point(fy, nfy): + """ + Finds the double so that + ``(float)x != (float)(x + espilon)``. + """ + a = np.float64(fy) + b = np.float64(nfy) + fa = np.float32(a) + a0, b0 = a, a + while a != a0 or b != b0: + a0, b0 = a, b + m = (a + b) / 2 + fm = np.float32(m) + if fm == fa: + a = m + fa = fm + else: + b = m + return a + + +def sklearn_threshold(dy, dtype, mode): + """ + *scikit-learn* does not compare x to a threshold + but (float)x to a double threshold. As we need a float + threshold, we need a different value than the threshold + rounded to float. For floats, it finds float *w* which + verifies:: + + (float)x <= y <=> (float)x <= w + + For doubles, it finds double *w* which verifies:: + + (float)x <= y <=> x <= w + """ + if mode == "BRANCH_LEQ": + if dtype == np.float32: + fy = np.float32(dy) + if fy == dy: + return np.float64(fy) + if fy < dy: + return np.float64(fy) + eps = max(abs(fy), np.finfo(np.float32).eps) * 10 + nfy = np.nextafter([fy], [fy - eps], dtype=np.float32)[0] + return np.float64(nfy) + elif dtype == np.float64: + fy = np.float32(dy) + eps = max(abs(fy), np.finfo(np.float32).eps) * 10 + afy = np.nextafter([fy], [fy - eps], dtype=np.float32)[0] + afy2 = find_switch_point(afy, fy) + if fy > dy > afy2: + return afy2 + bfy = np.nextafter([fy], [fy + eps], dtype=np.float32)[0] + bfy2 = find_switch_point(fy, bfy) + if fy <= dy <= bfy2: + return bfy2 + return np.float64(fy) + raise TypeError("Unexpected dtype {}.".format(dtype)) + raise RuntimeError("Threshold is not changed for other mode and " + "'BRANCH_LEQ' (actually '{}').".format(mode)) + + def add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id, feature_id, mode, value, true_child_id, false_child_id, - weights, weight_id_bias, leaf_weights_are_counts): + weights, weight_id_bias, leaf_weights_are_counts, + adjust_threshold_for_sklearn, dtype): attr_pairs['nodes_treeids'].append(tree_id) attr_pairs['nodes_nodeids'].append(node_id) attr_pairs['nodes_featureids'].append(feature_id) attr_pairs['nodes_modes'].append(mode) - attr_pairs['nodes_values'].append(value) + if adjust_threshold_for_sklearn and mode != 'LEAF': + attr_pairs['nodes_values'].append( + sklearn_threshold(value, dtype, mode)) + else: + attr_pairs['nodes_values'].append(value) attr_pairs['nodes_truenodeids'].append(true_child_id) attr_pairs['nodes_falsenodeids'].append(false_child_id) attr_pairs['nodes_missing_value_tracks_true'].append(False) @@ -91,7 +159,9 @@ def add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id, def add_tree_to_attribute_pairs(attr_pairs, is_classifier, tree, tree_id, tree_weight, weight_id_bias, - leaf_weights_are_counts): + leaf_weights_are_counts, + adjust_threshold_for_sklearn=False, + dtype=None): for i in range(tree.node_count): node_id = i weight = tree.value[i] @@ -111,4 +181,6 @@ def add_tree_to_attribute_pairs(attr_pairs, is_classifier, tree, tree_id, add_node(attr_pairs, is_classifier, tree_id, tree_weight, node_id, feat_id, mode, threshold, left_child_id, right_child_id, - weight, weight_id_bias, leaf_weights_are_counts) + weight, weight_id_bias, leaf_weights_are_counts, + adjust_threshold_for_sklearn=adjust_threshold_for_sklearn, + dtype=dtype) diff --git a/skl2onnx/convert.py b/skl2onnx/convert.py index fe6592d32..77ad763bd 100644 --- a/skl2onnx/convert.py +++ b/skl2onnx/convert.py @@ -27,7 +27,8 @@ def convert_sklearn(model, name=None, initial_types=None, doc_string='', For pipeline conversion, user needs to make sure each component is one of our supported items. - This function converts the specified *scikit-learn* model into its *ONNX* counterpart. + This function converts the specified *scikit-learn* model + into its *ONNX* counterpart. Note that for all conversions, initial types are required. *ONNX* model name can also be specified. diff --git a/skl2onnx/operator_converters/ada_boost.py b/skl2onnx/operator_converters/ada_boost.py index d2f7bfc8a..926d54632 100644 --- a/skl2onnx/operator_converters/ada_boost.py +++ b/skl2onnx/operator_converters/ada_boost.py @@ -160,7 +160,7 @@ def convert_sklearn_ada_boost_classifier(scope, operator, container): attrs['classlabels_strings'] = classes add_tree_to_attribute_pairs(attrs, True, op.estimators_[tree_id].tree_, - 0, 1, 0, True) + 0, 1, 0, True, True, dtype=container.dtype) container.add_node( op_type, operator.input_full_names, [label_name, proba_name], @@ -234,7 +234,8 @@ def _get_estimators_label(scope, operator, container, model): attrs['n_targets'] = int(model.estimators_[tree_id].n_outputs_) add_tree_to_attribute_pairs(attrs, False, model.estimators_[tree_id].tree_, - 0, 1, 0, False) + 0, 1, 0, False, True, + dtype=container.dtype) container.add_node(op_type, input_name, estimator_label_name, op_domain='ai.onnx.ml', diff --git a/skl2onnx/operator_converters/decision_tree.py b/skl2onnx/operator_converters/decision_tree.py index 7c776c812..a2aa70f07 100644 --- a/skl2onnx/operator_converters/decision_tree.py +++ b/skl2onnx/operator_converters/decision_tree.py @@ -36,7 +36,8 @@ def convert_sklearn_decision_tree_classifier(scope, operator, container): else: raise ValueError('Labels must be all integers or all strings.') - add_tree_to_attribute_pairs(attrs, True, op.tree_, 0, 1., 0, True) + add_tree_to_attribute_pairs(attrs, True, op.tree_, 0, 1., 0, True, + True, dtype=container.dtype) container.add_node( op_type, operator.input_full_names, @@ -51,7 +52,8 @@ def convert_sklearn_decision_tree_regressor(scope, operator, container): attrs = get_default_tree_regressor_attribute_pairs() attrs['name'] = scope.get_unique_operator_name(op_type) attrs['n_targets'] = int(op.n_outputs_) - add_tree_to_attribute_pairs(attrs, False, op.tree_, 0, 1., 0, False) + add_tree_to_attribute_pairs(attrs, False, op.tree_, 0, 1., 0, False, + True, dtype=container.dtype) input_name = operator.input_full_names if type(operator.inputs[0].type) == Int64TensorType: diff --git a/skl2onnx/operator_converters/gradient_boosting.py b/skl2onnx/operator_converters/gradient_boosting.py index 8142fe7c1..cbc7064ab 100644 --- a/skl2onnx/operator_converters/gradient_boosting.py +++ b/skl2onnx/operator_converters/gradient_boosting.py @@ -92,14 +92,16 @@ def convert_sklearn_gradient_boosting_classifier(scope, operator, container): for tree_id in range(n_est): tree = op.estimators_[tree_id][0].tree_ add_tree_to_attribute_pairs(attrs, True, tree, tree_id, - tree_weight, 0, False) + tree_weight, 0, False, True, + dtype=container.dtype) else: for i in range(n_est): for c in range(op.n_classes_): tree_id = i * op.n_classes_ + c tree = op.estimators_[i][c].tree_ add_tree_to_attribute_pairs(attrs, True, tree, tree_id, - tree_weight, c, False) + tree_weight, c, False, True, + dtype=container.dtype) container.add_node( op_type, operator.input_full_names, @@ -138,7 +140,7 @@ def convert_sklearn_gradient_boosting_regressor(scope, operator, container): tree = op.estimators_[i][0].tree_ tree_id = i add_tree_to_attribute_pairs(attrs, False, tree, tree_id, tree_weight, - 0, False) + 0, False, True, dtype=container.dtype) input_name = operator.input_full_names if type(operator.inputs[0].type) == Int64TensorType: diff --git a/skl2onnx/operator_converters/random_forest.py b/skl2onnx/operator_converters/random_forest.py index 5cf762a1a..6cddff098 100644 --- a/skl2onnx/operator_converters/random_forest.py +++ b/skl2onnx/operator_converters/random_forest.py @@ -55,7 +55,8 @@ def convert_sklearn_random_forest_classifier(scope, operator, container): for tree_id in range(estimtator_count): tree = op.estimators_[tree_id].tree_ add_tree_to_attribute_pairs(attr_pairs, True, tree, tree_id, - tree_weight, 0, True) + tree_weight, 0, True, True, + dtype=container.dtype) container.add_node( op_type, operator.input_full_names, @@ -78,7 +79,8 @@ def convert_sklearn_random_forest_regressor_converter(scope, for tree_id in range(estimtator_count): tree = op.estimators_[tree_id].tree_ add_tree_to_attribute_pairs(attrs, False, tree, tree_id, - tree_weight, 0, False) + tree_weight, 0, False, True, + dtype=container.dtype) input_name = operator.input_full_names if type(operator.inputs[0].type) == Int64TensorType: diff --git a/tests/test_sklearn_random_forest_converters.py b/tests/test_sklearn_random_forest_converters.py index 3f84c299c..ff9a902d7 100644 --- a/tests/test_sklearn_random_forest_converters.py +++ b/tests/test_sklearn_random_forest_converters.py @@ -6,10 +6,12 @@ import unittest import numpy -from sklearn.ensemble import RandomForestClassifier -from sklearn.ensemble import RandomForestRegressor -from sklearn.ensemble import ExtraTreesClassifier -from sklearn.ensemble import ExtraTreesRegressor +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from sklearn.ensemble import ( + RandomForestClassifier, RandomForestRegressor, + ExtraTreesClassifier, ExtraTreesRegressor +) from skl2onnx.common.data_types import onnx_built_with_ml, FloatTensorType from test_utils import ( dump_one_class_classification, @@ -67,6 +69,27 @@ def test_random_forest_classifier_mismatched_estimator_counts(self): model.__class__.__name__ + '_mismatched_estimator_counts') + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") + def test_random_forest_regressor_mismatches(self): + iris = load_iris() + X, y = iris.data, iris.target + X_train, X_test, y_train, _ = train_test_split( + X, y, random_state=13) + X_test = X_test.astype(numpy.float32) + clr = RandomForestRegressor(n_jobs=1, n_estimators=100) + clr.fit(X_train, y_train) + clr.fit(X, y) + model_onnx, prefix = convert_model(clr, 'reg', + [('input', + FloatTensorType([None, 4]))]) + dump_data_and_model(X_test, clr, model_onnx, + basename=prefix + "RegMis" + + clr.__class__.__name__ + + '_mismatched_estimator_counts') + + @unittest.skipIf(not onnx_built_with_ml(), + reason="Requires ONNX-ML extension.") def test_random_forest_regressor(self): model = RandomForestRegressor(n_estimators=3) dump_single_regression(