Skip to content

ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (…#22286

Merged
JohannesGaessler merged 7 commits intoggml-org:masterfrom
lnigam:fattn_320_256_GQA32
Apr 28, 2026
Merged

ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (…#22286
JohannesGaessler merged 7 commits intoggml-org:masterfrom
lnigam:fattn_320_256_GQA32

Conversation

@lnigam
Copy link
Copy Markdown
Contributor

@lnigam lnigam commented Apr 23, 2026

…GQA=32)

Adds MMA-f16 and tile kernel configs, dispatch logic, template instances, and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to ncols2=32 to support GQA ratio 32 only.

Overview

Add fattn-kernel instantiation for dimension DQK=320 and DV-256 required for Mistal small 4. forced kernel instantiation to ncols2=32

Additional information

Mistral small 4 has Multi head Latent attention (MLA), while running with flash-attn ON, fattn is getting fallback on CPU, hence reducing the performance. This model has GQA=32. Hence initiated kernel with nclos2=8
benchmark: .\llama-bench.exe --model "D:\Loveneet\Sinistea-119B-q4-k-m.gguf" -n 32 -ngl 500 -r 10 -fa 1 -d 4096 -p 512
on NVIDIA RTX PRO 6000 Blackwell Workstation Edition
With ncols2=32

model size params backend ngl fa test t/s
mistral4 ?B Q4_K - Medium 67.32 GiB 118.97 B CUDA 500 1 pp512 @ d4096 3752.47 ± 26.33
mistral4 ?B Q4_K - Medium 67.32 GiB 118.97 B CUDA 500 1 tg32 @ d4096 185.92 ± 4.43

With ncols2=8

model size params backend ngl fa test t/s
mistral4 ?B Q4_K - Medium 67.32 GiB 118.97 B CUDA 500 1 pp512 @ d4096 3678.35 ± 21.68
mistral4 ?B Q4_K - Medium 67.32 GiB 118.97 B CUDA 500 1 tg32 @ d4096 182.24 ± 4.00

without changes:

model size params backend ngl fa test t/s
mistral4 ?B Q4_K - Medium 67.32 GiB 118.97 B CUDA 500 1 pp512 @ d4096 179.64 ± 2.66
mistral4 ?B Q4_K - Medium 67.32 GiB 118.97 B CUDA 500 1 tg32 @ d4096 33.05 ± 0.30

Requirements

…GQA=32)

Adds MMA-f16 and tile kernel configs, dispatch logic, template instances,
and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to
ncols2=32 to support GQA ratio 32 only.
@lnigam lnigam requested a review from a team as a code owner April 23, 2026 12:18
@lnigam
Copy link
Copy Markdown
Contributor Author

lnigam commented Apr 23, 2026

@JohannesGaessler please help to get this merged. Will close my another PR for ncols2=8

Comment thread ggml/src/ggml-cuda/fattn-mma-f16.cuh Outdated
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);

GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

referred 256,256 case for these configs.

Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
return;
// cols_per_block must be >= ncols2 so ncols1 = cols_per_block/ncols2 is never 0 (integer division).
// Without if constexpr, NVCC/MSVC still instantiate flash_attn_tile<..., 0, ncols2, ...> when ncols2 > 16.
if constexpr (16 >= ncols2) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for ncols2=32, getting 16/ncols2=0 for kernel instantiation.

@ggml-gh-bot
Copy link
Copy Markdown

ggml-gh-bot Bot commented Apr 23, 2026

Hi @lnigam, thanks for your contribution!

Per our contribution guidelines, the automated PR checker found the following issue(s) that need your attention:

  • Multiple open PRs from a new contributor: We limit new contributors (those without a previously merged PR) to 1 open PR at a time. You currently have 2 open PRs.

Please note that maintainers reserve the right to make final decisions on PRs. If you believe there is a mistake, please comment below.

@github-actions github-actions Bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Apr 23, 2026
@HelloKS
Copy link
Copy Markdown
Contributor

HelloKS commented Apr 26, 2026

I tested with this PR merged on the top of current master, and it seems FA works OK on DGX Spark too.

Comment thread ggml/src/ggml-cuda/template-instances/generate_cu_files.py Outdated
Comment thread ggml/src/ggml-cuda/template-instances/generate_cu_files.py Outdated
Comment thread ggml/src/ggml-cuda/template-instances/generate_cu_files.py Outdated
Comment thread ggml/src/ggml-cuda/fattn-mma-f16.cuh Outdated
Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
Comment thread ggml/src/ggml-cuda/fattn-tile.cuh Outdated
Comment thread ggml/src/ggml-cuda/fattn.cu Outdated
Comment thread ggml/src/ggml-cuda/fattn.cu Outdated
Comment thread ggml/src/ggml-cuda/fattn.cu Outdated
lnigam and others added 3 commits April 28, 2026 12:46
Address review comments

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
…ted but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np)
@lnigam
Copy link
Copy Markdown
Contributor Author

lnigam commented Apr 28, 2026

@JohannesGaessler I have addressed all the review comments.
The test-backend-ops was failing for cuda with sinks=1:
Failing tests:
FLASH_ATTN_EXT(hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
FLASH_ATTN_EXT(hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
FLASH_ATTN_EXT(hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
FLASH_ATTN_EXT(hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=3,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
FLASH_ATTN_EXT(hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
FLASH_ATTN_EXT(hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=32,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
FLASH_ATTN_EXT(hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3])
FLASH_ATTN_EXT(hsk=320,hsv=256,nh=1,nr23=[32,1],kv=512,nb=75,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,2,1,3])
Backend CUDA0: FAIL
https://github.com/ggml-org/llama.cpp/actions/runs/25042595466/job/73353004904?pr=22286#step:3:43529

Which is failing because with ncols=32, cols_per_warp=16, two warp groups are created per token(grp-0, warp-0=head-0,1,..15), grp-0,warp-1=head-16,17...31). But jc value is local to per warp and always between (0,1...15) for both warps in the group. Hence the output with CPU is not matching. Fixed by adding the base index per warp within the group and also also verified the test is passing now.
Please help to review this commit: 6e9ccb0

With latest changes, Getting following performance numbers on 6000 pro
.\llama-bench.exe --model "D:\mistral-119B-q4-k-m.gguf" -n 32 -ngl 500 -r 10 -fa 1 -d 4096 -p 512
ggml_cuda_init: found 1 CUDA devices (Total VRAM: 97886 MiB):
Device 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition, compute capability 12.0, VMM: yes, VRAM: 97886 MiB
| model | size | params | backend | ngl | fa | test |
t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| mistral4 ?B Q4_K - Medium | 67.32 GiB | 118.97 B | CUDA | 500 | 1 | pp512 @ d4096 | 3655.55 ± 15.81 |
| mistral4 ?B Q4_K - Medium | 67.32 GiB | 118.97 B | CUDA | 500 | 1 | tg32 @ d4096 | 186.31 ± 4.23 |

Copy link
Copy Markdown
Contributor

@JohannesGaessler JohannesGaessler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please stop adding superfluous comments.

Comment thread ggml/src/ggml-cuda/template-instances/generate_cu_files.py Outdated
Comment thread ggml/src/ggml-cuda/fattn-mma-f16.cuh Outdated
Comment thread ggml/src/ggml-cuda/fattn-mma-f16.cuh Outdated
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Comment thread ggml/src/ggml-cuda/template-instances/generate_cu_files.py Outdated
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Copy link
Copy Markdown
Collaborator

@IMbackK IMbackK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have restarted windows-latest-hip which timed out (known problem) please wait until it finishes successfully and then feel free to merge.

@JohannesGaessler JohannesGaessler merged commit 7b8443a into ggml-org:master Apr 28, 2026
79 of 81 checks passed
@jacekpoplawski
Copy link
Copy Markdown
Contributor

Very nice speedup!

Details

ggml_cuda_init: found 4 CUDA devices (Total VRAM: 84280 MiB):
Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes, VRAM: 24124 MiB
Device 1: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes, VRAM: 24124 MiB
Device 2: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes, VRAM: 24124 MiB
Device 3: NVIDIA GeForce RTX 3060, compute capability 8.6, VMM: yes, VRAM: 11907 MiB

before:

model size params backend ngl fa test t/s
mistral4 ?B Q4_K - Medium 68.62 GiB 118.97 B CUDA 999 1 pp512 348.35 ± 3.64
mistral4 ?B Q4_K - Medium 68.62 GiB 118.97 B CUDA 999 1 tg128 59.12 ± 0.14

after:

model size params backend ngl fa test t/s
mistral4 ?B Q4_K - Medium 68.62 GiB 118.97 B CUDA 999 1 pp512 1310.13 ± 9.14
mistral4 ?B Q4_K - Medium 68.62 GiB 118.97 B CUDA 999 1 tg128 93.06 ± 0.89

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants