Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Sep 16, 2025

Summary

#161 #169

  • Replace vectorized/cp.async loads of mask/bias with scalar (per-element) loads to:
    • Prevent materialization of expanded mask/bias tensors (preserve views, reduce memory).
    • Robustly handle irregular sequence lengths (non-128-aligned K), eliminating OOB and misaligned-address errors.

Root Cause

  • Vectorized/cp.async loads assumed 16B-aligned row strides. With unpadded seqlen_k = 4095, bias rows became misaligned (e.g., 8190B stride for fp16), causing CUDA misaligned address errors.
  • Padding mask/bias to 128 aligned K avoided the fault but forced materialization, defeating memory savings from expand views.

Changes

  • Introduced column predicates (predicate_N) and updated copy helper:
    • utils.h: copy_MN supports per-column predicate and row-limit, with scalar elementwise copy. When Bool_to_Element=true, converts bool mask to numeric element.
  • Applied predicate-based scalar loads:
    • flash_fwd_kernel.h: Use copy_MN to load mask/bias with per-N predicates in masking steps; use even fast-path in non-masking loop.
    • flash_bwd_kernel.h: Symmetric updates for dS/mask/bias loads and dbias writes with Clear_OOB_MN=false.
  • Python interface unchanged semantically; mask/bias no longer need K-dimension padding, preserving expand views.

Reproduction

Before fix:

  1. Env: PyTorch 2.8.0a0+nv, CUDA, RTX 4090.
  2. Run benchmarks/forward_equivalence.py with:
    • batch=1, heads=2, kv_heads=1, seqlen_q=4095, seqlen_k=4095, head_dim=128, is_causal=True.
  3. Error: RuntimeError: CUDA error: misaligned address during forward.

After fix:

  • Same config passes; no CUDA errors.
  • Accuracy vs Python prototype ≥ threshold (≥95%).

Tests

  • Forward/backward equivalence: standard suites across:
    • irregular K (4095), regular K (4096/4096, 8192/8192), causal/non-causal.
  • Performance sanity: minor overhead isolated to tail blocks; end-to-end impact negligible in typical shapes.

Compatibility

  • No API changes.
  • Behavior identical for regular shapes.
  • Memory usage reduced when using expand views for mask/bias (no materialization).

Checklist

  • Linked issue provided (internal)
  • Adds or updates tests (equivalence for irregular K; OOB regression)
  • Updates docs/changelog as needed
  • No correctness regressions
  • Perf impact acceptable (tail-only, negligible E2E)

Uses in-place nan_to_num_ operation for better memory efficiency.

Updates tensor sanitization to use dtype-specific infinity bounds instead of fixed values, preventing potential overflow issues.

Changes tensor initialization from empty_like to zeros_like to ensure deterministic starting values for gradients.

Fixes bias padding value from minimum float to zero for better numerical behavior.

Enhances documentation to clarify support for flexible mask and bias head dimensions in MQA/GQA scenarios.
Eliminates unnecessary padding of key and value tensors to multiples of 128 in sequence length dimension.

Removes associated context saving and gradient unpadding operations that are no longer needed without the sequence length padding.

Simplifies the forward and backward pass implementation by removing conditional padding logic for masks and biases.
Replaces vectorized copy with element-wise assignment to prevent memory access violations when bounds checking is disabled.

Changes predicate handling to use dedicated predicate tensor instead of coordinate-based bounds checking for improved safety.

Updates default Clear_OOB_MN to false and removes max_N parameter as bounds checking now relies on predicate tensor.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR fixes a critical bug that prevented mask/bias materialization and caused out-of-bounds (OOB) errors for irregular sequence lengths. The fix replaces vectorized loads with scalar per-element loads to handle non-128-aligned sequence lengths robustly while preserving memory efficiency.

  • Removes K-dimension padding requirement for mask/bias tensors, preserving expand views
  • Updates copy utilities to use scalar loads with column predicates instead of vectorized operations
  • Improves tensor sanitization to use appropriate floating-point limits

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
flash_dmattn/flash_dmattn_interface.py Removes K-dimension padding logic and improves tensor sanitization
csrc/flash_dmattn/src/utils.h Updates copy_MN function to use scalar loads with predicates instead of vectorized copies

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

):
q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors
dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias)
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.zeros_like() initializes all tensors with zeros, which is unnecessary overhead since these gradient tensors will be fully written by the backward kernel. Consider using torch.empty_like() for better performance.

Suggested change
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias)

Copilot uses AI. Check for mistakes.
):
"""
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Supports multi-query attention and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The word 'Similarity' on line 331 should be 'Similarly' (missing 'l').

Copilot uses AI. Check for mistakes.
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2
bool Is_even_MN=true, bool Clear_OOB_MN=false, bool Bool_to_Element=false, typename To_type=void,
// typename TiledCopy,
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out typename TiledCopy parameter should be removed entirely rather than left as a comment, as it's no longer used in the function signature.

Suggested change
// typename TiledCopy,

Copilot uses AI. Check for mistakes.
Comment on lines 525 to 530
// typename TiledCopy,
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3
>
__forceinline__ __device__ void copy_MN(
TiledCopy tiled_copy,
// TiledCopy tiled_copy,
Copy link

Copilot AI Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out TiledCopy tiled_copy parameter should be removed entirely rather than left as a comment, as it's no longer used in the function.

Copilot uses AI. Check for mistakes.
Improves memory alignment by ensuring head dimensions are padded to multiples of 8 for 16-bit memory allocations.

Comments out sequence length padding implementation for future consideration, including corresponding mask and bias padding logic in both forward and backward passes.
Implements a device function that performs logical OR reduction across mask tensor elements and synchronizes the result across thread blocks using warp-level primitives.

Enables efficient sparse attention pattern processing by allowing threads to collectively determine if any mask elements are active within a given region.
Splits the generic copy_MN function into four specialized functions:
- copy_MN for basic tensor copying with tiled copy operations
- copy_mask for masked copying operations
- copy_mask_with_or_reduce for copying with OR reduction and block activity tracking
- copy_bias for bias-specific copying with element-wise assignment

Removes the Bool_to_Element template parameter and related conditional logic,
simplifying the codebase by creating purpose-specific functions instead of
a single overloaded function with multiple behaviors.
Refactors combined mask-bias memory operations into separate dedicated operations to improve performance and maintainability.

Introduces specialized copy functions for mask and bias operations with proper bounds checking and OR-reduction for mask activity detection.

Removes redundant synchronization points by leveraging built-in synchronization in the new copy functions.

Adds predicate tensor allocation for proper boundary handling in both regular and split-KV attention kernels.
Splits the unified GmemTiledCopyMaskBias into separate GmemTiledCopyMask and GmemTiledCopyBias operations to enable independent optimization of memory access patterns.

Introduces specialized copy_mask_with_or_reduce and copy_bias functions that replace generic copy_MN calls, allowing for better memory coalescing and reduced synchronization overhead.

Adds predicate tensor allocation for bounds checking on the N dimension to improve memory safety and enable more efficient vectorized operations in future optimizations.
Splits the combined GmemTiledCopyMaskBias type into separate GmemTiledCopyMask and GmemTiledCopyBias types in both forward and backward kernel traits.

This separation improves code clarity and allows for independent handling of mask and bias copy operations, enabling more flexible memory access patterns and potential optimizations.
@LoserCheems LoserCheems requested a review from Copilot September 17, 2025 04:54
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.


Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines 14 to 19
def _sanitize_tensors(*tensors: Optional[torch.Tensor], nan: float = 0.0, posinf: float = 1e6, neginf: float = -1e6) -> None:
for t in tensors:
if t is not None and isinstance(t, torch.Tensor):
torch.nan_to_num(t, nan=nan, posinf=posinf, neginf=neginf, out=t)
torch.nan_to_num_(t, nan=nan, posinf=posinf, neginf=neginf)


Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function signature changed from using out=t parameter to in-place operation torch.nan_to_num_(), but the function parameters still include nan, posinf, neginf which suggests the old API expected these to be configurable. However, the function calls at lines 98 and 173 now pass specific dtype-based values, making the default parameters in the function signature potentially misleading.

Copilot uses AI. Check for mistakes.
):
q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors
dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias)
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using torch.zeros_like() instead of torch.empty_like() initializes the tensors with zeros, which adds unnecessary overhead since these tensors will be completely overwritten by the backward computation.

Suggested change
dq, dk, dv, dbias = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v), torch.zeros_like(bias)
dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias)

Copilot uses AI. Check for mistakes.
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3
>
__forceinline__ __device__ void copy_mask(
Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy_mask function at lines 585-612 is nearly identical to copy_MN function at lines 548-575, with only minor template parameter differences. This duplicates logic and creates maintenance burden.

Copilot uses AI. Check for mistakes.
Comment on lines +685 to 689
// cute::copy(tiled_copy, S(_, m, n), D(_, m, n));
#pragma unroll
for (int i = 0; i < size<0>(S); ++i) {
D(i, m, n) = S(i, m, n);
}
Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented out cute::copy call on line 685 suggests this was changed to manual scalar copying, but the comment should be removed or explain why the manual loop is necessary instead of using the copy utility.

Copilot uses AI. Check for mistakes.
Comment on lines +391 to +399
FLASH_NAMESPACE::copy_bias<Is_even_MN, /*Clear_OOB_MN=*/true>(
gmem_tiled_copy_Bias,
tBiasgBias(_, _, _, n_block), tBiassBias,
tBiascBias,
binfo.actual_seqlen_q - m_block * kBlockM, binfo.actual_seqlen_k - n_block * kBlockN
tBiascBias, tBiaspBias,
binfo.actual_seqlen_q - m_block * kBlockM
);
// Because copy_bias currently uses scalar loads, we need to sync here.
// TODO: Remove sync after fixing to vectorized loads.
__syncthreads();
Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiple instances of manual __syncthreads() calls are added (lines 399, 525, 654, 1085, 1233, 1380) specifically for scalar bias copying. This adds synchronization overhead that could impact performance, and the TODO comments indicate this is a temporary workaround.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 1cca349 into main Sep 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants