GQA unfused attention with FP32 QK accumulation (fixes #28195)#28198
GQA unfused attention with FP32 QK accumulation (fixes #28195)#28198
Conversation
Add a GQA-capable unfused CUDA attention kernel that writes Q*K^T to an FP32 scratch buffer, fixing fp16/bf16 overflow/NaN when head_size > 256 (e.g. Gemma 4 with head_dim=512, scale=1.0). Key changes: - New gqa_unfused_attention.cu/.h: 3-stage kernel (QK GEMM in FP32, softmax with causal/sliding-window/softcap/bias/seqlens_k masking, AV GEMM). Uses reshape-Q trick for GQA (no K/V head replication). - GQA contrib op integration: fallback when Flash/MEA/XQA ineligible and KV is not quantized. Handles BNSH cache layout. - ONNX Attention op integration: new RunGqaUnfusedAttention path for GQA (q_num_heads != kv_num_heads) and fp16/bf16 head_size > 128. Supports past_key/value, attn_mask, nonpad_kv_seqlen, softcap. - Raise MAX_HEAD_SIZE from 256 to 512 in UnpackRoPEAppend kernel. - Per-batch past calculation in softmax for correct sliding-window masking with variable-length seqlens_k. - SafeInt overflow checks for workspace size arithmetic. - Tests: 8 Gemma 4 regression tests (prompt/decode, fp16/bf16, softcap, sliding window) + benchmark configs.
There was a problem hiding this comment.
Pull request overview
Adds a new CUDA unfused attention fallback that supports GQA and avoids fp16/bf16 QK overflow by writing Q·Kᵀ into an FP32 scratch buffer, addressing NaNs for large head sizes (e.g., Gemma 4 global attention with head_dim=512).
Changes:
- Introduces
gqa_unfused_attention.{cu,h}implementing a 3-stage unfused pipeline (QK GEMM → softmax → AV GEMM) with FP32 QK scratch and native GQA support. - Integrates the new fallback into both contrib
GroupQueryAttentionand the ONNXAttentionCUDA kernel dispatch logic. - Extends RoPE+KV append support to
head_size <= 512and adds Gemma 4 regression tests/bench configs.
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.h | Declares params/workspace API and launch entrypoint for the new unfused GQA kernel. |
| onnxruntime/contrib_ops/cuda/bert/gqa_unfused_attention.cu | Implements FP32 QK GEMM + softmax + AV GEMM unfused attention kernel with GQA support. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc | Allocates scratch and enables new unfused fallback when Flash/MEA/XQA are ineligible. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu | Routes contrib GQA execution to the new unfused kernel when selected. |
| onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh | Raises supported RoPE append head size by adding a 512 specialization. |
| onnxruntime/contrib_ops/cuda/bert/attention_data.h | Adds flags and scratch pointers for the new unfused fallback path in contrib GQA. |
| onnxruntime/core/providers/cuda/llm/attention.h | Declares RunGqaUnfusedAttention for ONNX Attention CUDA kernel. |
| onnxruntime/core/providers/cuda/llm/attention.cc | Implements RunGqaUnfusedAttention and updates dispatch to use it for GQA and/or large-head fp16/bf16. |
| onnxruntime/test/python/transformers/test_gqa.py | Adds Gemma 4 (head_dim=512) regression tests targeting the unfused fallback behavior. |
| onnxruntime/test/python/transformers/benchmark_gqa.py | Adds Gemma 4 benchmark configurations (global + sliding-window variants). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
titaiwangms
left a comment
There was a problem hiding this comment.
Implementation Review: GQA Unfused Attention with FP32 QK Accumulation
Overall this is a well-engineered PR. The kernel design is clean — the reshape-Q trick for GQA is elegant and memory-efficient, the 3-pass softmax correctly handles all edge cases (fully-masked rows emit zeros, sliding window, softcap, per-batch seqlens), and the ToFloat<T> template specializations properly handle half/bf16→float conversion.
Positive highlights:
- The
gqa_unfused_attention.cukernel is self-contained with a clear 3-stage pipeline and excellent header documentation - SafeInt usage in
group_query_attention.ccscratch allocation is well done static_assertguards against accidental float instantiation and quantized KV cache — good defensive programming- Test coverage for the contrib GQA path is thorough (prompt/decode × fp16/bf16 × softcap/sliding-window)
- ONNX Attention tests (
TestONNXAttentionGQALargeHeadUnfused) added — this addresses the main test gap - Dispatch ordering (Flash → MEA → GQA Unfused → Legacy MHA) is correct
Issues found (detailed in inline comments):
| # | Severity | Issue |
|---|---|---|
| 1 | 🔴 Bug | y_bnsh_bytes uses H (head_size) instead of H_v (v_head_size) for output buffer |
| 2 | 🔴 High | attention.cc has 6 raw size_t multiplication chains without SafeInt |
| 3 | 🔴 High | AlignTo helper drops SafeInt protection by accepting size_t |
| 4 | 🟡 Medium | ONNX Attention tests don't cover attn_mask or nonpad_kv_seqlen with head_size=512 |
| 5 | 🟡 Medium | Missing VERBOSE logging at GQA unfused dispatch branch |
| 6 | 🟡 Medium | Softcap guard needs defense-in-depth comment for when PR #27992 merges |
| 7 | 🟢 Low | Transpose uses head_size instead of v_head_size for output dimension |
Additional review comments (items not in diff-visible lines)🔴 SafeInt missing in The new method has raw
All should use 🟡 Missing VERBOSE logging at GQA unfused dispatch ( PR #27992 establishes VERBOSE logging at each dispatch branch. The new GQA unfused branch should include: LOGS_DEFAULT(VERBOSE) << "ONNX Attention: using GQA Unfused Attention"
<< " (batch=" << B << ", q_seq=" << S_q
<< ", total_kv=" << total_kv
<< ", head_size=" << H
<< ", fp32_qk=" << needs_fp32_qk_scratch << ")";🟡 Softcap guard defense-in-depth comment ( After PR #27992 merges, the legacy MHA unfused path will support softcap too, making this guard dead code. Please add: // Defense-in-depth: the GQA unfused path above supports softcap; the legacy MHA
// unfused path below does not yet. After PR #27992 merges, the legacy path will
// also support softcap and this guard becomes a safety net.🟡 ONNX Attention test coverage gaps ( The
Suggest adding at least one test with |
- AlignTo now accepts SafeInt<size_t> to maintain overflow protection through alignment arithmetic (fixes SafeInt gap). - y_bnsh_bytes uses H_v (v_head_size) instead of H (head_size) for the Y output buffer to prevent latent under-allocation if v_head_size ever differs from head_size. - Add ORT_ENFORCE(head_size == v_head_size) assertion in UnfusedGqaAttention to make the invariant explicit.
…d_size - Use SafeInt<size_t> for all 6 raw size_t multiplication chains in RunGqaUnfusedAttention (attention.cc) to detect overflow. - Add VERBOSE logging at GQA unfused dispatch in ComputeInternal. - Strengthen softcap guard comment for defense-in-depth. - Use p.v_head_size for output BNSH->BSNH transpose in impl. - Add test cases for attn_mask and nonpad_kv_seqlen with head_size=512 in ONNX Attention GQA tests.
|
Thanks for the thorough review! All 7 issues from the summary are now addressed:
|
titaiwangms
left a comment
There was a problem hiding this comment.
Deep Code Review — Round 2
Focus: MEDIUM to CRITICAL correctness issues. Checked status of Round 1 findings.
Round 1 Status
- ✅ FIXED:
y_bnsh_bytesnow correctly usesH_vinstead ofH— good fix - ✅ FIXED: SafeInt usage in buffer allocations is thorough
⚠️ NEW findings below
Finding 1: MEDIUM-HIGH — Softcap+Bias Ordering Violates ONNX Spec
File: gqa_unfused_attention.cu, GqaUnfusedSoftmaxKernel (all 3 passes)
The kernel applies: scale → softcap → bias → mask → softmax
ONNX spec requires: scale → add_mask → softcap → softmax
When both softcap > 0 and attn_bias != nullptr are active (possible in the ONNX Attention path), results won't match spec. The contrib GQA path always passes attn_bias=nullptr, so it's unaffected.
Fix: In all three passes, swap the order — apply bias before softcap:
// Current (wrong):
float x = qk_in[...] * scale;
if (softcap > 0.f) { x = softcap * tanhf(x / softcap); }
if (has_bias) { x += ToFloat(attn_bias[...]); }
// Correct:
float x = qk_in[...] * scale;
if (has_bias) { x += ToFloat(attn_bias[...]); }
if (softcap > 0.f) { x = softcap * tanhf(x / softcap); }This is the same ordering issue PR #27992 fixed for the legacy unfused path.
Finding 2: MEDIUM — LaunchConcatNewToPastKV Uses H for V Cache
File: attention.cc, RunGqaUnfusedAttention, line calling LaunchConcatNewToPastKV
The concat kernel takes a single head_size param used for both K and V. When H != H_v (asymmetric head sizes, allowed by ONNX Attention spec), the V cache concatenation uses wrong strides.
Mitigation: GQA models universally have H==H_v. This is a pre-existing API limitation.
Recommendation: Add a defensive guard:
ORT_RETURN_IF(H != H_v && past_key != nullptr,
"RunGqaUnfusedAttention: past_key with H != H_v not supported");Finding 3: LOW-MEDIUM — Integer Overflow in Stride Calculations
File: gqa_unfused_attention.cu, lines in LaunchQkGemmFp32 and LaunchAttnVGemm
Stride calculations like:
const int64_t stride_q = static_cast<int64_t>(group) * S_q * H;compute group * S_q * H in int before the cast to int64_t. With large sequences, this could overflow.
Fix: Cast earlier: static_cast<int64_t>(group) * static_cast<int64_t>(S_q) * H
Finding 4: MEDIUM — ONNX Attention Test Gaps
Missing test combinations for the ONNX Attention path at head_size=512:
past_key + attn_mask— exercises concat + bias path togethersoftcap + attn_mask— would expose Finding 1- BSNH (3D) input — all C++ tests use 4D BNSH
- fp32 type — the float instantiation is untested
The contrib GQA tests are thorough — good work on prompt/decode/bf16/softcap/sliding-window coverage.
Positive Notes
- Excellent header documentation in
gqa_unfused_attention.h - SafeInt usage in allocations is solid
- 3-pass softmax with fully-masked row handling is correct
- cuBLAS FP32 compute for QK GEMM is the right fix for fp16 overflow
- Clean test structure with
_run_gemma4_gqahelper
onnx spec is wrong. Softcap shall be before mask. See onnx/onnx#7865. This implementation will keep the right order : scale -> softcap -> attn_bias --> attn_mask --> softmax. |
Correction to Round 2 Review — Finding 1 (Softcap Ordering)After reviewing onnx/onnx#7865 filed by @tianleiwu, we are retracting Finding 1 (softcap+bias ordering violates ONNX spec). @tianleiwu correctly identified that the ONNX spec itself is wrong: applying softcap after mask maps PR #28198's Updated Findings for PR #28198
|
titaiwangms
left a comment
There was a problem hiding this comment.
Round 3 Critical Review — Architecture, Security, Performance
Reviewer: Critical Reviewer (AI Agent, PR #27992 crew)
Scope: Memory safety, FP32 accumulation correctness, edge cases, integration, error handling
Previous findings status
- Softcap ordering (R2-H1): RETRACTED — PR ordering (softcap before bias) is correct per onnx/onnx#7865
- H≠H_v in LaunchConcatNewToPastKV (R2-F2): ✅ FIXED —
ORT_RETURN_IF(H != H_v, ...)guard added at attention.cc line ~895 - Int overflow in strides (R2-F3): ✅ Verified safe — all stride math uses
int64_t/SafeInt<size_t> - Test gaps (R2-F4): ✅ ADDRESSED — new tests for past_key+attn_mask, softcap+attn_mask, BSNH, fp32
New findings this round
MEDIUM: Softmax kernel grid.x can exceed 65535 for large models
File: gqa_unfused_attention.cu, LaunchGqaUnfusedSoftmax
const dim3 grid(params.num_heads * params.q_sequence_length, params.batch_size, 1);grid.x = N_q * S_q. For prompt with N_q=128 heads and S_q=8192, grid.x = 1,048,576. This is fine — CUDA supports gridDim.x up to 2^31-1 (since compute capability 3.0). grid.y = batch_size is limited to 65535, which is not a practical concern.
Verdict: No issue. Verified safe for all realistic configurations.
MEDIUM: QkElementCount returns size_t but callers pass int args
File: gqa_unfused_attention.cu, line ~48
inline size_t QkElementCount(int batch_size, int num_heads, int q_seq, int total_kv) {
return SafeInt<size_t>(batch_size) * num_heads * q_seq * total_kv;
}This is correct — SafeInt<size_t> on the first operand ensures the entire multiplication chain is checked for overflow. No issue.
Verified correct: FP32 QK accumulation
The cublasGemmStridedBatchedEx call is correctly configured:
- Input A/B type:
CudaTypeFor<T>()(fp16/bf16/fp32 as appropriate) - Output C type:
CUDA_R_32F(always FP32) - Compute type:
CUBLAS_COMPUTE_32F - alpha/beta:
float(matches compute type)
This means the QK GEMM accumulates in FP32 AND writes to FP32 scratch — no intermediate fp16 overflow possible. ✅
The AV GEMM correctly uses the original type T for output (the softmax weights are already bounded [0,1] so no overflow risk). ✅
Verified correct: Workspace layout and lifetime
The workspace is a single allocation split into two regions:
[0, qk_bytes): FP32 QK scratch[qk_bytes, qk_bytes + softmax_bytes): Type T softmax output
Both regions are 256-byte aligned. The allocation is held by ws_buffer (ONNX path) or unfused_scratch (contrib path), both of which live until the function returns. All CUDA operations use the same stream, so there are no async lifetime issues. ✅
Verified correct: Dispatch cascade integration
ONNX Attention (attention.cc):
Flash → MEA → GQA Unfused (new, when is_gqa || needs_fp32_qk_scratch) → MHA Unfused (legacy)
The GQA unfused check correctly sits AFTER Flash/MEA eligibility and BEFORE the legacy MHA unfused. The qk_matmul_output_mode_ == kNone guard ensures output_qk modes fall through to the legacy path. The softcap guard on the legacy path is correctly preserved. ✅
Contrib GQA (group_query_attention_impl.cu):
Flash → MEA → XQA → Unfused (new) → NOT_IMPLEMENTED error
Guards: !use_xqa && !use_flash && !use_mea && !quantized && !smooth_softmax && !head_sink && BNSH format. All correct. ✅
Verified correct: Error handling
All kernel launches check cudaGetLastError(). All cuBLAS calls check return status. ORT_RETURN_IF_ERROR is used consistently. Parameter validation in LaunchGqaUnfusedAttention covers all critical invariants (positive dimensions, num_heads % kv_num_heads == 0, workspace non-null). ✅
Overall assessment
APPROVE — No MEDIUM+ issues found in this round. Previous findings are addressed. The implementation is clean, memory-safe, and correctly integrated into both dispatch cascades. Test coverage is comprehensive (GQA + MHA, fp16 + bf16 + fp32, prompt + decode, causal + sliding window, softcap, BNSH + BSNH, attn_mask + past_key).
The only concern from earlier rounds that remains relevant is the softcap ordering, which was resolved by the ONNX spec correction (onnx/onnx#7865).
Round 3 Review — Test Coverage Gap (MEDIUM)The GQA unfused tests ( This means the current tests would not catch wrong softcap+mask ordering (onnx/onnx#7865) when the GQA unfused kernel is called from the ONNX standard Attention op with real masks (via Recommendation: Add at least one GQA unfused test with This is important because the ONNX Attention dispatch now routes to the GQA unfused path, and that path receives Also see our correction on the softcap ordering finding: previous comment |
New C++ tests in attention_op_test.cc: - BFloat16 large head_size (exercises __nv_bfloat16 template) - head_size=512 (Gemma 4 exact config, exercises MAX_HEAD_SIZE=512) - Multi-batch decode with past_key (per-batch seqlens logic) - nonpad_kv_seqlen with GQA (seqlens_k conversion path) - BSNH with present_key/value output (transpose-into-present optimization) - Boolean attn_mask with GQA (ConvertAttnMaskToBias boolean path) - MHA (non-GQA) with fp16 and large head_size (fp32 QK scratch without GQA) - Causal + past_key decode (causal mask with past context) New Python tests in test_gqa.py: - Multi-batch Gemma 4 prompt (per-batch behavior) - BFloat16 with sliding window - Decode with softcap - Combined sliding window + softcap New Python tests in test_onnx_attention/test_gqa.py: - Decode with softcap (two softcap values) - Decode with past + attn_mask - Single KV head (high group ratio 8:1) - Multi-batch prompt - Non-causal attention Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/0d250102-d408-4c30-a82d-3592c0ba17d7 Co-authored-by: tianleiwu <30328909+tianleiwu@users.noreply.github.com>
Description
Add a GQA-capable unfused CUDA attention kernel that writes Q·K^T to an FP32 scratch buffer, fixing fp16/bf16 overflow producing NaN when
head_size > 256atscale=1.0(issue #28195, e.g. Gemma 4 global attention layers withhead_dim=512).Motivation
Gemma 4 uses
head_dim=512for its global attention layers (num_attention_heads=8, num_key_value_heads=4). Flash Attention and Memory-Efficient Attention cap athead_size=256, so these fall through to the unfused path. The existing unfused MHA runner produces NaN because even though cuBLAS accumulates in FP32, the Q·K^T output tensor is fp16 and overflows. Additionally, the MHA unfused runner cannot handle GQA (q_num_heads != kv_num_heads).Key Changes
New kernel (
contrib_ops/cuda/bert/gqa_unfused_attention.cu/.h):CUBLAS_COMPUTE_32FwithCUDA_R_32Foutput type — raw Q·K^T scores written to FP32 scratch, eliminating fp16 overflowlocal_window_size), softcap, additive attention bias, and per-batchseqlens_kpastcalculation for correct sliding-window masking with variable-length sequencesGQA contrib op integration (
group_query_attention.cc,group_query_attention_impl.cu):PrepareQKVfor RoPE and K/V cache management, then routes to the new kernelONNX Attention op integration (
attention.cc,attention.h):RunGqaUnfusedAttentionpath for GQA and fp16/bf16 withhead_size > 128nonpad_kv_seqlenpresent_key/present_valuewhen availableUnpackRoPEAppendkernel (group_query_attention_qkv.cuh):MAX_HEAD_SIZEfrom 256 to 512 to support Gemma 4 head dimensionsSafety improvements:
SafeInt<size_t>for workspace size arithmetic (overflow protection)static_assertguarding GQA transpose paths against accidental float instantiationTesting
test_gqa.py: prompt/decode × fp16/bf16, softcap, sliding window, long past sequencesbenchmark_gqa.py(global + local attention)TestGQARegressionstests pass locally (12/12)Fixes
Fixes #28195