-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Update keras3 Softmax mask handling to be more numerically robust. #21850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @yangliyl-g, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request improves the numerical robustness of softmax mask handling by replacing an arithmetic approach with backend.numpy.where. This is a solid improvement that also correctly handles floating-point masks by binarizing them. A key benefit of this change is the removal of in-place modification of the inputs tensor, which avoids potential side effects and aligns with best practices for writing neural network layers. I have one minor suggestion to correct a typo in a comment for clarity. Overall, this is a great change.
| ) * _large_negative_number(inputs.dtype) | ||
| inputs += adder | ||
| # We keep the positions where the mask is True or > 0.5, and set the | ||
| # other (masked) positions to -1e.9. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a small typo in this comment. -1e.9 appears to be a typo for -1e9. To improve clarity and accuracy, especially since _large_negative_number can return different values based on dtype, I suggest making the comment more general.
| # other (masked) positions to -1e.9. | |
| # other (masked) positions to a large negative number. |
hertschuh
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the improvement!
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #21850 +/- ##
==========================================
- Coverage 82.48% 82.47% -0.01%
==========================================
Files 577 577
Lines 59506 59508 +2
Branches 9330 9332 +2
==========================================
- Hits 49084 49082 -2
- Misses 8010 8014 +4
Partials 2412 2412
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Oh, you'll need to reformat the code.
Otherwise the real way is: |
The original impl of masking is essentially
(1.0 - mask) * -1e-9 + inputs; this can be sensitive to numerical noise onmask(imagine if it is slightly off from either 1 or 0; then we would be adding a very large perturbation to the inputs).Using comparison and where ops are much more numerically robust. Also for cases where we get floating point masks coming in, we add a binarization step to make it compatible w/ the
whereop.