-
Notifications
You must be signed in to change notification settings - Fork 39
Optimize sparse GEMM and enable in attention computation #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replaces per-MMA block activation checks with global activation detection to reduce warp divergence and improve performance. Removes template parameter for number of warps and simplifies tensor parameter list by eliminating unused shared memory mask tensor. Uses thread synchronization to ensure consistent branching behavior across all threads in the cooperative thread array, avoiding divergent execution paths that can hurt GPU performance.
Replaces regular GEMM operations with sparse GEMM variants throughout the attention computation kernel. Updates both the key-query multiplication and the attention-value multiplication to use sparse operations, incorporating active mask indices for improved performance on sparse attention patterns. Removes TODO comments and activates previously commented sparse multiplication parameters.
Combines multi-line parameter lists into single lines to improve code readability and maintain consistent formatting style across the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR integrates sparse GEMM into the attention kernel by detecting active blocks via a global mask reduction, removes per-MMA branching to simplify code paths, and refactors template parameter ordering for readability.
- Replace per-MMA activation flags with a single CTA-level
any_activecheck - Unify copy/clear logic under
any_activeforBloads - Update calls in
flash_fwd_kernel.hto usesparse_gemmandsparse_gemm_rs
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| csrc/src/utils.h | Simplified sparse_gemm template, replaced per-block mask loops with a CTA reduction, and unified copy/clear logic under any_active. |
| csrc/src/flash_fwd_kernel.h | Updated attention kernel calls to use sparse_gemm/sparse_gemm_rs and adjusted parameter lists for mask arguments. |
Comments suppressed due to low confidence (4)
csrc/src/utils.h:185
- [nitpick] The name
local_any_activemight be confused with warp- or CTA-level state. Consider renaming tothread_any_activeorcta_local_activeto clarify scope.
bool local_any_active = false;
csrc/src/utils.h:198
- New paths for
any_active == falseandany_active == trueshould be covered by tests (e.g., fully sparse vs. fully dense masks) to ensure correct behavior and catch regressions.
bool any_active = __syncthreads_or(local_any_active);
csrc/src/utils.h:203
- When
any_activeis true, this copies the entire batch of B tiles even for MMA blocks that may still be inactive. For very sparse patterns, this could waste memory bandwidth. Consider retaining per-MMA checks or using a masked/batched copy to only load truly active blocks.
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
csrc/src/utils.h:224
- Calling
cute::gemmon the entire batch when only a subset of MMA blocks is active may result in unnecessary compute. If sparse patterns are common, a per-block invocation or conditional mask could reduce wasted FLOPs.
if (any_active) {
| mma_active_found:; | ||
| } | ||
| // Ensure all threads in the CTA have the same any_active value to avoid warp divergence | ||
| bool any_active = __syncthreads_or(local_any_active); |
Copilot
AI
Jun 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using __syncthreads_or for a CTA-wide reduction may not be supported in standard CUDA. Consider using warp-level intrinsics like __any_sync for each warp and a final combine via shared memory or use Cooperative Groups (cg::grid_group.any()) to correctly reduce across all threads in the CTA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the __syncthreads_or is in sm_20_intrinsics.h!
|
@SNHuan please review this PR🤗 |
SNHuan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Enhance performance by removing per-MMA branching and implementing global activation detection. Integrate sparse GEMM operations into the attention computation kernel, improving efficiency for sparse patterns. Refactor function parameter formatting for better readability.