-
Notifications
You must be signed in to change notification settings - Fork 13.9k
CUDA: no FP16 arithmetic for vector FA kernel #17558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: no FP16 arithmetic for vector FA kernel #17558
Conversation
CISC
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why, but this seems to have improved 3 of the perf tests (the others are unaffected) on my 3090Ti:
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): 41727 runs - 26.13 us/run - 16.78 MFLOP/run - 641.98 GFLOPS
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): 23848 runs - 43.61 us/run - 33.55 MFLOP/run - 769.51 GFLOPS
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): 13419 runs - 78.54 us/run - 67.11 MFLOP/run - 854.41 GFLOPS
vs
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): 41727 runs - 24.51 us/run - 16.78 MFLOP/run - 684.44 GFLOPS
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): 23848 runs - 42.17 us/run - 33.55 MFLOP/run - 795.63 GFLOPS
FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]): 13419 runs - 77.16 us/run - 67.11 MFLOP/run - 869.79 GFLOPS
|
Okay, what I said regarding which GPUs are affected was strictly speaking not 100% correct. Volta, Turing, and Ampere do theoretically also have some cases where the vector FA kernel is used but only if a model isn't using GQA. And nowadays essentially all of them do. So I didn't consider those cases for performance testing. |
|
I have a RX 5700XT (RDNA1) that I could use to test, however it's been a while since I last used the HIP backend and flash attention just throws an error with this card on both master and this PR: I don't know if this issue is ROCm's fault (RDNA1 has never been officially supported, but it used to work until some time ago, I now run version 7.1) or if this issue came with some update inside the GGML backend itself. |
|
Can confirm that the fix works, GLM4 is now coherent again. |
This reverts commit 73955f7.
Fixes #17549 .
The problem is that while the CUDA FlashAttention vector kernel is always using FP32 for the KQ accumulation the intermediate multiplication results can still be FP16 if the GPU is lacking the
v_dot_f32_f16instruction for FP16 dot products with FP32 accumulation. For models such as GLM 4 32b numerical overflows can then occur. This PR makes it so that GPUs withoutv_dot_f32_f16use the FP32 implementation instead. Due to the kernel selection logic the only affected GPUs are NVIDIA GPUs Ada Lovelace or newer and AMD GPUs RDNA1 or older. With FP32 arithmetic the performance on my GPUs is essentially unchanged:I don't have an RDNA1 GPU for testing but only the rare models without GQA should be affected and even if there is a small performance regression it's more important to ensure correct results.