Skip to content

Conversation

@JohannesGaessler
Copy link
Collaborator

Fixes #17549 .

The problem is that while the CUDA FlashAttention vector kernel is always using FP32 for the KQ accumulation the intermediate multiplication results can still be FP16 if the GPU is lacking the v_dot_f32_f16 instruction for FP16 dot products with FP32 accumulation. For models such as GLM 4 32b numerical overflows can then occur. This PR makes it so that GPUs without v_dot_f32_f16 use the FP32 implementation instead. Due to the kernel selection logic the only affected GPUs are NVIDIA GPUs Ada Lovelace or newer and AMD GPUs RDNA1 or older. With FP32 arithmetic the performance on my GPUs is essentially unchanged:

GPU Model Microbatch size Test t/s master t/s b13fcf8 Speedup
RTX 4090 llama 8B Q4_0 512 tg128@d65536 58.24 58.79 1.01
RTX 5090 llama 8B Q4_0 512 tg128@d65536 73.03 72.93 1.00

I don't have an RDNA1 GPU for testing but only the rare models without GQA should be affected and even if there is a small performance regression it's more important to ensure correct results.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 27, 2025
Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why, but this seems to have improved 3 of the perf tests (the others are unaffected) on my 3090Ti:

  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 41727 runs -    26.13 us/run -  16.78 MFLOP/run - 641.98 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 23848 runs -    43.61 us/run -  33.55 MFLOP/run - 769.51 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                13419 runs -    78.54 us/run -  67.11 MFLOP/run - 854.41 GFLOPS

vs

  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=4096,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 41727 runs -    24.51 us/run -  16.78 MFLOP/run - 684.44 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=8192,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                 23848 runs -    42.17 us/run -  33.55 MFLOP/run - 795.63 GFLOPS
  FLASH_ATTN_EXT(hsk=128,hsv=128,nh=8,nr23=[1,1],kv=16384,nb=1,mask=1,sinks=0,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]):                13419 runs -    77.16 us/run -  67.11 MFLOP/run - 869.79 GFLOPS

@JohannesGaessler
Copy link
Collaborator Author

Okay, what I said regarding which GPUs are affected was strictly speaking not 100% correct. Volta, Turing, and Ampere do theoretically also have some cases where the vector FA kernel is used but only if a model isn't using GQA. And nowadays essentially all of them do. So I didn't consider those cases for performance testing.

@daniandtheweb
Copy link
Contributor

daniandtheweb commented Nov 28, 2025

I have a RX 5700XT (RDNA1) that I could use to test, however it's been a while since I last used the HIP backend and flash attention just throws an error with this card on both master and this PR:

CUDA_VISIBLE_DEVICES=1 ./llama-bench -m ~/Applications/chat/gguf/llama-2-7b.Q4_0.gguf -ngl 100 -fa 0,1                   0.001s 
ggml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 1 ROCm devices:
  Device 0: AMD Radeon RX 5700 XT, gfx1010:xnack- (0x1010), VMM: no, Wave Size: 32
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | ROCm       | 100 |  0 |           pp512 |        356.28 ± 0.29 |
| llama 7B Q4_0                  |   3.56 GiB |     6.74 B | ROCm       | 100 |  0 |           tg128 |         62.08 ± 0.02 |
/home/daniandtheweb/Applications/chat/llama.cpp/ggml/src/ggml-cuda/template-instances/../fattn-common.cuh:898: GGML_ASSERT(max_blocks_per_sm > 0) failed
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libggml-base.so.0(+0x156f6) [0x7f2d6896c6f6]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libggml-base.so.0(ggml_print_backtrace+0x203) [0x7f2d6896cb33]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libggml-base.so.0(ggml_abort+0x130) [0x7f2d6896ccd0]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libggml-hip.so.0(_Z12launch_fattnILi128ELi64ELi1EEvR25ggml_backend_cuda_contextP11ggml_tensorPFvPKcS5_S5_S5_S5_PKiPfP15HIP_vector_typeIfLj2EEffffjfiiiiiiiiiiiiiliiliiiiilEimibbbi+0x12d0) [0x7f2d6b7a8c70]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libggml-hip.so.0(_Z34ggml_cuda_flash_attn_ext_tile_caseILi128ELi128EEvR25ggml_backend_cuda_contextP11ggml_tensor+0x21f) [0x7f2d6b7927ef]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libggml-hip.so.0(+0x2b2ed4d) [0x7f2d6b52ed4d]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libggml-hip.so.0(+0x2b2e23f) [0x7f2d6b52e23f]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libggml-base.so.0(ggml_backend_sched_graph_compute_async+0x813) [0x7f2d68987f73]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libllama.so.0(_ZN13llama_context13graph_computeEP11ggml_cgraphb+0xa0) [0x7f2d6bc9b140]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libllama.so.0(_ZN13llama_context14process_ubatchERK12llama_ubatch14llm_graph_typeP22llama_memory_context_iR11ggml_status+0xf3) [0x7f2d6bc9d003]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libllama.so.0(_ZN13llama_context6decodeERK11llama_batch+0x40f) [0x7f2d6bca21cf]
/home/daniandtheweb/Applications/chat/llama.cpp/build/bin/libllama.so.0(llama_decode+0xe) [0x7f2d6bca311e]
./llama-bench(+0x1ca6b) [0x5634881cba6b]
./llama-bench(+0x1a012) [0x5634881c9012]
/usr/lib/libc.so.6(+0x27635) [0x7f2d68027635]
/usr/lib/libc.so.6(__libc_start_main+0x89) [0x7f2d680276e9]
./llama-bench(+0x1b7f5) [0x5634881ca7f5]
zsh: IOT instruction (core dumped)  CUDA_VISIBLE_DEVICES=1 ./llama-bench -m  -ngl 100 -fa 0,1

I don't know if this issue is ROCm's fault (RDNA1 has never been officially supported, but it used to work until some time ago, I now run version 7.1) or if this issue came with some update inside the GGML backend itself.

@LostRuins
Copy link
Collaborator

Can confirm that the fix works, GLM4 is now coherent again.

@JohannesGaessler JohannesGaessler merged commit 73955f7 into ggml-org:master Nov 28, 2025
61 of 63 checks passed
Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Eval bug: incoherent results with GLM-4-32B-0414-GGUF on CUDA

4 participants