Skip to content

Commit

Permalink
Fixed GradientBoostingRegression model conversion failure with init=z…
Browse files Browse the repository at this point in the history
…ero (#164)

* Fixed GradientBoostingRegression model conversion failure with init=zero

* Loop on fit n_estimators_ instead of passed n_estimators

* Added corresponding unit test

* Fix error in scikit < 0.21
  • Loading branch information
Prabhat committed Jun 17, 2019
1 parent 1204066 commit 200cb00
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
19 changes: 12 additions & 7 deletions skl2onnx/operator_converters/gradient_boosting.py
Expand Up @@ -96,17 +96,22 @@ def convert_sklearn_gradient_boosting_regressor(scope, operator, container):
attrs['name'] = scope.get_unique_operator_name(op_type)
attrs['n_targets'] = 1

# constant_ was introduced in scikit-learn 0.21.
if hasattr(op.init_, 'constant_'):
cst = [float(x) for x in op.init_.constant_]
elif op.loss == 'ls':
cst = [op.init_.mean]
if op.init == 'zero':
cst = np.zeros((operator.inputs[0].type.shape[0], op.loss_.K))
else:
cst = [op.init_.quantile]
# constant_ was introduced in scikit-learn 0.21.
if hasattr(op.init_, 'constant_'):
cst = [float(x) for x in op.init_.constant_]
elif op.loss == 'ls':
cst = [op.init_.mean]
else:
cst = [op.init_.quantile]
attrs['base_values'] = [float(x) for x in cst]

tree_weight = op.learning_rate
for i in range(op.n_estimators):
n_est = (op.n_estimators_ if hasattr(op, 'n_estimators_') else
op.n_estimators)
for i in range(n_est):
tree = op.estimators_[i][0].tree_
tree_id = i
add_tree_to_attribute_pairs(attrs, False, tree, tree_id, tree_weight,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_sklearn_gradient_boosting_converters.py
Expand Up @@ -163,6 +163,25 @@ def test_gradient_boosting_regressor_int(self):
" <= StrictVersion('0.2.1')"
)

def test_gradient_boosting_regressor_zero_init(self):
model, X = fit_regression_model(
GradientBoostingRegressor(n_estimators=30, init="zero",
random_state=42))
model_onnx = convert_sklearn(
model,
"gradient boosting regression",
[("input", FloatTensorType([1, X.shape[1]]))],
)
self.assertIsNotNone(model_onnx)
dump_data_and_model(
X,
model,
model_onnx,
basename="SklearnGradientBoostingRegressionZeroInit-Dec4",
allow_failure="StrictVersion(onnxruntime.__version__)"
" <= StrictVersion('0.2.1')"
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 200cb00

Please sign in to comment.