ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (…#22286
ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (…#22286JohannesGaessler merged 7 commits intoggml-org:masterfrom
Conversation
…GQA=32) Adds MMA-f16 and tile kernel configs, dispatch logic, template instances, and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to ncols2=32 to support GQA ratio 32 only.
|
@JohannesGaessler please help to get this merged. Will close my another PR for ncols2=8 |
| GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); | ||
| GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); | ||
|
|
||
| GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); |
There was a problem hiding this comment.
referred 256,256 case for these configs.
| return; | ||
| // cols_per_block must be >= ncols2 so ncols1 = cols_per_block/ncols2 is never 0 (integer division). | ||
| // Without if constexpr, NVCC/MSVC still instantiate flash_attn_tile<..., 0, ncols2, ...> when ncols2 > 16. | ||
| if constexpr (16 >= ncols2) { |
There was a problem hiding this comment.
for ncols2=32, getting 16/ncols2=0 for kernel instantiation.
|
Hi @lnigam, thanks for your contribution! Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:
Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below. |
|
I tested with this PR merged on the top of current master, and it seems FA works OK on DGX Spark too. |
Address review comments Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
…DV=512 instead of DQK=256,DV=256
…ted but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np)
|
@JohannesGaessler I have addressed all the review comments. Which is failing because with ncols=32, cols_per_warp=16, two warp groups are created per token(grp-0, warp-0=head-0,1,..15), grp-0,warp-1=head-16,17...31). But jc value is local to per warp and always between (0,1...15) for both warps in the group. Hence the output with CPU is not matching. Fixed by adding the base index per warp within the group and also also verified the test is passing now. With latest changes, Getting following performance numbers on 6000 pro |
JohannesGaessler
left a comment
There was a problem hiding this comment.
Please stop adding superfluous comments.
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
IMbackK
left a comment
There was a problem hiding this comment.
i have restarted windows-latest-hip which timed out (known problem) please wait until it finishes successfully and then feel free to merge.
|
Very nice speedup! Detailsggml_cuda_init: found 4 CUDA devices (Total VRAM: 84280 MiB): before:
after:
|
…GQA=32)
Adds MMA-f16 and tile kernel configs, dispatch logic, template instances, and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to ncols2=32 to support GQA ratio 32 only.
Overview
Add fattn-kernel instantiation for dimension DQK=320 and DV-256 required for Mistal small 4. forced kernel instantiation to ncols2=32
Additional information
Mistral small 4 has Multi head Latent attention (MLA), while running with flash-attn ON, fattn is getting fallback on CPU, hence reducing the performance. This model has GQA=32. Hence initiated kernel with nclos2=8
benchmark: .\llama-bench.exe --model "D:\Loveneet\Sinistea-119B-q4-k-m.gguf" -n 32 -ngl 500 -r 10 -fa 1 -d 4096 -p 512
on NVIDIA RTX PRO 6000 Blackwell Workstation Edition
With ncols2=32
With ncols2=8
without changes:
Requirements