New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making softmax numerically stable #133
Conversation
Hi @joelshepherd! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
thanks! this is definitely a better approach. With respect to gradient I think you need to mask out the non-max element indices (https://discuss.pytorch.org/t/differentiable-argmax/33020). As a reference, here's how the gradient state works: #113 (comment) |
Thanks for the pointers! I solved what was missing from my earlier attempt with the help of those tests. Side note: if in future we can index by tensor, a potentially more performant solution might be: ctx.backwards_input.index(ctx.forward_input[0].argmax(...)) But for now, mask + sum seems to do the trick. This successfully converges my test models too. |
@jacobkahn can you index by tensor in FL? I guess we should expose this! @joelshepherd looks great, tests are passing :) mind signing the CLA? https://code.facebook.com/cla I can land right after that |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Looks like the bot just took a while to update the label. All good now! |
@bwasti -- index by tensor is supported in Flashlight, yep. Indexing by multiple tensors does an outer product index (rather than a Cartesian product index), although this is behavior we're looking to change. Overall, indexing by a single tensor behave as you'd expect. cc @StrongerXi if I'm missing anything. |
Ah one thing -- we don't have index by tensor support in OneDNN backend yet, because afaik OneDNN doesn't expose any primitives to help us with that. |
Modified the softmax function to be numerically stable with large exponents. Method taken from here.
I am fairly new to autodiff gradient functions, so my implementation of
amax
may be way off the mark (it certainly looks wrong).I originally wrote the below code based on the min/max gradient functions that already exist, but it would not converge my test model (where the current implementation does).