Skip to content

Commit

Permalink
Use n_estimators_ when set in GradientBoostingClassifier (#213)
Browse files Browse the repository at this point in the history
* Use n_estimators_ when set in GradientBoostingClassifier
  • Loading branch information
Prabhat authored and xadupre committed Jul 11, 2019
1 parent 01b37e3 commit b5333f5
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions skl2onnx/operator_converters/gradient_boosting.py
Expand Up @@ -70,13 +70,15 @@ def convert_sklearn_gradient_boosting_classifier(scope, operator, container):
raise ValueError('Labels must be all integer or all strings.')

tree_weight = op.learning_rate
n_est = (op.n_estimators_ if hasattr(op, 'n_estimators_') else
op.n_estimators)
if op.n_classes_ == 2:
for tree_id in range(op.n_estimators):
for tree_id in range(n_est):
tree = op.estimators_[tree_id][0].tree_
add_tree_to_attribute_pairs(attrs, True, tree, tree_id,
tree_weight, 0, False)
else:
for i in range(op.n_estimators):
for i in range(n_est):
for c in range(op.n_classes_):
tree_id = i * op.n_classes_ + c
tree = op.estimators_[i][c].tree_
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sklearn_adaboost_converter.py
Expand Up @@ -80,7 +80,7 @@ def test_ada_boost_regressor(self):
X,
model,
model_onnx,
basename="SklearnAdaBoostRegressor-OneOffArray",
basename="SklearnAdaBoostRegressor-OneOffArray-Dec4",
allow_failure="StrictVersion("
"onnxruntime.__version__) "
"<= StrictVersion('0.2.1') or "
Expand Down

0 comments on commit b5333f5

Please sign in to comment.