-
Notifications
You must be signed in to change notification settings - Fork 14k
CUDA: fix FA VKQ accumulator overflow #17746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: fix FA VKQ accumulator overflow #17746
Conversation
|
Wow, there is a 0% chance that I would have found my way to this solution. Aside from knowing the |
|
Confirmed that this does solve the bug on my GB10, including the issue mentioned in #17610 (comment) where subsequent short-context requests also exhibit error behavior. |
|
I was not aware of similar issues with other models or with our tests in |
|
Thanks! We have been seeing numerical overflow issues with a number of these smaller And now that I type this, I wonder if this patch might also fix some of those issues. I'll dig a bit. |
ggerganov
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit worried that this change might have some side effects - not convinced that the perplexity test is enough. AFAIU the larger the context, the smaller the values would be. Perplexity runs with a fixed context of ~2048 tokens.
Using F32 accumulators seems like the proper solution, but it could affect the performance. So not really sure.
In principle one could do the softmax in FA by calculating the exponentials of the raw KQ values. But since that has a tendency to result in numerical overflows one instead subtracts the KQ maximum (so far) from all KQ values. This does (beyond numerical differences) not affect the final result but it forces all input values for the VKQ matrix multiplication to be <= 1. But the value that is being subtracted is essentially arbitrary. With this patch an extra 2.079 are being subtracted from each KQ value so post exponentiation the input values for the VKQ matrix multiplication will be <= 0.125. So there is a higher tolerance for numerically large values in the V cache. This is a constant factor and the context size in not relevant in that regard. The only way in which the context size could be relevant is if small, additive contributions post softmax have a large impact on the total result of VKQ. But this seems unlikely to me since the V cache seems to be relatively robust against quantization, i.e. noise. And the values post softmax typically have a few very large values and lots of extremely small values that the model is trying to push towards 0 anyways. For training the patch in this PR would be problematic since it would be zeroing out gradients but for inference I think it will be fine. I chose the current offset as "conservative" in the sense of avoiding overflows but if you are concerned that this changes the results too much we could also go with a lower offset. An offset of 0.6931 shifts the representable range by a factor of 2 and is already enough to fix the issue for Granite 4 1b + the test prompt. It would in principle also be possible to define these values in the GGUF files and to set it at runtime.
It is actually a context size of 512 tokens. With a context size of 8192 tokens the LLaMA 3 8b perplexity changes from 5.6651 to 5.6645 with an offset of 2.079 or to 5.6643 with an offset of 0.6931.
Each CUDA thread can at most use 255 registers. The MMA kernel runs on 64 Q/VKQ columns in its largest configuration. For a head size of 128 this requires 64 registers simply to store the Q input and VKQ output values during the kernel. With FP32 VKQ accumulators that would increase to 96 registers. For a head size of 256 each thread currently needs 128 registers and the kernel is still viable (with limitations). So for head sizes <= 128 FP32 accumulators could probably be made to work but not for head sizes 256 and 576/512. So it would not be a universal solution. |
Yes, I agree with your analysis. This point here is my main concern. During training, the model didn't see the values flushed to zero. I understand the argument that these values are likely negligible, but at the same time I can't say I have a good intuition of how the ensemble of many small attentions would add up and the role that they would play in the end. Hence the concern.
Yeah, I agree - same issue in the Metal backend. I decided to still go with the F32 accumulators regardless, at the price that the performance of some models such as Gemma (head size = 256) is affected negatively. |
1dd1272 to
59e6cba
Compare
|
For now I've reduced the offset to 0.6931 (factor of 2 post softmax). As long as people like @gabe-l-hart and the llama.cpp maintainers are aware that this is a possible issue in terms of numerics the time lost on debugging should be comparatively small and we can adjust the offset if necessary. There are no tensor cores for (FP16, FP16) -> BF16 math but we could in principle emulate it by doing (FP16, FP16) -> FP32 math in combination with bitwise operators. Then the issue of the numerical range would be fixed but we would in return have possible issues with reduced numerical precision in the VKQ accumulators. |
Fixes #17610 .
The problem is that that particular model is suffering from numerical overflow in the FP16 VKQ accumulators. In principle the use of FP32 of BF16 accumulators could fix the issue but that would be problematic in terms of either register pressure or lack of hardware support. For this reason I think the least bad option is to apply an offset to the KQ maximum used as the scale in the softmax. By adding$\ln 8$ to the maximum both the VKQ accumulators and the KQ sums are effectively being reduced by a factor of 8. So the range of representable values is in turn being shifted upwards by a factor of 8. The downside is that larger values will be flushed to 0 (goes up from $2^{-14}$ to $2^{-11}$ ). However, this effect should be negligible for the model outputs. This PR changes the LLaMa 3 8b q4_0 perplexity over the Wikitext 2 test set from 6.717177 to 6.717161.