Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix base_values in TreeEnsembleRegressor #5518

Closed
wants to merge 0 commits into from

Conversation

corwinjoy
Copy link
Contributor

Description

In the reference implementation of TreeEnsembleRegressor when a value is provided for the base_values argument, this value replaces any prediction. This can't be what's intended and does not match onnxruntime. Instead, the base_value should be added to the prediction after applying any aggregation.

Motivation and Context

We are exporting regression trees into ONNX that have a non-zero base_value as the baseline prediction for the tree. Prediction works as expected in onnxruntime but not in the reference implementation. I believe this is an oversight and propose the fix (plus tests) below. I also think the documentation should be more explicit about what the base_values argument does.

@corwinjoy corwinjoy requested review from a team as code owners August 23, 2023 01:58
Copy link
Contributor

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the lint errors. Thanks!

onnx/test/reference_evaluator_ml_test.py Outdated Show resolved Hide resolved
Copy link
Contributor

@gramalingam gramalingam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@xadupre xadupre enabled auto-merge August 29, 2023 07:57
auto-merge was automatically disabled September 5, 2023 18:17

Head branch was pushed to by a user without write access

@corwinjoy
Copy link
Contributor Author

Lint errors should be fixed now according to local lintrunner.

@corwinjoy corwinjoy force-pushed the patch_tree_reg_base_values branch 2 times, most recently from dfcc951 to aef8e0c Compare September 5, 2023 19:42
@corwinjoy corwinjoy changed the base branch from main to 1.13.1-protobuf4.21 September 5, 2023 20:00
@corwinjoy corwinjoy requested a review from a team as a code owner September 5, 2023 20:00
@corwinjoy corwinjoy changed the base branch from 1.13.1-protobuf4.21 to main September 5, 2023 20:00
@parameterized.expand(
[
(f"{agg}_{base_value}", base_value, agg)
for base_value in (None, [1.0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use itertools.product for this. A double for loop in list comprehension can be confusing: https://google.github.io/styleguide/pyguide.html#274-decision

Copy link
Contributor Author

@corwinjoy corwinjoy Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchuby I don't know how to do this in a clear way with itertools.product because of the name variable at the beginning. I do want to keep the name variable, since I think that is the whole point of the parameterization - that is to more clearly see any failures directly from the test name.

@justinchuby justinchuby added the auto update doc Generate md/proto files automatically using the CI pipeline label Sep 6, 2023
@justinchuby
Copy link
Contributor

Could you rebase on main to remove the relevant commits? Thanks!

@justinchuby justinchuby added operator Issues related to ONNX operators and removed auto update doc Generate md/proto files automatically using the CI pipeline labels Sep 6, 2023
@corwinjoy
Copy link
Contributor Author

corwinjoy commented Sep 6, 2023

I tried a rebase on main but it still shows all the other merges / updates as commits. Not sure how to fix. It may require a new branch / PR? The signed commits + rebasing the code seems to be creating problems.

@justinchuby
Copy link
Contributor

justinchuby commented Sep 6, 2023

I tried a rebase on main but it still shows all the other merges / updates as commits. Not sure how to fix. It may require a new branch / PR? The signed commits + rebasing the code seems to be creating problems.

I usually do this: branch this out to a new branch, reset this current branch to upstream/main, then squash merge the new branch to this branch; force push.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
operator Issues related to ONNX operators
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants