Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Enhance performance by removing per-MMA branching and implementing global activation detection. Integrate sparse GEMM operations into the attention computation kernel, improving efficiency for sparse patterns. Refactor function parameter formatting for better readability.

Replaces per-MMA block activation checks with global activation detection to reduce warp divergence and improve performance.

Removes template parameter for number of warps and simplifies tensor parameter list by eliminating unused shared memory mask tensor.

Uses thread synchronization to ensure consistent branching behavior across all threads in the cooperative thread array, avoiding divergent execution paths that can hurt GPU performance.
Replaces regular GEMM operations with sparse GEMM variants throughout the attention computation kernel.

Updates both the key-query multiplication and the attention-value multiplication to use sparse operations, incorporating active mask indices for improved performance on sparse attention patterns.

Removes TODO comments and activates previously commented sparse multiplication parameters.
Combines multi-line parameter lists into single lines to improve code readability and maintain consistent formatting style across the codebase.
@LoserCheems LoserCheems requested review from Evanwu1125, SNHuan, Copilot and wubingheng111 and removed request for Copilot June 21, 2025 03:54
@LoserCheems LoserCheems added the feature New feature request label Jun 21, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR integrates sparse GEMM into the attention kernel by detecting active blocks via a global mask reduction, removes per-MMA branching to simplify code paths, and refactors template parameter ordering for readability.

  • Replace per-MMA activation flags with a single CTA-level any_active check
  • Unify copy/clear logic under any_active for B loads
  • Update calls in flash_fwd_kernel.h to use sparse_gemm and sparse_gemm_rs

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
csrc/src/utils.h Simplified sparse_gemm template, replaced per-block mask loops with a CTA reduction, and unified copy/clear logic under any_active.
csrc/src/flash_fwd_kernel.h Updated attention kernel calls to use sparse_gemm/sparse_gemm_rs and adjusted parameter lists for mask arguments.
Comments suppressed due to low confidence (4)

csrc/src/utils.h:185

  • [nitpick] The name local_any_active might be confused with warp- or CTA-level state. Consider renaming to thread_any_active or cta_local_active to clarify scope.
    bool local_any_active = false;

csrc/src/utils.h:198

  • New paths for any_active == false and any_active == true should be covered by tests (e.g., fully sparse vs. fully dense masks) to ensure correct behavior and catch regressions.
    bool any_active = __syncthreads_or(local_any_active);

csrc/src/utils.h:203

  • When any_active is true, this copies the entire batch of B tiles even for MMA blocks that may still be inactive. For very sparse patterns, this could waste memory bandwidth. Consider retaining per-MMA checks or using a masked/batched copy to only load truly active blocks.
            cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));

csrc/src/utils.h:224

  • Calling cute::gemm on the entire batch when only a subset of MMA blocks is active may result in unnecessary compute. If sparse patterns are common, a per-block invocation or conditional mask could reduce wasted FLOPs.
        if (any_active) {

mma_active_found:;
}
// Ensure all threads in the CTA have the same any_active value to avoid warp divergence
bool any_active = __syncthreads_or(local_any_active);
Copy link

Copilot AI Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using __syncthreads_or for a CTA-wide reduction may not be supported in standard CUDA. Consider using warp-level intrinsics like __any_sync for each warp and a final combine via shared memory or use Cooperative Groups (cg::grid_group.any()) to correctly reduce across all threads in the CTA.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the __syncthreads_or is in sm_20_intrinsics.h!

@LoserCheems
Copy link
Collaborator Author

@SNHuan please review this PR🤗

Copy link
Collaborator

@SNHuan SNHuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@SNHuan SNHuan merged commit f936441 into main Jun 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants