Skip to content

Fix CUDA ONNX Attention: min_bias_align crash on SM<80 and MEA NaN for fully-masked batches#27831

Open
titaiwangms wants to merge 1 commit intomainfrom
titaiwang/fix_cuda
Open

Fix CUDA ONNX Attention: min_bias_align crash on SM<80 and MEA NaN for fully-masked batches#27831
titaiwangms wants to merge 1 commit intomainfrom
titaiwang/fix_cuda

Conversation

@titaiwangms
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms commented Mar 24, 2026

Summary

Fixes two bugs in the CUDA ONNX Attention operator:

  1. min_bias_align crash on SM<80: The alignment check for Memory Efficient Attention (MEA) bias was too strict on SM<80 GPUs, causing MEA to be incorrectly rejected and falling back to the unfused path. Fixed by using a conservative 4*sizeof(T) alignment that works across all SM architectures.

  2. MEA NaN for fully-masked batches: When all positions in a batch are masked (nonpad_kv_seqlen=0), CUTLASS MEA computes 1/s_prime where s_prime=0, producing NaN in the output. Added ZeroOutputForFullyMaskedBatches kernel to zero the output for these batches before MEA runs.

Additional improvements

  • Fixed is_bsnh to explicit false in unfused decode path (BNSH-only)
  • Added TODO(titaiwang) documenting Flash Attention's semantic mismatch with ONNX spec for bool attn_mask + past_key (Flash interprets bool mask as padding mask, spec treats it as general attention mask)
  • Improved comments and docstrings for routing logic and ConvertAttnMaskToBias

Related

@yuslepukhin yuslepukhin requested review from Copilot and tianleiwu and removed request for Copilot March 24, 2026 19:34
Copy link
Copy Markdown
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

titaiwangms added a commit that referenced this pull request Mar 24, 2026
…igned dims

PR #27831 fell back to CUBLAS_DEFAULT_MATH which still uses TF32 on Ampere+ GPUs
(SM>=80) since cuBLAS 11.0. Changed to CUBLAS_PEDANTIC_MATH when dimensions are
not 4-aligned to guarantee no tensor core usage, preventing CUDA error 716
(misaligned address) on CUDA 12.9+.

Three-way logic in all three float cuBLAS GEMM helpers:
- TF32 requested + dimensions aligned: CUBLAS_TF32_TENSOR_OP_MATH
- TF32 requested + dimensions NOT aligned: CUBLAS_PEDANTIC_MATH
- TF32 not requested: CUBLAS_DEFAULT_MATH (unchanged)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms force-pushed the titaiwang/fix_cuda branch 2 times, most recently from 38fef79 to 0796037 Compare March 25, 2026 20:22
@titaiwangms titaiwangms changed the title Fix TF32 misaligned address error in cuBLAS GEMM functions Fix misaligned BiasLoader access in CUTLASS FMHA attention dispatch Mar 25, 2026
@titaiwangms
Copy link
Copy Markdown
Contributor Author

titaiwangms commented Mar 25, 2026

@titaiwangms titaiwangms force-pushed the titaiwang/fix_cuda branch 2 times, most recently from 9a1bf88 to 2cccef2 Compare March 26, 2026 23:44
Copy link
Copy Markdown
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

@titaiwangms titaiwangms changed the title Fix misaligned BiasLoader access in CUTLASS FMHA attention dispatch Fix ONNX Attention CUDA: bias alignment, unfused decode concat, and MEA NaN Mar 27, 2026
@titaiwangms titaiwangms requested a review from Copilot March 27, 2026 05:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes correctness and stability issues in the CUDA EP implementation of the ONNX Attention op, focusing on decode-time bool masking with past KV, and NaNs in CUTLASS Memory Efficient Attention (MEA) for fully-masked batches.

Changes:

  • Aligns MEA bias-stride eligibility logic in attention.cc and adds post-MEA output zeroing for fully-masked batches.
  • Fixes unfused decode behavior for bool masks with past KV by using variable-length KV concat consistent with Flash layout semantics (and zero-inits present buffers).
  • Adds C++ and Python tests covering decode bool-mask edge cases (partial masks, divergent per-batch seqlens).

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/test/python/transformers/test_onnx_attention/test_mha.py Adds CUDA-only graph-level tests for bool-mask decode with past KV (partial mask + divergent seqlens).
onnxruntime/test/providers/cpu/llm/attention_op_test.cc Adds CUDA execution tests forcing unfused path and verifying variable-length concat correctness in decode.
onnxruntime/core/providers/cuda/llm/attention_mask_impl.h Declares a CUDA helper to zero outputs for fully-masked batches (seqlens_k==0).
onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu Implements and instantiates the ZeroOutputForFullyMaskedBatches kernel and launcher.
onnxruntime/core/providers/cuda/llm/attention.cc Applies fully-masked output zeroing for MEA and unfused nonpad paths; fixes unfused bool-mask decode concat semantics; updates MEA eligibility alignment check.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@titaiwangms titaiwangms marked this pull request as ready for review March 27, 2026 17:13
@titaiwangms titaiwangms reopened this Mar 27, 2026
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@titaiwangms
Copy link
Copy Markdown
Contributor Author

  1. UnfusedRunner decode with bool mask + past KV (attention.cc)
    When a 4D bool attention mask is provided with past key/value in decode mode, the unfused path now uses variable-length concat (LaunchConcatNewToPastKV) placing the new token at position seqlens_k[b] — matching the Flash Attention present_key/value layout contract. This fixes incorrect token attendance where ConcatPastToPresent placed the new token at a fixed position that didn't match the mask.

This does not seem right. I will need to take another look.

- Use 4*sizeof(T) convention for min_bias_align (matches contrib_ops)
- Fix unfused decode with bool mask + past KV: variable-length concat
  placing new token at seqlens_k[b] position (Flash layout contract)
- Add ZeroOutputForFullyMaskedBatches kernel for MEA path (CUTLASS
  epilogue produces NaN when s_prime=0 for fully-masked batches)
- Fix is_bsnh mismatch in LaunchConcatNewToPastKV call
- Fix BFloat16 linker: use OrtToCudaType instead of ToCudaType for
  direct CUDA kernel launches
- Zero-init present buffers before variable-length concat
- Add C++ tests: partial mask decode, multi-batch divergent seqlens,
  all-false mask decode
- Add Python graph-level tests: partial mask decode, multi-batch
  divergent seqlens

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@titaiwangms titaiwangms changed the title Fix ONNX Attention CUDA: bias alignment, unfused decode concat, and MEA NaN Fix CUDA ONNX Attention: min_bias_align crash on SM<80 and MEA NaN for fully-masked batches Mar 28, 2026
@titaiwangms titaiwangms marked this pull request as ready for review March 28, 2026 18:43
@titaiwangms
Copy link
Copy Markdown
Contributor Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants