Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making softmax numerically stable #133

Merged
merged 2 commits into from Dec 6, 2022

Conversation

joelshepherd
Copy link
Contributor

Modified the softmax function to be numerically stable with large exponents. Method taken from here.

I am fairly new to autodiff gradient functions, so my implementation of amax may be way off the mark (it certainly looks wrong).

I originally wrote the below code based on the min/max gradient functions that already exist, but it would not converge my test model (where the current implementation does).

const mask = ctx.forward_inputs[0]
  .eq(ctx.forward_output)
  .astype(ctx.backward_input.dtype);
return ctx.backward_input.mul(mask)

@facebook-github-bot
Copy link

Hi @joelshepherd!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@bwasti
Copy link
Contributor

bwasti commented Dec 5, 2022

thanks! this is definitely a better approach.

With respect to gradient I think you need to mask out the non-max element indices (https://discuss.pytorch.org/t/differentiable-argmax/33020).
To help with debugging, you can add automated check for amax here: https://github.com/facebookresearch/shumai/blob/main/test/gradient.test.ts#L76

As a reference, here's how the gradient state works: #113 (comment)

@joelshepherd
Copy link
Contributor Author

Thanks for the pointers! I solved what was missing from my earlier attempt with the help of those tests.

Side note: if in future we can index by tensor, a potentially more performant solution might be:

ctx.backwards_input.index(ctx.forward_input[0].argmax(...))

But for now, mask + sum seems to do the trick. This successfully converges my test models too.

@bwasti
Copy link
Contributor

bwasti commented Dec 6, 2022

Side note: if in future we can index by tensor, a potentially more performant solution might be:

@jacobkahn can you index by tensor in FL? I guess we should expose this!

@joelshepherd looks great, tests are passing :) mind signing the CLA? https://code.facebook.com/cla I can land right after that

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 6, 2022
@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@joelshepherd
Copy link
Contributor Author

Looks like the bot just took a while to update the label. All good now!

@jacobkahn
Copy link
Member

@bwasti -- index by tensor is supported in Flashlight, yep. Indexing by multiple tensors does an outer product index (rather than a Cartesian product index), although this is behavior we're looking to change.

Overall, indexing by a single tensor behave as you'd expect. cc @StrongerXi if I'm missing anything.

@bwasti bwasti merged commit 326e730 into facebookresearch:main Dec 6, 2022
@StrongerXi
Copy link

@bwasti -- index by tensor is supported in Flashlight, yep. Indexing by multiple tensors does an outer product index (rather than a Cartesian product index), although this is behavior we're looking to change.

Overall, indexing by a single tensor behave as you'd expect. cc @StrongerXi if I'm missing anything.

Ah one thing -- we don't have index by tensor support in OneDNN backend yet, because afaik OneDNN doesn't expose any primitives to help us with that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants