Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

This pull request introduces significant enhancements to the Flash Attention forward pass implementation, focusing on kernel definitions, architecture-specific optimizations, and memory layout improvements. The changes aim to improve code maintainability, support for dynamic masks, and performance across different GPU architectures.

Kernel and Architecture Enhancements:

  • Introduced macros to streamline kernel definitions and handle unsupported architectures in flash_fwd_launch_template.h. This includes defining DEFINE_FLASH_FORWARD_KERNEL for cleaner kernel declarations and FLASH_UNSUPPORTED_ARCH for centralized error messaging.
  • Added architecture-specific optimizations for compute capabilities (e.g., SM80+) by adjusting kernel configurations and memory usage based on GPU capabilities.

Dynamic Mask Support:

  • Enhanced dynamic mask memory allocation in Flash_fwd_kernel_traits by splitting kDynamicMaskBufferPerQuery into separate components (kMaskValuesSize, kNonZeroIndicesSize, etc.) for better modularity and clarity.
  • Defined shared memory layouts (SmemLayoutDynamicMaskValues, SmemLayoutNonZeroIndices, etc.) to support dynamic masks with improved memory organization.

Performance Improvements:

  • Optimized kernel configurations for specific head dimensions (e.g., 32, 64, 96, 128, 192, 256) in flash_fwd_launch_template.h, ensuring efficient memory and thread usage based on GPU architecture and workload characteristics.
  • Adjusted shared memory size calculations to account for dynamic mask buffers and non-zero indices, ensuring efficient use of shared memory resources.

@LoserCheems
Copy link
Collaborator Author

@wubingheng111

@LoserCheems LoserCheems merged commit a843e46 into main May 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants