Skip to content

[CUDA] PagedAttention: use exact max_query_len on FA path#28409

Open
elwhyjay wants to merge 2 commits intomicrosoft:mainfrom
elwhyjay:feature/paged-attention-rotary-max-seqlen-fix
Open

[CUDA] PagedAttention: use exact max_query_len on FA path#28409
elwhyjay wants to merge 2 commits intomicrosoft:mainfrom
elwhyjay:feature/paged-attention-rotary-max-seqlen-fix

Conversation

@elwhyjay
Copy link
Copy Markdown
Contributor

@elwhyjay elwhyjay commented May 8, 2026

Description

Fix the FA dispatch path of PagedAttention to use the host-computed actual maximum new-query length per batch instead of the older token_count - batch_size + 1 heuristic. The same value is used for both mha_varlen_fwd (params.seqlen_q) and the rotary kernel grid.

With the exact host-computed maximum:

  • The rotary kernel grid is large enough to cover the true per-batch maximum, so no Q/K token is dropped from rotary.
  • mha_varlen_fwd always sees a positive, accurate seqlen_q, so its grid.x = ceil(params.seqlen_q / kBlockM) covers all query tokens (no silent drop at kBlockM=64 boundaries) and is never invalid (no CUDA error 9).

The MEA path already used data.max_query_len (host-computed in paged_attention.cc) from #28200. This PR moves that host computation out of the MEA-only block so FA also sees the exact value, and the FA dispatch reads the same data.max_query_len.

Motivation and Context

This is a follow-up to #28200. The MEA dispatch path was fixed there; the FA dispatch path uses the same heuristic and has the same root cause.

In paged_attention_impl.cu, the FA path computed:

    const int max_query_len = token_count - batch_size + 1;

This value was passed to mha_varlen_fwd as params.seqlen_q and also used as grid.x for LaunchRotaryEmbeddingKernel. The formula assumes each batch has at least one new token, which is not enforced by the op input.

Three failure modes from the same heuristic underestimation:

  1. Rotary silent droplens=[10, 0, 0, 0]. heuristic = 10 - 4 + 1 = 7, real max = 10. Tokens at positions s=7,8,9 in batch 0 are not rotated.
  2. FA grid silent drop at kBlockM=64 boundarylens=[65, 0]. heuristic = 64, real max = 65. mha_varlen_fwd launches with grid.x = 1 but the 65th query token needs grid.x = 2.
  3. Non-positive heuristiclens=[10, 0, ..., 0] with batch_size = 16. heuristic = 10 - 16 + 1 = -5. The value reaches mha_varlen_fwd as params.seqlen_q, and the FA launch fails with CUDA error 9 (invalid configuration argument).

Tests

New test class TestPagedAttentionRotaryZeroTokenRegression:

  • test_fa_rotary_zero_token_first_batchlens=[10,0,0,0]. Rotary silent drop reproducer.
  • test_fa_rotary_zero_token_mixedlens=[0,7,0,3].
  • test_fa_rotary_zero_token_large_batchlens=[10, 0×15] (batch_size=16). Negative-heuristic CUDA error 9 reproducer.
  • test_fa_kblockm_boundary_zero_tokenlens=[65, 0] with rotary on. FA grid kBlockM-boundary silent drop reproducer.
  • test_fa_kblockm_boundary_zero_token_no_rotary — same with rotary off.
  • test_mea_rotary_zero_token_no_regression — guards against regression of the [CUDA] PagedAttention: add SM<80 fp16 fallback via memory-efficient attention #28200 MEA fix.
  • test_fa_no_rotary_zero_token_sanity — sanity for FA without rotary.

parity_check_paged_attention got a new optional parameter new_seqlens_override so tests can pass a deterministic per-batch distribution instead of randint(1, ...).

Existing TestPagedAttention (24) + TestPagedAttentionMEA (24) suite passes (48/48). New regression class passes (7/7).

Performance

Measured on RTX PRO 4500 (sm_120, CUDA 12.8). 100 warmup + 500 iterations per case. GPU clock lock is not available on this cloud GPU, so absolute numbers carry some measurement noise, but the trend is consistent.

case base us this PR us delta
prefill B=1 rot=on (lens=[1024]) 102.1 110.2 +7.9%
prefill B=4 rot=on (lens=[1024]×4) 567.9 360.6 -36.5%
prefill B=16 rot=on (lens=[1024]×16) 5988 1488 -75.1%
prefill B=64 rot=on (lens=[1024]×64) 82541 6317 -92.3%
prefill B=64 rot=off (lens=[1024]×64) 9496 4870 -48.7%
decode B=1 rot=on (lens=[1]) 26.4 34.8 +31.4%
decode B=4 rot=on (lens=[1]×4) 29.1 38.4 +32.0%
decode B=16 rot=on (lens=[1]×16) 46.1 55.6 +20.5%
decode B=64 rot=on (lens=[1]×64) 112.1 120.2 +7.2%
mixed B=64 rot=on (lens=[1024,1×63]) 1384 1390 +0.4%
zero-mix B=4 rot=on (lens=[10,0,0,0]) 27.7 34.6 +25.0%
zero-mix B=16 rot=on (lens=[10, 0×15]) CUDA error 9 37.6 now passes
zero-mix B=64 rot=on (lens=[10, 0×63]) CUDA error 9 49.9 now passes

The base heuristic over-launched the rotary kernel on prefill: for lens=[1024]×64 it computes 65473, vs the true max of 1024. The rotary kernel was launching ~64x more blocks than needed; each over-launched block did the early-return work but the launch overhead was visible in wall-clock. The same heuristic was also passed to FA as params.seqlen_q, so FA also launched many unnecessary m-blocks in prefill-like cases.

Using the exact maximum gives:

  • Large prefill workloads: significant speedup (-36% to -92%).
  • Decode small batches: small absolute regression (~+5–10us per call) from the host D2H copy + cudaStreamSynchronize newly added on the FA path. In percent this is +20–30% on cases that take ~25us; on decode B=64 rot=on it is +7%.
  • Inputs that previously crashed with CUDA error 9 now run.

The decode regression is the cost of the new D2H sync. It was already paid on the MEA path. The regression is small in absolute terms and the heuristic was silently wrong on these sparse-batch patterns, so this is a correctness fix.

Test plan

  • new regression tests pass (7/7) on RTX PRO 4500 (sm_120, CUDA 12.8)
  • existing test_paged_attention_cuda suite passes (48/48)
  • CI green

cc: @tianleiwu

@elwhyjay elwhyjay force-pushed the feature/paged-attention-rotary-max-seqlen-fix branch from b80299a to 11596b3 Compare May 8, 2026 05:52
@elwhyjay elwhyjay marked this pull request as draft May 8, 2026 07:52
@elwhyjay elwhyjay force-pushed the feature/paged-attention-rotary-max-seqlen-fix branch from 9e9ba43 to 8278551 Compare May 8, 2026 08:38
@elwhyjay elwhyjay changed the title [CUDA] PagedAttention: use token_count for FA rotary grid [CUDA] PagedAttention: use exact max_query_len on FA path May 8, 2026
@elwhyjay elwhyjay closed this May 8, 2026
@elwhyjay elwhyjay reopened this May 8, 2026
@elwhyjay elwhyjay marked this pull request as ready for review May 8, 2026 08:51
Compute the exact per-batch max query length from cumulative_seqlens_q
and pass it to the FA path instead of using the token_count - batch_size + 1
heuristic.

The heuristic assumes every batch has at least one new token. When some
batches have zero new tokens, it can underestimate the true max query length
or become non-positive. This can cause rotary under-launch, FA kBlockM boundary
under-launch, or an invalid FA launch grid.

Use the exact max query length for both mha_varlen_fwd and the rotary kernel
grid, and add regression tests for zero-token batches, the non-positive
heuristic case, and the kBlockM boundary case.
@elwhyjay elwhyjay force-pushed the feature/paged-attention-rotary-max-seqlen-fix branch from 8278551 to 90c5702 Compare May 8, 2026 09:08
@tianleiwu tianleiwu requested a review from Copilot May 8, 2026 14:52
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes CUDA PagedAttention’s FlashAttention (FA) dispatch to use the host-computed exact per-batch maximum new-query length (max_query_len) instead of the token_count - batch_size + 1 heuristic, preventing under-launch/silent token drops and invalid FA launches when some batches have zero new tokens.

Changes:

  • Hoist host-side computation of max_query_len in paged_attention.cc so it’s available to both FA and MEA paths, and plumb it through PagedAttentionData.
  • Update FA path in paged_attention_impl.cu to use data.max_query_len for both rotary grid sizing and mha_varlen_fwd’s seqlen_q.
  • Extend Python parity helper to accept deterministic new_seqlens_override and add regression tests covering zero-token batch cases and the kBlockM boundary.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
onnxruntime/test/python/transformers/test_paged_attention_cuda.py Adds new_seqlens_override to parity helper and introduces FA/MEA regression tests for zero-token batch distributions.
onnxruntime/contrib_ops/cuda/bert/paged_attention.cc Computes max_query_len on host and stores it in PagedAttentionData for FA/MEA consumption.
onnxruntime/contrib_ops/cuda/bert/paged_attention_impl.cu Replaces FA heuristic max-query-length with host-computed data.max_query_len for rotary + mha_varlen_fwd.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/test/python/transformers/test_paged_attention_cuda.py
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants