CUDA: generalized (mma) FA, add Volta support #17505
Open
+938
−738
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR makes the following changes to the CUDA FlashAttention code:
mask->ne[1]direction. This is done by applying a modulo on the mask column that is being read so no conditional statements need to be evaluated. The impact on performance is negligible and I do not deem it necessary to compile additional template specializations. See ggml : remove KQ mask padding #16309 . cc @ggerganov .tiletemplate inmma.cuhhas been extended with additional, optional arguments to safely handle situations where tiles of the same shape can have different physical data layouts.__launch_bounds__when using ROCm (as of right now ROCm is not used).K->ne[1]. As with the tile kernel, because this comes at a cost to performance it is still preferable to pad the KV cache length. As of right now this is still required to be 256, for the currently supported GPUs it should be possible to lower this to 128 without issue once the WMMA kernel has been completely replaced. For Hopper it may still make sense to have a padding of 256 but as it is I have no idea whether the 256x64 instruction would actually have better performance than the 128x64 instruction.As of right now the interface in
mma.cuhis suboptimal and long-term I intend to refactor it to allow the use of tensor cores in a more uniform way. However, I don't know the exact requirements until we have proper support for AMD WMMA and AMD MFMA instructions. So for now I think the correct choice is to prioritize getting working support for those at the cost of maintainability and to do a refactor afterwards.V100 performance
Other GPU performance
The performance numbers assume that the KQ mask is no longer being padded. This change is also in this PR. I don't have a good overview of which other backends maybe still need support for this change and whether or not it should be reverted prior to merging.