Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN loss when using attention mask or not using attention mask but training for long time #6

Closed
annasajkh opened this issue Aug 28, 2022 · 6 comments

Comments

@annasajkh
Copy link

even when i adjust the scaling in the FastAttention class the loss is still become NaN after training for few hours

@annasajkh
Copy link
Author

Screenshot from 2022-08-28 15-02-21

@mtanghu
Copy link
Owner

mtanghu commented Aug 28, 2022

Hi! Thanks for taking interest in this project, and this should be fixed by today, or at worst tomorrow. I'm currently working on some issues with loss going to NaN, which may actually be relevant to attention as a whole. (If you're interested, it seems like with linear Attention especially, it is very easy for the attention scores to go off the charts and NaN out when the model "memorizes" that an attention score should be huge).

Just some basic info:

  • Are your own training setup with your own dataset?
  • Are you using a larger model size than 128?
  • Are you using windowed/local attention?

Especially if the answer to the last question is yes, then almost certainly I'm working on that exact issue. Comprehensive answers shouldn't be needed.

@mtanghu
Copy link
Owner

mtanghu commented Aug 28, 2022

Okay I have the fix! I'll just run a little bit more testing and it'll be ready by today!

The idea is pretty fun, and I'll explain it by talking about normal full attention with QKV. Normally scaling of attention is done because the standard deviation of attention scores scales with the model dimension. However this only considers when attention scores are random normal. In truth, as the parameters are updated the attention scores are going to be less random to the point where Q_i == K_i (when the alignment should be strong), and thus Q_i dot K_i will have a mean of d_model and standard deviation of 0!

To counter this, we will still scale the standard deviation by d_model and NOT sqrt(d_model), but since that would max out the similarity score at 1, we will also multiply by a set constant of 9. (9 was picked for numerical stability preliminary where preliminary testing found that 10 lead to instability and divergence).

Also we get rid of LayerNorm and replace it with a manual norming (i.e. de-meaning and dividing by standard deviation) as preliminary testing found that the parameters in the LayerNorm would get big enough to where the attention scores would explode. We then also apply this manual norming to the query and keys.

@mtanghu
Copy link
Owner

mtanghu commented Aug 29, 2022

@annasajkh, the most recent merge #5 implements a number of changes including fixing any possible instability that could cause loss to diverge.

Note the this project is rebranding to be called "LEAP" (Linear Explainable Attention In Parallel) as the merge also introduces a new LEAP Transformer that moves away from Additive Attention (but has all the same benefits, just more representational complexity). However, the full changes (mostly just to the readme and some testing) are still being performed, so I apologize if the READMEs are confusing right now.

We will maintain support for Fastformer (because it was the inspiration for this project) which is 'full functioning' and 'tested', now with lower loss. You will need to run:

pip install leap-transformer

and change your import to

from leap import FastformerForCausalLM, FastformerLMConfig

Please let me know if this works for you! If you're willing to answer, may I ask your reason for interest in this project?

@mtanghu
Copy link
Owner

mtanghu commented Sep 3, 2022

@annasajkh Just a quick update, the new README and LEAP code release is finished if you'd like to check it out. This release also improves upon the training stability! (by adding numerical stability terms to denominators)

@mtanghu
Copy link
Owner

mtanghu commented Sep 3, 2022

I'll go ahead and close this issue given the recent updates that address this directly, but please open a new one if you encounter NaN loss/any other issues!

@mtanghu mtanghu closed this as completed Sep 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants