Skip to content

Conversation

@JohannesGaessler
Copy link
Collaborator

@JohannesGaessler JohannesGaessler commented Nov 25, 2025

This PR makes the following changes to the CUDA FlashAttention code:

  • All kernels have been extended with support for attention masks that are not padded in mask->ne[1] direction. This is done by applying a modulo on the mask column that is being read so no conditional statements need to be evaluated. The impact on performance is negligible and I do not deem it necessary to compile additional template specializations. See ggml : remove KQ mask padding #16309 . cc @ggerganov .
  • The mma kernel has been extended with support for Volta tensor cores. Previously the WMMA kernel was used. The WMMA kernel is now only needed for AMD. After AMD support has been added to the mma kernel the WMMA kernel can be safely removed, leaving only 3 kernels to maintain going forward. On master the mma kernel has defects w.r.t. tile shapes that do not manifest as bugs, those should be fixed with this PR and I think it is now feasible for other developers to add support for e.g. AMD wmma instructions. cc @zhang-hui-yulo @jiachengjason @unverbraucht .
  • The tile template in mma.cuh has been extended with additional, optional arguments to safely handle situations where tiles of the same shape can have different physical data layouts.
  • The mma kernel is refactored to allow more flexible configuration. The configuration is now also done without the use of templating which seems to be causing issues for __launch_bounds__ when using ROCm (as of right now ROCm is not used).
  • The mma kernel is extended with support for out-of-bounds checks in direction of K->ne[1]. As with the tile kernel, because this comes at a cost to performance it is still preferable to pad the KV cache length. As of right now this is still required to be 256, for the currently supported GPUs it should be possible to lower this to 128 without issue once the WMMA kernel has been completely replaced. For Hopper it may still make sense to have a padding of 256 but as it is I have no idea whether the 256x64 instruction would actually have better performance than the 128x64 instruction.

As of right now the interface in mma.cuh is suboptimal and long-term I intend to refactor it to allow the use of tensor cores in a more uniform way. However, I don't know the exact requirements until we have proper support for AMD WMMA and AMD MFMA instructions. So for now I think the correct choice is to prioritize getting working support for those at the cost of maintainability and to do a refactor afterwards.

V100 performance
GPU Model Microbatch size Test t/s master t/s 277014f Speedup
V100-PCIE-32GB deepseek2 16B Q4_0 1 pp512@d32768 84.06 89.23 1.06
V100-PCIE-32GB deepseek2 16B Q4_0 2 pp512@d32768 88.28 86.50 0.98
V100-PCIE-32GB deepseek2 16B Q4_0 4 pp512@d32768 122.04 134.50 1.10
V100-PCIE-32GB deepseek2 16B Q4_0 8 pp512@d32768 159.61 204.43 1.28
V100-PCIE-32GB deepseek2 16B Q4_0 16 pp512@d32768 187.50 274.82 1.47
V100-PCIE-32GB deepseek2 16B Q4_0 32 pp512@d32768 208.08 340.50 1.64
V100-PCIE-32GB deepseek2 16B Q4_0 64 pp512@d32768 196.49 312.07 1.59
V100-PCIE-32GB deepseek2 16B Q4_0 128 pp512@d32768 217.64 371.18 1.71
V100-PCIE-32GB deepseek2 16B Q4_0 256 pp512@d32768 227.55 408.51 1.80
V100-PCIE-32GB deepseek2 16B Q4_0 512 pp512@d32768 250.76 432.14 1.72
V100-PCIE-32GB gemma 2B Q4_0 1 pp512@d32768 196.73 276.43 1.41
V100-PCIE-32GB gemma 2B Q4_0 2 pp512@d32768 341.32 472.67 1.38
V100-PCIE-32GB gemma 2B Q4_0 4 pp512@d32768 233.69 461.42 1.97
V100-PCIE-32GB gemma 2B Q4_0 8 pp512@d32768 433.09 705.18 1.63
V100-PCIE-32GB gemma 2B Q4_0 16 pp512@d32768 779.04 1095.12 1.41
V100-PCIE-32GB gemma 2B Q4_0 32 pp512@d32768 981.00 1506.68 1.54
V100-PCIE-32GB gemma 2B Q4_0 64 pp512@d32768 859.59 1260.66 1.47
V100-PCIE-32GB gemma 2B Q4_0 128 pp512@d32768 1032.55 1735.64 1.68
V100-PCIE-32GB gemma 2B Q4_0 256 pp512@d32768 1089.22 1833.70 1.68
V100-PCIE-32GB gemma 2B Q4_0 512 pp512@d32768 995.95 1613.81 1.62
V100-PCIE-32GB llama 1B Q4_0 1 pp512@d32768 237.92 323.72 1.36
V100-PCIE-32GB llama 1B Q4_0 2 pp512@d32768 417.22 588.65 1.41
V100-PCIE-32GB llama 1B Q4_0 4 pp512@d32768 448.34 838.65 1.87
V100-PCIE-32GB llama 1B Q4_0 8 pp512@d32768 824.46 1445.37 1.75
V100-PCIE-32GB llama 1B Q4_0 16 pp512@d32768 1435.92 1917.20 1.34
V100-PCIE-32GB llama 1B Q4_0 32 pp512@d32768 1769.39 2566.43 1.45
V100-PCIE-32GB llama 1B Q4_0 64 pp512@d32768 1991.61 2289.92 1.15
V100-PCIE-32GB llama 1B Q4_0 128 pp512@d32768 2391.19 2843.04 1.19
V100-PCIE-32GB llama 1B Q4_0 256 pp512@d32768 2312.60 2559.85 1.11
V100-PCIE-32GB llama 1B Q4_0 512 pp512@d32768 1900.53 2137.76 1.12
V100-PCIE-32GB llama 8B Q4_0 1 pp512@d32768 61.12 81.47 1.33
V100-PCIE-32GB llama 8B Q4_0 2 pp512@d32768 115.57 154.44 1.34
V100-PCIE-32GB llama 8B Q4_0 4 pp512@d32768 120.26 220.87 1.84
V100-PCIE-32GB llama 8B Q4_0 8 pp512@d32768 215.88 323.48 1.50
V100-PCIE-32GB llama 8B Q4_0 16 pp512@d32768 380.43 467.35 1.23
V100-PCIE-32GB llama 8B Q4_0 32 pp512@d32768 470.78 656.82 1.40
V100-PCIE-32GB llama 8B Q4_0 64 pp512@d32768 228.56 456.01 2.00
V100-PCIE-32GB llama 8B Q4_0 128 pp512@d32768 278.85 670.43 2.40
V100-PCIE-32GB llama 8B Q4_0 256 pp512@d32768 307.17 872.91 2.84
V100-PCIE-32GB llama 8B Q4_0 512 pp512@d32768 314.34 932.41 2.97
Other GPU performance
GPU Model Microbatch size Test t/s master t/s e44ebb0 Speedup
MI60 / MI50 llama 8B Q4_0 1 pp512@d32768 59.80 64.40 1.08
MI60 / MI50 llama 8B Q4_0 2 pp512@d32768 106.46 113.46 1.07
MI60 / MI50 llama 8B Q4_0 4 pp512@d32768 119.84 97.07 0.81
MI60 / MI50 llama 8B Q4_0 8 pp512@d32768 162.89 167.55 1.03
MI60 / MI50 llama 8B Q4_0 16 pp512@d32768 228.46 229.93 1.01
MI60 / MI50 llama 8B Q4_0 32 pp512@d32768 269.06 268.69 1.00
MI60 / MI50 llama 8B Q4_0 64 pp512@d32768 291.15 289.38 0.99
MI60 / MI50 llama 8B Q4_0 128 pp512@d32768 335.13 332.27 0.99
MI60 / MI50 llama 8B Q4_0 256 pp512@d32768 351.75 349.71 0.99
MI60 / MI50 llama 8B Q4_0 512 pp512@d32768 357.18 355.12 0.99
MI100 llama 8B Q4_0 1 pp512@d32768 77.78 82.66 1.06
MI100 llama 8B Q4_0 2 pp512@d32768 133.33 139.16 1.04
MI100 llama 8B Q4_0 4 pp512@d32768 164.44 169.21 1.03
MI100 llama 8B Q4_0 8 pp512@d32768 232.70 236.51 1.02
MI100 llama 8B Q4_0 16 pp512@d32768 424.09 431.27 1.02
MI100 llama 8B Q4_0 32 pp512@d32768 559.43 563.32 1.01
MI100 llama 8B Q4_0 64 pp512@d32768 648.34 648.77 1.00
MI100 llama 8B Q4_0 128 pp512@d32768 671.01 668.83 1.00
MI100 llama 8B Q4_0 256 pp512@d32768 696.50 692.00 0.99
MI100 llama 8B Q4_0 512 pp512@d32768 706.38 700.32 0.99
P40 llama 8B Q4_0 1 pp512@d32768 31.00 32.45 1.05
P40 llama 8B Q4_0 2 pp512@d32768 59.14 61.75 1.04
P40 llama 8B Q4_0 4 pp512@d32768 87.36 89.87 1.03
P40 llama 8B Q4_0 8 pp512@d32768 122.68 122.31 1.00
P40 llama 8B Q4_0 16 pp512@d32768 178.33 175.34 0.98
P40 llama 8B Q4_0 32 pp512@d32768 189.92 190.07 1.00
P40 llama 8B Q4_0 64 pp512@d32768 209.02 208.27 1.00
P40 llama 8B Q4_0 128 pp512@d32768 217.96 217.49 1.00
P40 llama 8B Q4_0 256 pp512@d32768 223.15 222.81 1.00
P40 llama 8B Q4_0 512 pp512@d32768 219.45 219.48 1.00
Radeon 8060S Graphics llama 8B Q4_0 1 pp512@d32768 23.92 24.10 1.01
Radeon 8060S Graphics llama 8B Q4_0 2 pp512@d32768 43.49 43.68 1.00
Radeon 8060S Graphics llama 8B Q4_0 4 pp512@d32768 77.88 78.19 1.00
Radeon 8060S Graphics llama 8B Q4_0 8 pp512@d32768 108.82 96.17 0.88
Radeon 8060S Graphics llama 8B Q4_0 16 pp512@d32768 138.58 140.27 1.01
Radeon 8060S Graphics llama 8B Q4_0 32 pp512@d32768 151.39 152.96 1.01
Radeon 8060S Graphics llama 8B Q4_0 64 pp512@d32768 74.81 76.94 1.03
Radeon 8060S Graphics llama 8B Q4_0 128 pp512@d32768 101.46 102.30 1.01
Radeon 8060S Graphics llama 8B Q4_0 256 pp512@d32768 115.59 115.84 1.00
Radeon 8060S Graphics llama 8B Q4_0 512 pp512@d32768 117.65 118.57 1.01
RTX 3090 llama 8B Q4_0 1 pp512@d32768 87.54 92.96 1.06
RTX 3090 llama 8B Q4_0 2 pp512@d32768 160.48 170.31 1.06
RTX 3090 llama 8B Q4_0 4 pp512@d32768 293.48 303.46 1.03
RTX 3090 llama 8B Q4_0 8 pp512@d32768 429.51 439.54 1.02
RTX 3090 llama 8B Q4_0 16 pp512@d32768 844.62 874.15 1.03
RTX 3090 llama 8B Q4_0 32 pp512@d32768 1184.30 1194.99 1.01
RTX 3090 llama 8B Q4_0 64 pp512@d32768 1491.70 1495.43 1.00
RTX 3090 llama 8B Q4_0 128 pp512@d32768 1612.42 1617.77 1.00
RTX 3090 llama 8B Q4_0 256 pp512@d32768 1716.96 1697.92 0.99
RTX 3090 llama 8B Q4_0 512 pp512@d32768 1470.93 1448.12 0.98
RTX 4090 llama 8B Q4_0 1 pp512@d32768 98.14 102.76 1.05
RTX 4090 llama 8B Q4_0 2 pp512@d32768 178.13 190.39 1.07
RTX 4090 llama 8B Q4_0 4 pp512@d32768 349.90 366.50 1.05
RTX 4090 llama 8B Q4_0 8 pp512@d32768 618.83 646.33 1.04
RTX 4090 llama 8B Q4_0 16 pp512@d32768 1095.54 1140.84 1.04
RTX 4090 llama 8B Q4_0 32 pp512@d32768 2007.89 2051.87 1.02
RTX 4090 llama 8B Q4_0 64 pp512@d32768 3091.16 3089.09 1.00
RTX 4090 llama 8B Q4_0 128 pp512@d32768 3188.55 3095.61 0.97
RTX 4090 llama 8B Q4_0 256 pp512@d32768 2961.18 2892.63 0.98
RTX 4090 llama 8B Q4_0 512 pp512@d32768 2464.56 2431.25 0.99
RTX 5090 llama 8B Q4_0 1 pp512@d32768 155.78 167.41 1.07
RTX 5090 llama 8B Q4_0 2 pp512@d32768 239.31 269.27 1.13
RTX 5090 llama 8B Q4_0 4 pp512@d32768 461.48 486.56 1.05
RTX 5090 llama 8B Q4_0 8 pp512@d32768 780.64 810.10 1.04
RTX 5090 llama 8B Q4_0 16 pp512@d32768 1381.19 1408.61 1.02
RTX 5090 llama 8B Q4_0 32 pp512@d32768 2253.55 2308.20 1.02
RTX 5090 llama 8B Q4_0 64 pp512@d32768 2827.63 2828.64 1.00
RTX 5090 llama 8B Q4_0 128 pp512@d32768 3009.14 3075.67 1.02
RTX 5090 llama 8B Q4_0 256 pp512@d32768 3078.24 2981.31 0.97
RTX 5090 llama 8B Q4_0 512 pp512@d32768 2698.04 2640.36 0.98
RX 6800 llama 8B Q4_0 1 pp512@d32768 42.25 44.60 1.06
RX 6800 llama 8B Q4_0 2 pp512@d32768 77.43 81.42 1.05
RX 6800 llama 8B Q4_0 4 pp512@d32768 105.08 108.86 1.04
RX 6800 llama 8B Q4_0 8 pp512@d32768 140.43 140.94 1.00
RX 6800 llama 8B Q4_0 16 pp512@d32768 173.28 175.32 1.01
RX 6800 llama 8B Q4_0 32 pp512@d32768 209.55 210.72 1.01
RX 6800 llama 8B Q4_0 64 pp512@d32768 235.46 235.80 1.00
RX 6800 llama 8B Q4_0 128 pp512@d32768 262.63 262.85 1.00
RX 6800 llama 8B Q4_0 256 pp512@d32768 274.40 274.65 1.00
RX 6800 llama 8B Q4_0 512 pp512@d32768 275.25 274.63 1.00
RX 9060 XT llama 8B Q4_0 1 pp512@d32768 25.67 29.58 1.15
RX 9060 XT llama 8B Q4_0 2 pp512@d32768 49.98 57.25 1.15
RX 9060 XT llama 8B Q4_0 4 pp512@d32768 85.18 97.39 1.14
RX 9060 XT llama 8B Q4_0 8 pp512@d32768 111.87 104.18 0.93
RX 9060 XT llama 8B Q4_0 16 pp512@d32768 162.98 172.35 1.06
RX 9060 XT llama 8B Q4_0 32 pp512@d32768 190.29 195.63 1.03
RX 9060 XT llama 8B Q4_0 64 pp512@d32768 288.59 291.34 1.01
RX 9060 XT llama 8B Q4_0 128 pp512@d32768 322.67 325.96 1.01
RX 9060 XT llama 8B Q4_0 256 pp512@d32768 348.31 351.01 1.01
RX 9060 XT llama 8B Q4_0 512 pp512@d32768 349.45 350.95 1.00

The performance numbers assume that the KQ mask is no longer being padded. This change is also in this PR. I don't have a good overview of which other backends maybe still need support for this change and whether or not it should be reverted prior to merging.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs ggml changes relating to the ggml tensor library for machine learning labels Nov 25, 2025
@JohannesGaessler JohannesGaessler changed the title CUDA: ganeralized (mma) FA, add Volta support CUDA: generalized (mma) FA, add Volta support Nov 25, 2025
@JohannesGaessler JohannesGaessler force-pushed the cuda-fa-mma-update-5 branch 2 times, most recently from 48372ef to 2ef0c5f Compare November 25, 2025 23:09
@zhang-hui-yulo
Copy link
Contributor

Thank you for the info, I shall work on FA for RDNA4 once this PR is merged. Looks like that the logic of transposed tile is still empty.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants