-
Notifications
You must be signed in to change notification settings - Fork 39
[BUG FIX] Fix mask/bias memory access and vectorization issues in kernels #182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces compile-time branching to separate the even-tile fast path from ragged edges. On even tiles, performs unguarded bulk copies and removes per-element predicates. On ragged tiles, guards M/N bounds, switches to element-wise copy, and explicitly clears out-of-bounds regions. Reduces runtime branching and divergence, improving correctness on partial tiles and performance on full tiles.
Pads key/value sequence length to a multiple of 8 and adjusts mask/bias accordingly to satisfy kernel alignment. Stores the original length and slices gradients/bias in backward to restore shapes. Improves correctness and supports non-multiple-of-8 sequence lengths without shape mismatches.
Uses dedicated shared-memory copy ops for mask and bias to match their layouts, preventing stride/type mismatches in attention computation and improving correctness/perf. Applies to both regular and split-KV paths and cleans minor whitespace.
Standardizes the mask element type as uint8_t in base traits and exposes it in forward/backward kernel traits. Improves consistency and avoids missing-type compile errors where the mask type is referenced, while easing future type changes.
Updates forward and backward paths to use a non-vectorized copy for masks and a hardware-tuned global copy for bias, avoiding unsafe 128B alignment assumptions on masks and improving portability. Improves correctness on potentially unaligned mask accesses and aligns bias copies with the chosen gmem policy, with minor cleanups in tiled copy definitions.
Removes block-wide barriers that were only needed when bias loads were scalar. With vectorized bias copies and async copy fencing in place, the extra synchronization is unnecessary. Reduces sync overhead and stalls, improving forward attention performance without affecting correctness.
Adds a dedicated memory copy path for bias gradients and uses proper shared-memory partitioning for mask/bias, aligning with the compute tile. Replaces scalar bias copies with vectorized transactions, allowing removal of explicit synchronization after bias copy operations. Improves performance and avoids layout mismatches in bias-enabled backward passes.
Collapses separate even/uneven paths into a single unrolled loop that uses tiled copies for in-bounds regions and clears out-of-bounds elements when requested. Replaces scalar element-wise copies with vectorized/tiled copies on valid tiles to improve performance and reduce code duplication while preserving correctness on partial tiles.
Removes tensor clamping in forward/backward to preserve true values and reduce overhead. Guards slicing of an optional bias to avoid None errors when sequence length isn't divisible by 8.
Unifies mask handling for even/odd shapes and N predicates, always using the tiled path and clearing OOB uniformly. Removes the type-cast template and per-element copy, reducing branching and improving performance. Fixes block activity detection by syncing and OR-reducing over the destination after copy, preventing false negatives; renames the output flag for clarity.
Separates per-matrix global-memory layouts and thread mapping to account for differing element sizes, improving coalescing and alignment. Switches mask transfers to aligned auto-vectorized paths and widens mask load width, plus adds divisibility assertions to catch misconfigurations early. Cleans up and clarifies shared-memory layout comments/structure for mask and bias, while preserving Q/K/V/O behavior.
Updates memory tiling to use per-type layouts (QKVO, mask, bias) with matching vector widths and thread mapping. Vectorizes mask copies in shared/global memory and increases mask read width to 16 for better bandwidth. Adds stronger compile-time checks to enforce alignment and divisibility, reducing misaligned accesses and improving coalescing and stability.
Standardizes mask dtype to an explicit element type in global/shared memory to fix type mismatches and ensure alignment. Aligns the shared mask buffer via a placeholder and updates the layout to avoid misaligned accesses. Replaces fused mask copy+reduce with a generic copy followed by an explicit OR-reduction and barrier for clearer synchronization and correctness. Unifies bias handling onto the generic copy path.
Aligns reduction configuration with the QKVO-specific per-row thread count to keep template and divisor consistent. Fixes a mismatch that could mis-partition threads, improving correctness and consistency in backward preprocessing.
Uses a dedicated mask element type with aligned shared memory, separating mask typing from shared buffers to prevent misalignment and aliasing. Replaces combined mask copy+reduce with a generic copy, explicit barrier, and a separate OR-reduction to ensure accurate activity detection. Unifies bias/mask transfers via generic copy utilities and updates the dot-product threading trait, improving correctness across mixed element types and preparing for varied mask formats.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes critical memory access and vectorization issues in forward and backward kernels for mask and bias operations. The main problem was incompatible memory copy strategies and incorrect synchronization that caused illegal memory access errors and static assertion failures.
- Separated memory layout configurations for QKV, mask, and bias tensors to prevent alignment conflicts
- Replaced cp.async operations with proper vectorized copy atoms for mask/bias loads
- Fixed synchronization strategy by using
__syncthreads()instead of cp.async synchronization for mask/bias operations
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_interface.py | Uncommented tensor sanitization calls and re-enabled sequence length padding logic |
| csrc/flash_dmattn/src/utils.h | Updated copy helper functions to use proper vectorized copies and fixed OR-reduction logic |
| csrc/flash_dmattn/src/kernel_traits.h | Added separate memory layout atoms and thread configurations for mask, bias, and QKV tensors |
| csrc/flash_dmattn/src/flash_fwd_kernel.h | Replaced custom mask copy functions with standard copy operations and proper synchronization |
| csrc/flash_dmattn/src/flash_bwd_preprocess_kernel.h | Updated thread count references to use new QKV-specific constants |
| csrc/flash_dmattn/src/flash_bwd_kernel.h | Applied similar mask/bias copy and synchronization fixes as forward kernel |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| return_softmax, | ||
| ) | ||
| _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) | ||
| # _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) |
Copilot
AI
Oct 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Commented out sanitization calls should be removed entirely rather than left as comments. If these are intended to be temporary debugging changes, consider using a feature flag instead.
| # _sanitize_tensors(out, nan=0.0, posinf=torch.finfo(out.dtype).max, neginf=torch.finfo(out.dtype).min) |
| } | ||
|
|
||
| //////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
Copilot
AI
Oct 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The template parameter To_type was removed from the function signature but the PR description mentions this was intentional. Consider adding a comment explaining why the type conversion is no longer needed.
| // [CodeQL][nitpick] The template parameter `To_type` was intentionally removed from the function signature. | |
| // Type conversion is no longer needed here because it is handled elsewhere, or the tensors involved are already of the correct type. |
| ); // (kBlockN, kHeadDim, nblocksN) | ||
| Tensor mMask = make_tensor( | ||
| make_gmem_ptr(reinterpret_cast<const bool*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), | ||
| make_gmem_ptr(reinterpret_cast<ElementMask*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), |
Copilot
AI
Oct 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reinterpret_cast removes const qualifier from params.mask_ptr. This should be reinterpret_cast<const ElementMask*> to maintain const correctness.
| make_gmem_ptr(reinterpret_cast<ElementMask*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), | |
| make_gmem_ptr(reinterpret_cast<const ElementMask*>(params.mask_ptr) + binfo.mask_offset(params.mask_batch_stride, params.mask_row_stride, bidb)), |
| ); | ||
| Tensor gMask = make_tensor( | ||
| make_gmem_ptr(reinterpret_cast<const bool *>(params.mask_ptr) + col_offset_mask), | ||
| make_gmem_ptr(reinterpret_cast<ElementMask *>(params.mask_ptr) + col_offset_mask), |
Copilot
AI
Oct 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same const qualifier issue as above. Should be reinterpret_cast<const ElementMask *> to maintain const correctness.
| make_gmem_ptr(reinterpret_cast<ElementMask *>(params.mask_ptr) + col_offset_mask), | |
| make_gmem_ptr(reinterpret_cast<const ElementMask *>(params.mask_ptr) + col_offset_mask), |
| typename Kernel_traits::SmemLayoutMaskBiasPdS{} | ||
| ); // For pointers alignment only | ||
| Tensor sMask = make_tensor( | ||
| make_smem_ptr(reinterpret_cast<ElementMask *>(sMaskPlace.data().get())), |
Copilot
AI
Oct 1, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The pattern of creating sMaskPlace for pointer alignment and then reinterpreting it to ElementMask* is repeated multiple times. Consider extracting this into a helper function to reduce code duplication.
Summary
This PR fixes critical memory access and alignment issues in forward and backward kernels related to mask and bias loading operations. The fixes resolve illegal memory access errors, static assertion failures, and cp.async compatibility issues when using dynamic masks and attention bias.
Fixes issues:
Root Cause
The root causes of these issues were:
Incompatible memory copy strategy: The original implementation attempted to use
SM80_CP_ASYNC_CACHEGLOBALwithcp.asyncinstructions for mask and bias loading. However,cp.asynchas strict alignment requirements (128-bit aligned) and is designed for element types likefp16/bf16, not foruint8_tmasks.Incorrect vectorization configuration: The mask loading attempted to use 16 values per read (16 ×
uint8_t= 128-bit), but the memory layout and thread partitioning were not properly configured for this vectorization level, leading to alignment violations.Synchronization mismatch: Mask/bias loads used
cp_async_fenceandcp_async_waitsynchronization, which is incompatible with standard global memory loads and caused race conditions.Shared memory layout conflicts: QKV tensors and mask/bias tensors used different access patterns, but shared the same memory layout configuration, leading to bank conflicts and inefficient memory access.
Changes
1. Memory Layout Separation (
kernel_traits.h)GmemLayoutAtomMaskandGmemLayoutAtomBiasfor mask and bias tensorskGmemThreadsPerRowQKVObased onkBlockKSmem / kGmemElemsPerLoadQKVOkGmemThreadsPerRowMaskbased onkBlockN / kGmemElemsPerLoadMaskkGmemThreadsPerRowBiasbased onkBlockN / kGmemElemsPerLoadBias2. Vectorized Copy Atoms
SM80_CP_ASYNC_CACHEGLOBALtoAutoVectorizingCopyWithAssumedAlignment<128>:SM80_CP_ASYNC_CACHEGLOBALfor QKV tensors where cp.async is appropriate3. Synchronization Strategy (
flash_fwd_kernel.h,flash_bwd_kernel.h)cp_async_fence+cp_async_waitwith__syncthreads()copy_mask_with_or_reducethat performs OR-reduction and implicit synchronization4. Helper Functions (
utils.h)copy_maskandcopy_biasto handle boundary conditions properlycopy_mask_with_or_reducefor efficient mask loading with early exit optimizationReproduction
Minimal Reproducible Example
Steps to Reproduce Original Issues
Tests
Validation Performed
Correctness Tests (forward_equivalence.py, backward_equivalence.py):
Performance Tests (forward_performance.py, backward_performance.py):
Edge Cases:
Test Results
Compatibility
Backward Compatibility
Migration Notes
Breaking Changes
None
Checklist