Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH more stable gradient of CrossEntropy #6327

Merged
merged 9 commits into from
Feb 22, 2024

Conversation

lorentzenchr
Copy link
Contributor

@lorentzenchr lorentzenchr commented Feb 18, 2024

Similar to scikit-learn/scikit-learn#28048.

There is a small runtime cost to pay, but gradient computation is not the main bottleneck of histogram based gradient boosting.

@jameslamb
Copy link
Collaborator

Thanks for this! I'll defer to @shiyu1994 and @guolinke to review.

Until then, can you please update this to the latest master?

Copy link
Collaborator

@shiyu1994 shiyu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. I just left a few comments about the correction of hessian computation.

if (score[i] > -37.0) {
const double exp_tmp = std::exp(-score[i]);
gradients[i] = static_cast<score_t>(((1.0f - label_[i]) - label_[i] * exp_tmp) / (1.0f + exp_tmp));
hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) * (1 + exp_tmp));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The / followed by a * simply returns exp_tmp, which is not the expected hessian.

Suggested change
hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) * (1 + exp_tmp));
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp)));

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that hessians[i] = static_cast<score_t>(exp_tmp / (1 + exp_tmp) / (1 + exp_tmp)); could be more numerically stable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First for all, yes I forgot the parenthesis. Thanks for spotting it. It is surprising that still all the tests pass (with this bug).

Then (exp_tmp / (1 + exp_tmp) / (1 + exp_tmp)) is more numerical stable in the sense that it could prevent overflow. But exp_tmp > exp(37) = 1e16 and squaring that is within even single precision (3e38), and note that exp_tmp is even double precision.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Could you also fix the hessian calculation in the else branch?

} else {
const double exp_tmp = std::exp(score[i]);
gradients[i] = static_cast<score_t>(exp_tmp - label_[i]);
hessians[i] = static_cast<score_t>(exp_tmp);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hessians[i] = static_cast<score_t>(exp_tmp);
hessians[i] = static_cast<score_t>(exp_tmp / ((1 + exp_tmp) * (1 + exp_tmp)));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed as exp_tmp < 1e-16 is tiny and (1 + exp_tmp) is just 1. Otherwise stated, the implemented formula is the 1st order Taylor series in exp_tmp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. That makes sense.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But maybe it would still be better to write the original calculation formula explicitly to avoid ambiguity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean with "ambiguity"?
It would not avoid the branch and is a tiny bit more efficient.

src/objective/xentropy_objective.hpp Outdated Show resolved Hide resolved
src/objective/xentropy_objective.hpp Show resolved Hide resolved
@shiyu1994 shiyu1994 merged commit 894066d into microsoft:master Feb 22, 2024
43 checks passed
@lorentzenchr lorentzenchr deleted the stable_cross_entropy branch February 24, 2024 12:15
@jameslamb jameslamb mentioned this pull request May 1, 2024
33 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants