-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix base_values in TreeEnsembleRegressor #5518
Conversation
8baad67
to
7a65962
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the lint errors. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
Head branch was pushed to by a user without write access
Lint errors should be fixed now according to local lintrunner. |
dfcc951
to
aef8e0c
Compare
@parameterized.expand( | ||
[ | ||
(f"{agg}_{base_value}", base_value, agg) | ||
for base_value in (None, [1.0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use itertools.product
for this. A double for loop in list comprehension can be confusing: https://google.github.io/styleguide/pyguide.html#274-decision
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justinchuby I don't know how to do this in a clear way with itertools.product because of the name
variable at the beginning. I do want to keep the name variable, since I think that is the whole point of the parameterization - that is to more clearly see any failures directly from the test name.
Could you rebase on main to remove the relevant commits? Thanks! |
I tried a rebase on main but it still shows all the other merges / updates as commits. Not sure how to fix. It may require a new branch / PR? The signed commits + rebasing the code seems to be creating problems. |
I usually do this: branch this out to a new branch, reset this current branch to upstream/main, then squash merge the new branch to this branch; force push. |
dae93e6
to
db2227a
Compare
db2227a
to
bd60a39
Compare
Description
In the reference implementation of TreeEnsembleRegressor when a value is provided for the
base_values
argument, this value replaces any prediction. This can't be what's intended and does not match onnxruntime. Instead, thebase_value
should be added to the prediction after applying any aggregation.Motivation and Context
We are exporting regression trees into ONNX that have a non-zero
base_value
as the baseline prediction for the tree. Prediction works as expected in onnxruntime but not in the reference implementation. I believe this is an oversight and propose the fix (plus tests) below. I also think the documentation should be more explicit about what thebase_values
argument does.