fix: eliminate RuntimeWarnings in von Mises-Fisher loss backward pass#824
fix: eliminate RuntimeWarnings in von Mises-Fisher loss backward pass#824carlosm-silva merged 2 commits intographnet-team:mainfrom
Conversation
- Replace np.where with boolean masking to avoid double evaluation - Add comprehensive unit tests for zero handling in LogCMK.backward - Enhanced docstrings with mathematical background and implementation details - Added mathematical background documentation in docs/source/models/ Fixes division by zero warnings while maintaining numerical accuracy. Error bound analysis shows <1e-21 accuracy for |κ| < 1e-6.
shubhamos-ai
left a comment
There was a problem hiding this comment.
📚 Well-structured code! Consider adding more detailed documentation, implementing CI/CD pipelines, and adding performance monitoring for production readiness.
RasmusOrsoe
left a comment
There was a problem hiding this comment.
Hey @carlosm-silva thank you very much for this clean contribution!
I did a few checks on the compatibility of the gradients from this change in backward w.r.t. to the current implementation and was happy to find that they agree except for kappa=0.
When I took a closer look at the documentation introduced in the PR, I found that the math appears to be rendered incorrectly (see here). Could you double-check that it's OK?
Tagging @Aske-Rosted for completeness. @Aske-Rosted Do you have comments?
| @@ -0,0 +1,200 @@ | |||
| # Mathematical Background: von Mises-Fisher Loss Implementation | |||
There was a problem hiding this comment.
Most of the math here appears not to render correctly in the markdown file.
There was a problem hiding this comment.
Interestingly, it renders correctly when I copy and paste it into my local Markdown environment (Obsidian). The minus sign is being interpreted as the start of a list. I'm unfamiliar with GitHub's Markdown notation, so I'm unsure how to proceed.
Corrected mathematical expressions in the von Mises-Fisher documentation so they are properly rendered in Github markdown.
|
@carlosm-silva thank you for the quick iteration! The docs now look much better. I have no further comments at this stage. Lets give @Aske-Rosted a chance to catch up before we proceed |
|
Hi @carlosm-silva, Thanks for this very well documented contribution! |
🎯 Summary
Fixes RuntimeWarnings in
LogCMK.backwardwhen processing arrays containing zero values, while maintaining mathematical accuracy and numerical stability.🔬 Mathematical Background
This PR implements a numerically stable solution for computing gradients in the von Mises-Fisher loss function. The core issue was that while the mathematical limit:
is well-defined, naive floating-point evaluation triggers division by zero warnings.
Detailed mathematical derivations and proofs are provided in:
src/graphnet/training/loss_functions.pydocs/source/models/von_mises_fisher_mathematical_background.md🔧 Changes
Code Changes
np.where()with boolean masking to avoid double evaluationtest_logcmk_backward_zero_handling()with warning captureImplementation Details
-κ/3for|κ| < 1e-61/κ - 1/tanh(κ)≤ |κ|³/45 ≲ O(10⁻²¹)for threshold valuesDocumentation
docs/source/models/models.rst🧪 Testing
Test Coverage
f(0) = 0exactly📊 Before/After
Before:
After:
🔗 References
✅ Checklist
Mathematical accuracy preserved • Numerical stability achieved • Warnings eliminated