Skip to content

Optimize ONNX Attention KV cache with ConcatNewToPast and add release-build kernel safety#27613

Merged
titaiwangms merged 12 commits intomainfrom
titaiwang/improve_present_kv_copy
Mar 13, 2026
Merged

Optimize ONNX Attention KV cache with ConcatNewToPast and add release-build kernel safety#27613
titaiwangms merged 12 commits intomainfrom
titaiwang/improve_present_kv_copy

Conversation

@titaiwangms
Copy link
Contributor

@titaiwangms titaiwangms commented Mar 10, 2026

This pull request introduces a major optimization to the CUDA attention kernel by fusing the key/value (KV) cache concatenation step into a single kernel, reducing memory overhead and kernel launches during autoregressive decoding. The update also clarifies and improves documentation and safety around attention mask handling, especially regarding right-padding validity and mask conversion. Support for native __nv_bfloat16 types is expanded for better hardware utilization.

Performance and Kernel Fusion Improvements:

  • The attention kernel now uses LaunchConcatNewToPastKV to fuse the past KV copy and new token append into a single CUDA kernel, eliminating the previous memset and strided copy overhead. This reduces memory traffic and kernel launch overhead during decoding, matching the efficiency of the contrib GQA path. [1] [2]
  • Support for native __nv_bfloat16 types is added to the fused concat kernel, enabling direct hardware arithmetic on supported GPUs and aligning with GQA's early type conversion pattern. [1] [2]

Attention Mask Handling and Documentation:

  • The logic and documentation for converting boolean attention masks to sequence lengths are clarified. In debug builds, non-contiguous (non-right-padded) masks trigger a CUDA kernel assertion; in release builds, the output is memory-safe but may be semantically incorrect, as only leading True values are counted. [1] [2] [3] [4]
  • Comments and docstrings are updated to explain the handling of offsets in mask-to-seqlen conversion, and the broadcasting behavior in bias addition kernels is clarified. [1] [2]

Code Quality and Minor Fixes:

  • Includes minor code style improvements, such as using using instead of typedef for CUDA type aliases, and adds missing header includes for type conversion. [1] [2]
  • Adds missing includes and updates comments for clarity in several files. [1] [2] [3]
  • Adds preprocessor guards to prevent unused variable warnings in certain build configurations.

… safety for CUDA kernels

Task 1: Replace memset + cudaMemcpy2DAsync + Flash Append_KV with a
single LaunchConcatNewToPastKV call that writes past_key + new_key
directly into present_key. This eliminates the strided copy and
Flash's internal append, reducing memory operations significantly.

Task 2: Add release-build safety for CUDA_KERNEL_ASSERT sites:
- attention_mask_impl.cu: break after assert for non-contiguous mask
- tensorscatter_impl.cu: clamp write_indices and cache_pos to valid range
- rotary_embedding_impl.cu: remove redundant bare assert()

Fixes #27612

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes CUDA LLM attention decode by replacing the previous multi-step KV cache preparation (memset + strided copies + Flash internal append) with a fused “concat new KV to past” kernel, while also hardening several CUDA kernels to avoid unsafe behavior in release builds (where device asserts may be compiled out).

Changes:

  • Use LaunchConcatNewToPastKV to pre-populate present_key/value in ONNX Attention’s Flash decode path and call mha_fwd_kvcache with k_new/v_new=nullptr.
  • Add release-safety clamping/break behavior in CUDA kernels (TensorScatter write indices; boolean mask contiguity handling) to prevent OOB access/hard faults without host-side synchronization.
  • Remove an extra debug assert in the rotary embedding CUDA implementation.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/core/providers/cuda/llm/attention.cc Switch Flash decode-with-past to fused KV concat kernel and skip Flash internal KV append.
onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu Make non-contiguous bool masks stop scanning after the first invalid transition (plus clamp seqlens to non-negative).
onnxruntime/core/providers/cuda/llm/attention_mask_impl.h Update comments documenting debug-only asserts and release-build behavior for invalid masks.
onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu Clamp negative / OOB write indices in-kernel for release safety.
onnxruntime/core/providers/cuda/llm/tensorscatter.cc Update documentation/comments about in-kernel validation and release behavior.
onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu Remove redundant assert (retains runtime enforcement).
Comments suppressed due to low confidence (1)

onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu:43

  • In the release-build path (when CUDA_KERNEL_ASSERT is compiled out), clamping out-of-bounds write_indices causes all OOB writes to be redirected to max_seq_len - 1. That can silently corrupt the last cache position (potentially overwriting valid data multiple times) instead of preserving the original output for invalid indices. Consider making invalid indices a no-op (skip the write) and/or clamping wi to the largest valid start index (max_seq_len - seq_len) so that the entire update range stays in-bounds without collapsing onto the last element.
  int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
  CUDA_KERNEL_ASSERT(wi >= 0);
  wi = max(wi, static_cast<int64_t>(0));  // Clamp for release safety (CUDA_KERNEL_ASSERT is debug-only)
  int64_t cache_pos;
  if (circular) {
    cache_pos = (wi + seq_idx) % max_seq_len;
  } else {
    cache_pos = wi + seq_idx;
    CUDA_KERNEL_ASSERT(cache_pos < max_seq_len);
    cache_pos = min(cache_pos, max_seq_len - 1);  // Clamp for release safety (CUDA_KERNEL_ASSERT is debug-only)
  }

  int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx;
  output_data[out_offset] = update_data[id];

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@tianleiwu
Copy link
Contributor

Please consider using LaunchUnpackRoPEAppend in onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh if possible. That's unified preprocessing for GQA/MHA.

titaiwangms and others added 2 commits March 10, 2026 22:13
…NH layout

- tensorscatter.cc: clarify that clamping provides memory-safe (not output-correct) behavior
- attention.cc: add inline comment explaining BSNH dimension semantics for LaunchConcatNewToPastKV

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms marked this pull request as ready for review March 10, 2026 23:49
Adds CUDA EP test coverage for the Flash-eligible decode configuration
with past_key/past_value and boolean attention mask.

Three tests added:
- All-true mask: verifies basic Flash decode + fp16 + bool mask path
- Partial mask [T,T,T,F]: verifies Flash's seqlens_k-based masking (CUDA-only)
- First decode (past_seq=0): verifies empty past tensor edge case

Uses analytically-verifiable test data (uniform keys → uniform softmax)
so expected output values can be computed exactly.

Part of #27612

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

…tter overflow

- Modify ConcatNewToPastKV kernel to write zeros for tail positions
  (beyond valid tokens) ensuring deterministic present_key/value output
- Fix int64 overflow in tensorscatter: pre-clamp wi before addition,
  handle negative modulo in circular mode
- Tighten test tolerance for partial mask decode test
- Update comments for accuracy

Part of #27612

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

…pping, add optimization TODO

- Revert tail-zeroing in ConcatNewToPastKV kernels (wasted I/O; Flash respects seqlens_k)
- Add comment explaining intentional no-zero-fill for trailing KV cache positions
- Fix ToCudaType → OrtToCudaType in rotary_embedding.cc for native bf16 HW arithmetic
- Add __nv_bfloat16 template instantiation in rotary_embedding_impl.cu
- Add TODO(titaiwang) for single-kernel preprocessing optimization
- Simplify partial mask test (skip tail validation)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
titaiwangms and others added 3 commits March 11, 2026 23:20
- Guard has_output_qk with #if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION
- Remove duplicate #include <cuda_bf16.h> in rotary_embedding_impl.cu

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…sk test

- Switch decode path CudaT from ToCudaType to OrtToCudaType (native __nv_bfloat16)
- Add __nv_bfloat16 template instantiation for LaunchConcatNewToPastKV
- Simplify partial mask test: skip present_key/value content validation
  (trailing positions intentionally uninitialized for performance)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Rename CudaT → NativeCudaT in attention.cc and rotary_embedding.cc
  Flash/kernel paths to distinguish OrtToCudaType (native __nv_bfloat16)
  from ToCudaType (ORT BFloat16 wrapper) in the unfused path
- Remove dead BFloat16 template instantiation in rotary_embedding_impl.cu
  (only __nv_bfloat16 is used by llm/ callers via OrtToCudaType)
- Modernize typedef → using in rotary_embedding.cc and attention.cc
- Improve comments: clarify right-padding convention, uninitialized tail
  semantics, seqlen_offset docs, cyclic broadcast wording, rotary
  head_size constraint, tensorscatter clamping

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

- tensorscatter_impl.cu: replace double-modulo idiom with single %
  operator since wi and seq_idx are guaranteed non-negative after
  clamping on the preceding lines
- attention_op_test.cc: replace 1e10 tolerance hack for present_key/
  value with SetCustomOutputVerifier that validates only Y output.
  IEEE 754 NaN comparisons always return false regardless of tolerance,
  so uninitialized tail memory containing NaN would fail the test.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@titaiwangms titaiwangms requested a review from Copilot March 12, 2026 03:54
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Comments suppressed due to low confidence (1)

onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu:86

  • ConvertMaskToSeqlensKernel treats any row with mask_row[0]==false as an all-false (seq_len=0) mask and skips scanning the rest of the row. That means an invalid/non-contiguous pattern like [False, True, ...] will not trigger CUDA_KERNEL_ASSERT in debug builds and will be silently misinterpreted as length 0 in release builds. If right-padding convention validation is intended, consider checking the remaining elements are all False when mask_row[0] is False (at least under CUDA_KERNEL_ASSERT/debug).
  if (!mask_row[0]) {
    // Entire row is padding (all-false mask)
    seq_len = 0;
  } else {

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@titaiwangms
Copy link
Contributor Author

Addressing @tianleiwu's Review Feedback

Hi @tianleiwu — here's how each of your feedback items has been addressed:

1. Consider using LaunchUnpackRoPEAppend for unified preprocessing

  • Status: ✅ Acknowledged
  • Resolution: Added TODO at attention.cc:315-318 documenting this as a future optimization (fuse RoPE + mask + concat into ~2 kernel launches). Out of scope for this PR.
  • Commit: d093340

2. Avoid zero-filling present KV tail — I/O impacts performance

  • Status: ✅ Resolved
  • Resolution: Reverted tail-zeroing. Flash respects seqlens_k bounds so trailing positions are never read. Saves ~134MB I/O per decode step.
  • Commit: d093340

3. Use OrtToCudaType instead of ToCudaType for native bf16

  • Status: ✅ Addressed
  • Resolution: All Flash path casts now use OrtToCudaType<T>::type (NativeCudaT__nv_bfloat16), consistent with GQA pattern. Unfused path correctly retains ToCudaType for contrib_ops interface compatibility.
  • Commit: 79e5ca9

All 3 items addressed. Please let me know if anything needs further changes!

- tensorscatter_impl.cu: move min(wi, max_seq_len) into linear branch
  only. In circular (ring buffer) mode, wi > max_seq_len is legal and
  should wrap via modulo, not be clamped. Use (wi % max_seq_len +
  seq_idx) % max_seq_len to prevent int64 overflow while preserving
  wrap semantics.
- tensorscatter.cc: update comment to distinguish circular (modulo) vs
  linear (clamping) bounds handling.
- attention_op_test.cc: enhance custom verifier to validate
  present_key/value prefix (past rows + new token) while still
  skipping uninitialized tail positions (NaN-safe).
- attention_op_test.cc: rename FirstDecode test to PromptPath — past KV
  are absent (None), exercising mha_fwd prompt path, not decode.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms force-pushed the titaiwang/improve_present_kv_copy branch from 9184039 to f3424db Compare March 12, 2026 04:39
- Change explicit lambda capture to [&] per ORT convention
- Add FlashAttention_Decode_PartialMask_MultiBatch test with
  batch_size=2 and varying per-batch past_seqlens

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@titaiwangms titaiwangms merged commit 99e0119 into main Mar 13, 2026
91 checks passed
@titaiwangms titaiwangms deleted the titaiwang/improve_present_kv_copy branch March 13, 2026 03:56
@titaiwangms titaiwangms linked an issue Mar 17, 2026 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use concat to ONNX attention cache update

3 participants