[CUDA] PagedAttention: use exact max_query_len on FA path#28409
Open
elwhyjay wants to merge 2 commits intomicrosoft:mainfrom
Open
[CUDA] PagedAttention: use exact max_query_len on FA path#28409elwhyjay wants to merge 2 commits intomicrosoft:mainfrom
elwhyjay wants to merge 2 commits intomicrosoft:mainfrom
Conversation
b80299a to
11596b3
Compare
9e9ba43 to
8278551
Compare
Compute the exact per-batch max query length from cumulative_seqlens_q and pass it to the FA path instead of using the token_count - batch_size + 1 heuristic. The heuristic assumes every batch has at least one new token. When some batches have zero new tokens, it can underestimate the true max query length or become non-positive. This can cause rotary under-launch, FA kBlockM boundary under-launch, or an invalid FA launch grid. Use the exact max query length for both mha_varlen_fwd and the rotary kernel grid, and add regression tests for zero-token batches, the non-positive heuristic case, and the kBlockM boundary case.
8278551 to
90c5702
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
Fixes CUDA PagedAttention’s FlashAttention (FA) dispatch to use the host-computed exact per-batch maximum new-query length (max_query_len) instead of the token_count - batch_size + 1 heuristic, preventing under-launch/silent token drops and invalid FA launches when some batches have zero new tokens.
Changes:
- Hoist host-side computation of
max_query_leninpaged_attention.ccso it’s available to both FA and MEA paths, and plumb it throughPagedAttentionData. - Update FA path in
paged_attention_impl.cuto usedata.max_query_lenfor both rotary grid sizing andmha_varlen_fwd’sseqlen_q. - Extend Python parity helper to accept deterministic
new_seqlens_overrideand add regression tests covering zero-token batch cases and the kBlockM boundary.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| onnxruntime/test/python/transformers/test_paged_attention_cuda.py | Adds new_seqlens_override to parity helper and introduces FA/MEA regression tests for zero-token batch distributions. |
| onnxruntime/contrib_ops/cuda/bert/paged_attention.cc | Computes max_query_len on host and stores it in PagedAttentionData for FA/MEA consumption. |
| onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu | Replaces FA heuristic max-query-length with host-computed data.max_query_len for rotary + mha_varlen_fwd. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Fix the FA dispatch path of
PagedAttentionto use the host-computed actual maximum new-query length per batch instead of the oldertoken_count - batch_size + 1heuristic. The same value is used for bothmha_varlen_fwd(params.seqlen_q) and the rotary kernel grid.With the exact host-computed maximum:
mha_varlen_fwdalways sees a positive, accurateseqlen_q, so itsgrid.x = ceil(params.seqlen_q / kBlockM)covers all query tokens (no silent drop at kBlockM=64 boundaries) and is never invalid (no CUDA error 9).The MEA path already used
data.max_query_len(host-computed inpaged_attention.cc) from #28200. This PR moves that host computation out of the MEA-only block so FA also sees the exact value, and the FA dispatch reads the samedata.max_query_len.Motivation and Context
This is a follow-up to #28200. The MEA dispatch path was fixed there; the FA dispatch path uses the same heuristic and has the same root cause.
In
paged_attention_impl.cu, the FA path computed:This value was passed to
mha_varlen_fwdasparams.seqlen_qand also used asgrid.xforLaunchRotaryEmbeddingKernel. The formula assumes each batch has at least one new token, which is not enforced by the op input.Three failure modes from the same heuristic underestimation:
lens=[10, 0, 0, 0]. heuristic = 10 - 4 + 1 = 7, real max = 10. Tokens at positions s=7,8,9 in batch 0 are not rotated.lens=[65, 0]. heuristic = 64, real max = 65.mha_varlen_fwdlaunches withgrid.x = 1but the 65th query token needsgrid.x = 2.lens=[10, 0, ..., 0]with batch_size = 16. heuristic = 10 - 16 + 1 = -5. The value reachesmha_varlen_fwdasparams.seqlen_q, and the FA launch fails with CUDA error 9 (invalid configuration argument).Tests
New test class
TestPagedAttentionRotaryZeroTokenRegression:test_fa_rotary_zero_token_first_batch—lens=[10,0,0,0]. Rotary silent drop reproducer.test_fa_rotary_zero_token_mixed—lens=[0,7,0,3].test_fa_rotary_zero_token_large_batch—lens=[10, 0×15](batch_size=16). Negative-heuristic CUDA error 9 reproducer.test_fa_kblockm_boundary_zero_token—lens=[65, 0]with rotary on. FA grid kBlockM-boundary silent drop reproducer.test_fa_kblockm_boundary_zero_token_no_rotary— same with rotary off.test_mea_rotary_zero_token_no_regression— guards against regression of the [CUDA] PagedAttention: add SM<80 fp16 fallback via memory-efficient attention #28200 MEA fix.test_fa_no_rotary_zero_token_sanity— sanity for FA without rotary.parity_check_paged_attentiongot a new optional parameternew_seqlens_overrideso tests can pass a deterministic per-batch distribution instead ofrandint(1, ...).Existing
TestPagedAttention(24) +TestPagedAttentionMEA(24) suite passes (48/48). New regression class passes (7/7).Performance
Measured on RTX PRO 4500 (sm_120, CUDA 12.8). 100 warmup + 500 iterations per case. GPU clock lock is not available on this cloud GPU, so absolute numbers carry some measurement noise, but the trend is consistent.
The base heuristic over-launched the rotary kernel on prefill: for
lens=[1024]×64it computes 65473, vs the true max of 1024. The rotary kernel was launching ~64x more blocks than needed; each over-launched block did the early-return work but the launch overhead was visible in wall-clock. The same heuristic was also passed to FA asparams.seqlen_q, so FA also launched many unnecessary m-blocks in prefill-like cases.Using the exact maximum gives:
cudaStreamSynchronizenewly added on the FA path. In percent this is +20–30% on cases that take ~25us; on decode B=64 rot=on it is +7%.The decode regression is the cost of the new D2H sync. It was already paid on the MEA path. The regression is small in absolute terms and the heuristic was silently wrong on these sparse-batch patterns, so this is a correctness fix.
Test plan
cc: @tianleiwu