Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

This pull request updates the flash_dmattn_triton.py file to improve performance and memory efficiency by adjusting block size configurations in the Triton kernels. The changes focus on modifying block dimensions (BLOCK_M and BLOCK_N) to optimize computation and memory usage.

Block size configuration updates:

  • Updated block sizes in Triton configurations within the init_func function to use BLOCK_N=128 instead of BLOCK_N=64, for both sequence-parallel and non-sequence-parallel configurations. This change aims to improve performance by increasing the block size for certain dimensions.

  • Replaced the hardcoded BLOCK=64 with separate BLOCK_M=128 and BLOCK_N=64 in the _flash_attn_forward function. This allows more flexibility in defining block dimensions for different axes, potentially improving memory alignment and computation efficiency.

  • Updated kernel invocation in _flash_attn_forward to pass the new BLOCK_M and BLOCK_N variables instead of the single BLOCK parameter, ensuring compatibility with the updated block size configuration.

Increases BLOCK_N from 64 to 128 in backward pass autotuning configs to improve memory throughput.

Separates forward pass block dimensions into distinct BLOCK_M (128) and BLOCK_N (64) variables instead of using a single BLOCK parameter, allowing for asymmetric block sizing that better matches memory access patterns.
Refactors the splitkv attention kernel to use attention mask and bias
instead of ZOH (zero-out-head) and active mask patterns.

Updates tensor definitions, memory layouts, and copy operations to
support the new mask/bias approach while maintaining the same sparse
matrix multiplication functionality.

Removes the Append_KV template parameter as it's no longer needed
with the simplified mask-based implementation.
Moves tensor gP declaration after tensor sS to improve memory access patterns and potentially optimize performance in the attention computation kernel.
Copilot AI review requested due to automatic review settings July 9, 2025 12:56
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refines block size configurations to boost performance and memory efficiency in Triton kernels and streamlines kernel templates.

  • Increased BLOCK_N from 64 to 128 in the autotune configs of init_func.
  • Replaced a single BLOCK constant with separate BLOCK_M and BLOCK_N in _flash_attn_forward.
  • Refactored compute_attn_1rowblock_splitkv in the C++ kernel: removed the Append_KV parameter, reorganized the gP tensor placement, and integrated separate mask and bias handling.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
flash_dmattn/flash_dmattn_triton.py Adjusted autotune configs (BLOCK_N), split BLOCK into BLOCK_M/BLOCK_N, and updated kernel invocation.
csrc/src/flash_fwd_kernel.h Cleaned up compute_attn_1rowblock_splitkv template signature, moved gP tensor, and added mask/bias support while removing Append_KV.
Comments suppressed due to low confidence (1)

flash_dmattn/flash_dmattn_triton.py:663

  • [nitpick] Consider adding a brief comment explaining the rationale for increasing BLOCK_N to 128 in the autotune configs, to help future maintainers understand this performance tuning decision.
            {"BLOCK_M": 64, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},

Comment on lines +881 to +882
BLOCK_M = 128
BLOCK_N = 64
Copy link

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider defining these magic block size values (BLOCK_M and BLOCK_N) as module-level constants or parameters, rather than inline literals, to centralize configuration and simplify tuning.

Suggested change
BLOCK_M = 128
BLOCK_N = 64
# Use module-level constants for block sizes
BLOCK_M = BLOCK_M
BLOCK_N = BLOCK_N

Copilot uses AI. Check for mistakes.
Comment on lines +950 to 952
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
// __syncthreads();

Copy link

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This debug print statement is commented out and can be removed to clean up the codebase and avoid clutter.

Suggested change
// if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); }
// __syncthreads();

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

@LoserCheems LoserCheems left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@LoserCheems LoserCheems merged commit d198e30 into main Jul 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants