Skip to content

GQA unfused attention with FP32 QK accumulation (fixes #28195)#28198

Merged
tianleiwu merged 8 commits intomainfrom
tlwu/unfused_gqa
Apr 25, 2026
Merged

GQA unfused attention with FP32 QK accumulation (fixes #28195)#28198
tianleiwu merged 8 commits intomainfrom
tlwu/unfused_gqa

Conversation

@tianleiwu
Copy link
Copy Markdown
Contributor

Description

Add a GQA-capable unfused CUDA attention kernel that writes Q·K^T to an FP32 scratch buffer, fixing fp16/bf16 overflow producing NaN when head_size > 256 at scale=1.0 (issue #28195, e.g. Gemma 4 global attention layers with head_dim=512).

Motivation

Gemma 4 uses head_dim=512 for its global attention layers (num_attention_heads=8, num_key_value_heads=4). Flash Attention and Memory-Efficient Attention cap at head_size=256, so these fall through to the unfused path. The existing unfused MHA runner produces NaN because even though cuBLAS accumulates in FP32, the Q·K^T output tensor is fp16 and overflows. Additionally, the MHA unfused runner cannot handle GQA (q_num_heads != kv_num_heads).

Key Changes

New kernel (contrib_ops/cuda/bert/gqa_unfused_attention.cu/.h):

  • 3-stage pipeline: QK GEMM → softmax → AV GEMM
  • QK GEMM uses CUBLAS_COMPUTE_32F with CUDA_R_32F output type — raw Q·K^T scores written to FP32 scratch, eliminating fp16 overflow
  • Reshape-Q trick for native GQA support (no K/V head replication needed)
  • Softmax supports causal mask, sliding window (local_window_size), softcap, additive attention bias, and per-batch seqlens_k
  • Per-batch past calculation for correct sliding-window masking with variable-length sequences

GQA contrib op integration (group_query_attention.cc, group_query_attention_impl.cu):

  • Activates when Flash/MEA/XQA are all ineligible and KV cache is not quantized
  • Uses PrepareQKV for RoPE and K/V cache management, then routes to the new kernel

ONNX Attention op integration (attention.cc, attention.h):

  • New RunGqaUnfusedAttention path for GQA and fp16/bf16 with head_size > 128
  • Handles BSNH↔BNSH transposes, past_key concatenation, attn_mask→bias conversion, nonpad_kv_seqlen
  • Optimized: transposes BSNH K/V directly into present_key/present_value when available

UnpackRoPEAppend kernel (group_query_attention_qkv.cuh):

  • Raised MAX_HEAD_SIZE from 256 to 512 to support Gemma 4 head dimensions

Safety improvements:

  • SafeInt<size_t> for workspace size arithmetic (overflow protection)
  • static_assert guarding GQA transpose paths against accidental float instantiation

Testing

  • 8 new Gemma 4 regression tests in test_gqa.py: prompt/decode × fp16/bf16, softcap, sliding window, long past sequences
  • 2 new Gemma 4 benchmark configs in benchmark_gqa.py (global + local attention)
  • All TestGQARegressions tests pass locally (12/12)

Fixes

Fixes #28195

Add a GQA-capable unfused CUDA attention kernel that writes Q*K^T to an
FP32 scratch buffer, fixing fp16/bf16 overflow/NaN when head_size > 256
(e.g. Gemma 4 with head_dim=512, scale=1.0).

Key changes:
- New gqa_unfused_attention.cu/.h: 3-stage kernel (QK GEMM in FP32,
  softmax with causal/sliding-window/softcap/bias/seqlens_k masking,
  AV GEMM). Uses reshape-Q trick for GQA (no K/V head replication).
- GQA contrib op integration: fallback when Flash/MEA/XQA ineligible
  and KV is not quantized. Handles BNSH cache layout.
- ONNX Attention op integration: new RunGqaUnfusedAttention path for
  GQA (q_num_heads != kv_num_heads) and fp16/bf16 head_size > 128.
  Supports past_key/value, attn_mask, nonpad_kv_seqlen, softcap.
- Raise MAX_HEAD_SIZE from 256 to 512 in UnpackRoPEAppend kernel.
- Per-batch past calculation in softmax for correct sliding-window
  masking with variable-length seqlens_k.
- SafeInt overflow checks for workspace size arithmetic.
- Tests: 8 Gemma 4 regression tests (prompt/decode, fp16/bf16,
  softcap, sliding window) + benchmark configs.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new CUDA unfused attention fallback that supports GQA and avoids fp16/bf16 QK overflow by writing Q·Kᵀ into an FP32 scratch buffer, addressing NaNs for large head sizes (e.g., Gemma 4 global attention with head_dim=512).

Changes:

  • Introduces gqa_unfused_attention.{cu,h} implementing a 3-stage unfused pipeline (QK GEMM → softmax → AV GEMM) with FP32 QK scratch and native GQA support.
  • Integrates the new fallback into both contrib GroupQueryAttention and the ONNX Attention CUDA kernel dispatch logic.
  • Extends RoPE+KV append support to head_size <= 512 and adds Gemma 4 regression tests/bench configs.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h Declares params/workspace API and launch entrypoint for the new unfused GQA kernel.
onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu Implements FP32 QK GEMM + softmax + AV GEMM unfused attention kernel with GQA support.
onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Allocates scratch and enables new unfused fallback when Flash/MEA/XQA are ineligible.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu Routes contrib GQA execution to the new unfused kernel when selected.
onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh Raises supported RoPE append head size by adding a 512 specialization.
onnxruntime/contrib_ops/cuda/bert/attention_data.h Adds flags and scratch pointers for the new unfused fallback path in contrib GQA.
onnxruntime/core/providers/cuda/llm/attention.h Declares RunGqaUnfusedAttention for ONNX Attention CUDA kernel.
onnxruntime/core/providers/cuda/llm/attention.cc Implements RunGqaUnfusedAttention and updates dispatch to use it for GQA and/or large-head fp16/bf16.
onnxruntime/test/python/transformers/test_gqa.py Adds Gemma 4 (head_dim=512) regression tests targeting the unfused fallback behavior.
onnxruntime/test/python/transformers/benchmark_gqa.py Adds Gemma 4 benchmark configurations (global + sliding-window variants).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/providers/cuda/llm/attention.h Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu
Comment thread onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/core/providers/cuda/llm/attention.cc
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation Review: GQA Unfused Attention with FP32 QK Accumulation

Overall this is a well-engineered PR. The kernel design is clean — the reshape-Q trick for GQA is elegant and memory-efficient, the 3-pass softmax correctly handles all edge cases (fully-masked rows emit zeros, sliding window, softcap, per-batch seqlens), and the ToFloat<T> template specializations properly handle half/bf16→float conversion.

Positive highlights:

  • The gqa_unfused_attention.cu kernel is self-contained with a clear 3-stage pipeline and excellent header documentation
  • SafeInt usage in group_query_attention.cc scratch allocation is well done
  • static_assert guards against accidental float instantiation and quantized KV cache — good defensive programming
  • Test coverage for the contrib GQA path is thorough (prompt/decode × fp16/bf16 × softcap/sliding-window)
  • ONNX Attention tests (TestONNXAttentionGQALargeHeadUnfused) added — this addresses the main test gap
  • Dispatch ordering (Flash → MEA → GQA Unfused → Legacy MHA) is correct

Issues found (detailed in inline comments):

# Severity Issue
1 🔴 Bug y_bnsh_bytes uses H (head_size) instead of H_v (v_head_size) for output buffer
2 🔴 High attention.cc has 6 raw size_t multiplication chains without SafeInt
3 🔴 High AlignTo helper drops SafeInt protection by accepting size_t
4 🟡 Medium ONNX Attention tests don't cover attn_mask or nonpad_kv_seqlen with head_size=512
5 🟡 Medium Missing VERBOSE logging at GQA unfused dispatch branch
6 🟡 Medium Softcap guard needs defense-in-depth comment for when PR #27992 merges
7 🟢 Low Transpose uses head_size instead of v_head_size for output dimension

Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Comment thread onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu Outdated
Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@titaiwangms
Copy link
Copy Markdown
Contributor

Additional review comments (items not in diff-visible lines)

🔴 SafeInt missing in attention.cc RunGqaUnfusedAttention (6 locations):

The new method has raw size_t multiplication chains without SafeInt protection, unlike the contrib GQA path which correctly wraps everything in SafeInt. Affected lines in the diff:

  • const size_t q_bytes = static_cast<size_t>(B) * S_q * N_q * H * sizeof(T);
  • const size_t kn_bytes = static_cast<size_t>(B) * parameters.kv_sequence_length * N_kv * H * sizeof(T);
  • const size_t vn_bytes = ...
  • const size_t k_bytes = ..., const size_t v_bytes = ...
  • const size_t out_bytes = ...

All should use SafeInt<size_t>(B) * ... to match the pattern in group_query_attention.cc.


🟡 Missing VERBOSE logging at GQA unfused dispatch (attention.cc, around the dispatch block):

PR #27992 establishes VERBOSE logging at each dispatch branch. The new GQA unfused branch should include:

LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using GQA Unfused Attention"
                      << " (batch=" << B << ", q_seq=" << S_q
                      << ", total_kv=" << total_kv
                      << ", head_size=" << H
                      << ", fp32_qk=" << needs_fp32_qk_scratch << ")";

🟡 Softcap guard defense-in-depth comment (attention.cc, softcap rejection after GQA unfused):

After PR #27992 merges, the legacy MHA unfused path will support softcap too, making this guard dead code. Please add:

// Defense-in-depth: the GQA unfused path above supports softcap; the legacy MHA
// unfused path below does not yet. After PR #27992 merges, the legacy path will
// also support softcap and this guard becomes a safety net.

🟡 ONNX Attention test coverage gaps (test_onnx_attention/test_gqa.py):

The TestONNXAttentionGQALargeHeadUnfused tests cover core prompt/decode well, but don't exercise:

  1. attn_maskattn_bias conversion with head_size=512
  2. nonpad_kv_seqlenseqlens_k conversion

Suggest adding at least one test with has_attn_mask=True, attn_mask_dims=4 to the test cases generator.

- AlignTo now accepts SafeInt<size_t> to maintain overflow protection
  through alignment arithmetic (fixes SafeInt gap).
- y_bnsh_bytes uses H_v (v_head_size) instead of H (head_size) for the
  Y output buffer to prevent latent under-allocation if v_head_size
  ever differs from head_size.
- Add ORT_ENFORCE(head_size == v_head_size) assertion in
  UnfusedGqaAttention to make the invariant explicit.
…d_size

- Use SafeInt<size_t> for all 6 raw size_t multiplication chains in
  RunGqaUnfusedAttention (attention.cc) to detect overflow.
- Add VERBOSE logging at GQA unfused dispatch in ComputeInternal.
- Strengthen softcap guard comment for defense-in-depth.
- Use p.v_head_size for output BNSH->BSNH transpose in impl.
- Add test cases for attn_mask and nonpad_kv_seqlen with head_size=512
  in ONNX Attention GQA tests.
@tianleiwu
Copy link
Copy Markdown
Contributor Author

Thanks for the thorough review! All 7 issues from the summary are now addressed:

# Issue Resolution
1 y_bnsh_bytes uses H not H_v Fixed: added H_v from v_head_size (eeec512)
2 6 raw size_t chains in attention.cc Fixed: all use SafeInt<size_t> now (b0f71ac)
3 AlignTo drops SafeInt Fixed: accepts SafeInt<size_t> (eeec512)
4 Tests missing attn_mask/nonpad_kv_seqlen at head_size=512 Added prompt_attn_mask and prompt_nonpad_seqlen test cases (b0f71ac)
5 Missing VERBOSE logging Added at GQA unfused dispatch branch (b0f71ac)
6 Softcap guard comment Strengthened with defense-in-depth note (b0f71ac)
7 Output transpose uses head_size not v_head_size Fixed: uses p.v_head_size (b0f71ac)

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc Outdated
Comment thread onnxruntime/test/providers/cpu/llm/attention_op_test.cc Outdated
Comment thread onnxruntime/test/providers/cpu/llm/attention_op_test.cc Outdated
Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deep Code Review — Round 2

Focus: MEDIUM to CRITICAL correctness issues. Checked status of Round 1 findings.

Round 1 Status

  • FIXED: y_bnsh_bytes now correctly uses H_v instead of H — good fix
  • FIXED: SafeInt usage in buffer allocations is thorough
  • ⚠️ NEW findings below

Finding 1: MEDIUM-HIGH — Softcap+Bias Ordering Violates ONNX Spec

File: gqa_unfused_attention.cu, GqaUnfusedSoftmaxKernel (all 3 passes)

The kernel applies: scale → softcap → bias → mask → softmax
ONNX spec requires: scale → add_mask → softcap → softmax

When both softcap > 0 and attn_bias != nullptr are active (possible in the ONNX Attention path), results won't match spec. The contrib GQA path always passes attn_bias=nullptr, so it's unaffected.

Fix: In all three passes, swap the order — apply bias before softcap:

// Current (wrong):
float x = qk_in[...] * scale;
if (softcap > 0.f) { x = softcap * tanhf(x / softcap); }
if (has_bias) { x += ToFloat(attn_bias[...]); }

// Correct:
float x = qk_in[...] * scale;
if (has_bias) { x += ToFloat(attn_bias[...]); }
if (softcap > 0.f) { x = softcap * tanhf(x / softcap); }

This is the same ordering issue PR #27992 fixed for the legacy unfused path.


Finding 2: MEDIUM — LaunchConcatNewToPastKV Uses H for V Cache

File: attention.cc, RunGqaUnfusedAttention, line calling LaunchConcatNewToPastKV

The concat kernel takes a single head_size param used for both K and V. When H != H_v (asymmetric head sizes, allowed by ONNX Attention spec), the V cache concatenation uses wrong strides.

Mitigation: GQA models universally have H==H_v. This is a pre-existing API limitation.

Recommendation: Add a defensive guard:

ORT_RETURN_IF(H != H_v && past_key != nullptr,
              "RunGqaUnfusedAttention: past_key with H != H_v not supported");

Finding 3: LOW-MEDIUM — Integer Overflow in Stride Calculations

File: gqa_unfused_attention.cu, lines in LaunchQkGemmFp32 and LaunchAttnVGemm

Stride calculations like:

const int64_t stride_q = static_cast<int64_t>(group) * S_q * H;

compute group * S_q * H in int before the cast to int64_t. With large sequences, this could overflow.

Fix: Cast earlier: static_cast<int64_t>(group) * static_cast<int64_t>(S_q) * H


Finding 4: MEDIUM — ONNX Attention Test Gaps

Missing test combinations for the ONNX Attention path at head_size=512:

  1. past_key + attn_mask — exercises concat + bias path together
  2. softcap + attn_mask — would expose Finding 1
  3. BSNH (3D) input — all C++ tests use 4D BNSH
  4. fp32 type — the float instantiation is untested

The contrib GQA tests are thorough — good work on prompt/decode/bf16/softcap/sliding-window coverage.


Positive Notes

  • Excellent header documentation in gqa_unfused_attention.h
  • SafeInt usage in allocations is solid
  • 3-pass softmax with fully-masked row handling is correct
  • cuBLAS FP32 compute for QK GEMM is the right fix for fp16 overflow
  • Clean test structure with _run_gemma4_gqa helper

@tianleiwu
Copy link
Copy Markdown
Contributor Author

Finding 1: MEDIUM-HIGH — Softcap+Bias Ordering Violates ONNX Spec

onnx spec is wrong. Softcap shall be before mask. See onnx/onnx#7865. This implementation will keep the right order : scale -> softcap -> attn_bias --> attn_mask --> softmax.

@titaiwangms
Copy link
Copy Markdown
Contributor

Correction to Round 2 Review — Finding 1 (Softcap Ordering)

After reviewing onnx/onnx#7865 filed by @tianleiwu, we are retracting Finding 1 (softcap+bias ordering violates ONNX spec).

@tianleiwu correctly identified that the ONNX spec itself is wrong: applying softcap after mask maps -inf-softcap (a finite value), which leaks nonzero probability to masked positions. The correct ordering (per Gemma 2 / Flash Attention convention) is:

scale → softcap → add_bias → mask → softmax

PR #28198's GqaUnfusedSoftmaxKernel already implements this correct ordering. We are updating PR #27992's unfused path to match.

Updated Findings for PR #28198

# Severity Status Issue
1 MEDIUM-HIGH RETRACTED Softcap ordering — implementation is correct per Gemma/Flash convention
2 MEDIUM Still valid LaunchConcatNewToPastKV H≠H_v limitation — suggest defensive guard
3 LOW-MEDIUM Still valid Int overflow risk in stride calculations (int before int64_t cast)
4 MEDIUM Still valid ONNX Attention test gaps (past+mask, softcap+mask, BSNH, fp32)

Copy link
Copy Markdown
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Round 3 Critical Review — Architecture, Security, Performance

Reviewer: Critical Reviewer (AI Agent, PR #27992 crew)
Scope: Memory safety, FP32 accumulation correctness, edge cases, integration, error handling

Previous findings status

  • Softcap ordering (R2-H1): RETRACTED — PR ordering (softcap before bias) is correct per onnx/onnx#7865
  • H≠H_v in LaunchConcatNewToPastKV (R2-F2): ✅ FIXED — ORT_RETURN_IF(H != H_v, ...) guard added at attention.cc line ~895
  • Int overflow in strides (R2-F3): ✅ Verified safe — all stride math uses int64_t / SafeInt<size_t>
  • Test gaps (R2-F4): ✅ ADDRESSED — new tests for past_key+attn_mask, softcap+attn_mask, BSNH, fp32

New findings this round

MEDIUM: Softmax kernel grid.x can exceed 65535 for large models

File: gqa_unfused_attention.cu, LaunchGqaUnfusedSoftmax

const dim3 grid(params.num_heads * params.q_sequence_length, params.batch_size, 1);

grid.x = N_q * S_q. For prompt with N_q=128 heads and S_q=8192, grid.x = 1,048,576. This is fine — CUDA supports gridDim.x up to 2^31-1 (since compute capability 3.0). grid.y = batch_size is limited to 65535, which is not a practical concern.

Verdict: No issue. Verified safe for all realistic configurations.

MEDIUM: QkElementCount returns size_t but callers pass int args

File: gqa_unfused_attention.cu, line ~48

inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total_kv) {
  return SafeInt<size_t>(batch_size) * num_heads * q_seq * total_kv;
}

This is correct — SafeInt<size_t> on the first operand ensures the entire multiplication chain is checked for overflow. No issue.

Verified correct: FP32 QK accumulation

The cublasGemmStridedBatchedEx call is correctly configured:

  • Input A/B type: CudaTypeFor<T>() (fp16/bf16/fp32 as appropriate)
  • Output C type: CUDA_R_32F (always FP32)
  • Compute type: CUBLAS_COMPUTE_32F
  • alpha/beta: float (matches compute type)

This means the QK GEMM accumulates in FP32 AND writes to FP32 scratch — no intermediate fp16 overflow possible. ✅

The AV GEMM correctly uses the original type T for output (the softmax weights are already bounded [0,1] so no overflow risk). ✅

Verified correct: Workspace layout and lifetime

The workspace is a single allocation split into two regions:

  1. [0, qk_bytes): FP32 QK scratch
  2. [qk_bytes, qk_bytes + softmax_bytes): Type T softmax output

Both regions are 256-byte aligned. The allocation is held by ws_buffer (ONNX path) or unfused_scratch (contrib path), both of which live until the function returns. All CUDA operations use the same stream, so there are no async lifetime issues. ✅

Verified correct: Dispatch cascade integration

ONNX Attention (attention.cc):

Flash → MEA → GQA Unfused (new, when is_gqa || needs_fp32_qk_scratch) → MHA Unfused (legacy)

The GQA unfused check correctly sits AFTER Flash/MEA eligibility and BEFORE the legacy MHA unfused. The qk_matmul_output_mode_ == kNone guard ensures output_qk modes fall through to the legacy path. The softcap guard on the legacy path is correctly preserved. ✅

Contrib GQA (group_query_attention_impl.cu):

Flash → MEA → XQA → Unfused (new) → NOT_IMPLEMENTED error

Guards: !use_xqa && !use_flash && !use_mea && !quantized && !smooth_softmax && !head_sink && BNSH format. All correct. ✅

Verified correct: Error handling

All kernel launches check cudaGetLastError(). All cuBLAS calls check return status. ORT_RETURN_IF_ERROR is used consistently. Parameter validation in LaunchGqaUnfusedAttention covers all critical invariants (positive dimensions, num_heads % kv_num_heads == 0, workspace non-null). ✅

Overall assessment

APPROVE — No MEDIUM+ issues found in this round. Previous findings are addressed. The implementation is clean, memory-safe, and correctly integrated into both dispatch cascades. Test coverage is comprehensive (GQA + MHA, fp16 + bf16 + fp32, prompt + decode, causal + sliding window, softcap, BNSH + BSNH, attn_mask + past_key).

The only concern from earlier rounds that remains relevant is the softcap ordering, which was resolved by the ONNX spec correction (onnx/onnx#7865).

@titaiwangms
Copy link
Copy Markdown
Contributor

Round 3 Review — Test Coverage Gap (MEDIUM)

The GQA unfused tests (test_gqa.py) use softcap but no explicit attention masks (has_attn_mask=False). They only use is_causal=1, which applies causal masking inside the softmax kernel rather than as additive bias before softcap.

This means the current tests would not catch wrong softcap+mask ordering (onnx/onnx#7865) when the GQA unfused kernel is called from the ONNX standard Attention op with real masks (via ConvertAttnMaskToBias).

Recommendation: Add at least one GQA unfused test with softcap > 0 AND an explicit attention mask containing -inf positions. A poison-value technique works well: set V to a large value (e.g., 1000) at masked positions, then verify the output stays clean (< some threshold). If softcap ordering is wrong, attention leaks to poison positions and the output exceeds the threshold.

This is important because the ONNX Attention dispatch now routes to the GQA unfused path, and that path receives attn_bias from ConvertAttnMaskToBias() — a code path not exercised by the current contrib GQA tests.

Also see our correction on the softcap ordering finding: previous comment

Copilot AI added a commit that referenced this pull request Apr 24, 2026
New C++ tests in attention_op_test.cc:
- BFloat16 large head_size (exercises __nv_bfloat16 template)
- head_size=512 (Gemma 4 exact config, exercises MAX_HEAD_SIZE=512)
- Multi-batch decode with past_key (per-batch seqlens logic)
- nonpad_kv_seqlen with GQA (seqlens_k conversion path)
- BSNH with present_key/value output (transpose-into-present optimization)
- Boolean attn_mask with GQA (ConvertAttnMaskToBias boolean path)
- MHA (non-GQA) with fp16 and large head_size (fp32 QK scratch without GQA)
- Causal + past_key decode (causal mask with past context)

New Python tests in test_gqa.py:
- Multi-batch Gemma 4 prompt (per-batch behavior)
- BFloat16 with sliding window
- Decode with softcap
- Combined sliding window + softcap

New Python tests in test_onnx_attention/test_gqa.py:
- Decode with softcap (two softcap values)
- Decode with past + attn_mask
- Single KV head (high group ratio 8:1)
- Multi-batch prompt
- Non-causal attention

Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/0d250102-d408-4c30-a82d-3592c0ba17d7

Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
titaiwangms
titaiwangms previously approved these changes Apr 24, 2026
@tianleiwu tianleiwu merged commit 997c479 into main Apr 25, 2026
87 of 89 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUDA EP: Unfused Attention runner produces NaN for fp16 with head_dim > 256

3 participants