-
Notifications
You must be signed in to change notification settings - Fork 39
Update #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update #64
Conversation
Increases BLOCK_N from 64 to 128 in backward pass autotuning configs to improve memory throughput. Separates forward pass block dimensions into distinct BLOCK_M (128) and BLOCK_N (64) variables instead of using a single BLOCK parameter, allowing for asymmetric block sizing that better matches memory access patterns.
Refactors the splitkv attention kernel to use attention mask and bias instead of ZOH (zero-out-head) and active mask patterns. Updates tensor definitions, memory layouts, and copy operations to support the new mask/bias approach while maintaining the same sparse matrix multiplication functionality. Removes the Append_KV template parameter as it's no longer needed with the simplified mask-based implementation.
Moves tensor gP declaration after tensor sS to improve memory access patterns and potentially optimize performance in the attention computation kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refines block size configurations to boost performance and memory efficiency in Triton kernels and streamlines kernel templates.
- Increased
BLOCK_Nfrom 64 to 128 in the autotune configs ofinit_func. - Replaced a single
BLOCKconstant with separateBLOCK_MandBLOCK_Nin_flash_attn_forward. - Refactored
compute_attn_1rowblock_splitkvin the C++ kernel: removed theAppend_KVparameter, reorganized thegPtensor placement, and integrated separate mask and bias handling.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_triton.py | Adjusted autotune configs (BLOCK_N), split BLOCK into BLOCK_M/BLOCK_N, and updated kernel invocation. |
| csrc/src/flash_fwd_kernel.h | Cleaned up compute_attn_1rowblock_splitkv template signature, moved gP tensor, and added mask/bias support while removing Append_KV. |
Comments suppressed due to low confidence (1)
flash_dmattn/flash_dmattn_triton.py:663
- [nitpick] Consider adding a brief comment explaining the rationale for increasing BLOCK_N to 128 in the autotune configs, to help future maintainers understand this performance tuning decision.
{"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
| BLOCK_M = 128 | ||
| BLOCK_N = 64 |
Copilot
AI
Jul 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider defining these magic block size values (BLOCK_M and BLOCK_N) as module-level constants or parameters, rather than inline literals, to centralize configuration and simplify tuning.
| BLOCK_M = 128 | |
| BLOCK_N = 64 | |
| # Use module-level constants for block sizes | |
| BLOCK_M = BLOCK_M | |
| BLOCK_N = BLOCK_N |
| // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } | ||
| // __syncthreads(); | ||
|
|
Copilot
AI
Jul 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This debug print statement is commented out and can be removed to clean up the codebase and avoid clutter.
| // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } | |
| // __syncthreads(); |
LoserCheems
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This pull request updates the
flash_dmattn_triton.pyfile to improve performance and memory efficiency by adjusting block size configurations in the Triton kernels. The changes focus on modifying block dimensions (BLOCK_MandBLOCK_N) to optimize computation and memory usage.Block size configuration updates:
Updated block sizes in Triton configurations within the
init_funcfunction to useBLOCK_N=128instead ofBLOCK_N=64, for both sequence-parallel and non-sequence-parallel configurations. This change aims to improve performance by increasing the block size for certain dimensions.Replaced the hardcoded
BLOCK=64with separateBLOCK_M=128andBLOCK_N=64in the_flash_attn_forwardfunction. This allows more flexibility in defining block dimensions for different axes, potentially improving memory alignment and computation efficiency.Updated kernel invocation in
_flash_attn_forwardto pass the newBLOCK_MandBLOCK_Nvariables instead of the singleBLOCKparameter, ensuring compatibility with the updated block size configuration.