Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Reduces block dimensions across multiple head sizes to improve GPU occupancy and memory efficiency. Updates block configurations for head dimensions 32, 64, 96, 128, 192, and 256 to achieve better CTA (Cooperative Thread Array) utilization per streaming multiprocessor.

Adds debug comments for occupancy analysis and updates memory usage calculations in comments to reflect the new configurations.

Reduces block dimensions across multiple head sizes to improve GPU occupancy and memory efficiency. Updates block configurations for head dimensions 32, 64, 96, 128, 192, and 256 to achieve better CTA (Cooperative Thread Array) utilization per streaming multiprocessor.

Adds debug comments for occupancy analysis and updates memory usage calculations in comments to reflect the new configurations.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Optimizes CUDA kernel block dimensions across different head sizes to improve GPU occupancy and memory efficiency, while adding debug placeholders and updating related comments.

  • Adjust block configurations for head dimensions 32, 64, 96, 128, 192, and 256.
  • Insert commented debug printf and occupancy analysis snippets.
  • Update memory usage comments to reflect new block sizes and shared memory calculations.
Comments suppressed due to low confidence (3)

csrc/src/flash_fwd_launch_template.h:161

  • This historical segfault comment no longer aligns with the updated kBlockN logic. Consider removing or updating it to reflect the current block size thresholds.
    // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,

csrc/src/flash_fwd_launch_template.h:181

  • This performance comment refers to a 64×128 block size, but the code now launches with 64×64. Please update the comment to match the current configuration.
            // Using block size (64 x 128) is 27% slower for seqlen=2k

csrc/src/flash_fwd_launch_template.h:229

  • The comment suggests a 64×64 configuration for alignment, but the is_sm8x branch actually uses 64×32. Align the comment to reflect the code or vice versa.
            // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment

Comment on lines 113 to +115
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));
Copy link

Copilot AI Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider removing or wrapping this large block of commented-out debug code in a dedicated debug macro to reduce noise in the main dispatch path.

Suggested change
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));
DEBUG_PRINT("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));

Copilot uses AI. Check for mistakes.
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
// For A100, we want to run with 64 x 64 (112KB smem).
// For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM.
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
Copy link

Copilot AI Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The threshold expression uses raw magic numbers. Extract these calculations into named constants or add a brief comment explaining the chosen factors for clarity.

Suggested change
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
// Constants for shared memory calculations
constexpr static int SMEM_BASE = 128; // Base shared memory size
constexpr static int SMEM_MULTIPLIER = 64; // Multiplier for shared memory size
constexpr static int SMEM_FACTOR = 2; // Factor for additional shared memory
if (max_smem_per_block >= SMEM_FACTOR * Headdim * (SMEM_BASE + SMEM_FACTOR * SMEM_MULTIPLIER) &&
max_smem_per_sm < 4 * Headdim * (SMEM_MULTIPLIER + SMEM_FACTOR * SMEM_MULTIPLIER)) {

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 2920d92 into main Jul 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants