Optimize ONNX Attention KV cache with ConcatNewToPast and add release-build kernel safety#27613
Conversation
… safety for CUDA kernels Task 1: Replace memset + cudaMemcpy2DAsync + Flash Append_KV with a single LaunchConcatNewToPastKV call that writes past_key + new_key directly into present_key. This eliminates the strided copy and Flash's internal append, reducing memory operations significantly. Task 2: Add release-build safety for CUDA_KERNEL_ASSERT sites: - attention_mask_impl.cu: break after assert for non-contiguous mask - tensorscatter_impl.cu: clamp write_indices and cache_pos to valid range - rotary_embedding_impl.cu: remove redundant bare assert() Fixes #27612 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR optimizes CUDA LLM attention decode by replacing the previous multi-step KV cache preparation (memset + strided copies + Flash internal append) with a fused “concat new KV to past” kernel, while also hardening several CUDA kernels to avoid unsafe behavior in release builds (where device asserts may be compiled out).
Changes:
- Use
LaunchConcatNewToPastKVto pre-populatepresent_key/valuein ONNX Attention’s Flash decode path and callmha_fwd_kvcachewithk_new/v_new=nullptr. - Add release-safety clamping/break behavior in CUDA kernels (TensorScatter write indices; boolean mask contiguity handling) to prevent OOB access/hard faults without host-side synchronization.
- Remove an extra debug
assertin the rotary embedding CUDA implementation.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
onnxruntime/core/providers/cuda/llm/attention.cc |
Switch Flash decode-with-past to fused KV concat kernel and skip Flash internal KV append. |
onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu |
Make non-contiguous bool masks stop scanning after the first invalid transition (plus clamp seqlens to non-negative). |
onnxruntime/core/providers/cuda/llm/attention_mask_impl.h |
Update comments documenting debug-only asserts and release-build behavior for invalid masks. |
onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu |
Clamp negative / OOB write indices in-kernel for release safety. |
onnxruntime/core/providers/cuda/llm/tensorscatter.cc |
Update documentation/comments about in-kernel validation and release behavior. |
onnxruntime/core/providers/cuda/llm/rotary_embedding_impl.cu |
Remove redundant assert (retains runtime enforcement). |
Comments suppressed due to low confidence (1)
onnxruntime/core/providers/cuda/llm/tensorscatter_impl.cu:43
- In the release-build path (when CUDA_KERNEL_ASSERT is compiled out), clamping out-of-bounds write_indices causes all OOB writes to be redirected to max_seq_len - 1. That can silently corrupt the last cache position (potentially overwriting valid data multiple times) instead of preserving the original output for invalid indices. Consider making invalid indices a no-op (skip the write) and/or clamping wi to the largest valid start index (max_seq_len - seq_len) so that the entire update range stays in-bounds without collapsing onto the last element.
int64_t wi = (write_indices != nullptr) ? write_indices[batch_idx] : 0;
CUDA_KERNEL_ASSERT(wi >= 0);
wi = max(wi, static_cast<int64_t>(0)); // Clamp for release safety (CUDA_KERNEL_ASSERT is debug-only)
int64_t cache_pos;
if (circular) {
cache_pos = (wi + seq_idx) % max_seq_len;
} else {
cache_pos = wi + seq_idx;
CUDA_KERNEL_ASSERT(cache_pos < max_seq_len);
cache_pos = min(cache_pos, max_seq_len - 1); // Clamp for release safety (CUDA_KERNEL_ASSERT is debug-only)
}
int64_t out_offset = prefix_idx * (max_seq_len * suffix_count) + cache_pos * suffix_count + suffix_idx;
output_data[out_offset] = update_data[id];
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Please consider using |
…NH layout - tensorscatter.cc: clarify that clamping provides memory-safe (not output-correct) behavior - attention.cc: add inline comment explaining BSNH dimension semantics for LaunchConcatNewToPastKV Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Adds CUDA EP test coverage for the Flash-eligible decode configuration with past_key/past_value and boolean attention mask. Three tests added: - All-true mask: verifies basic Flash decode + fp16 + bool mask path - Partial mask [T,T,T,F]: verifies Flash's seqlens_k-based masking (CUDA-only) - First decode (past_seq=0): verifies empty past tensor edge case Uses analytically-verifiable test data (uniform keys → uniform softmax) so expected output values can be computed exactly. Part of #27612 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…tter overflow - Modify ConcatNewToPastKV kernel to write zeros for tail positions (beyond valid tokens) ensuring deterministic present_key/value output - Fix int64 overflow in tensorscatter: pre-clamp wi before addition, handle negative modulo in circular mode - Tighten test tolerance for partial mask decode test - Update comments for accuracy Part of #27612 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
…pping, add optimization TODO - Revert tail-zeroing in ConcatNewToPastKV kernels (wasted I/O; Flash respects seqlens_k) - Add comment explaining intentional no-zero-fill for trailing KV cache positions - Fix ToCudaType → OrtToCudaType in rotary_embedding.cc for native bf16 HW arithmetic - Add __nv_bfloat16 template instantiation in rotary_embedding_impl.cu - Add TODO(titaiwang) for single-kernel preprocessing optimization - Simplify partial mask test (skip tail validation) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Guard has_output_qk with #if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION - Remove duplicate #include <cuda_bf16.h> in rotary_embedding_impl.cu Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…sk test - Switch decode path CudaT from ToCudaType to OrtToCudaType (native __nv_bfloat16) - Add __nv_bfloat16 template instantiation for LaunchConcatNewToPastKV - Simplify partial mask test: skip present_key/value content validation (trailing positions intentionally uninitialized for performance) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- Rename CudaT → NativeCudaT in attention.cc and rotary_embedding.cc Flash/kernel paths to distinguish OrtToCudaType (native __nv_bfloat16) from ToCudaType (ORT BFloat16 wrapper) in the unfused path - Remove dead BFloat16 template instantiation in rotary_embedding_impl.cu (only __nv_bfloat16 is used by llm/ callers via OrtToCudaType) - Modernize typedef → using in rotary_embedding.cc and attention.cc - Improve comments: clarify right-padding convention, uninitialized tail semantics, seqlen_offset docs, cyclic broadcast wording, rotary head_size constraint, tensorscatter clamping Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- tensorscatter_impl.cu: replace double-modulo idiom with single % operator since wi and seq_idx are guaranteed non-negative after clamping on the preceding lines - attention_op_test.cc: replace 1e10 tolerance hack for present_key/ value with SetCustomOutputVerifier that validates only Y output. IEEE 754 NaN comparisons always return false regardless of tolerance, so uninitialized tail memory containing NaN would fail the test. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
a8d21b9 to
9184039
Compare
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
onnxruntime/core/providers/cuda/llm/attention_mask_impl.cu:86
- ConvertMaskToSeqlensKernel treats any row with mask_row[0]==false as an all-false (seq_len=0) mask and skips scanning the rest of the row. That means an invalid/non-contiguous pattern like [False, True, ...] will not trigger CUDA_KERNEL_ASSERT in debug builds and will be silently misinterpreted as length 0 in release builds. If right-padding convention validation is intended, consider checking the remaining elements are all False when mask_row[0] is False (at least under CUDA_KERNEL_ASSERT/debug).
if (!mask_row[0]) {
// Entire row is padding (all-false mask)
seq_len = 0;
} else {
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Addressing @tianleiwu's Review FeedbackHi @tianleiwu — here's how each of your feedback items has been addressed: 1. Consider using
|
- tensorscatter_impl.cu: move min(wi, max_seq_len) into linear branch only. In circular (ring buffer) mode, wi > max_seq_len is legal and should wrap via modulo, not be clamped. Use (wi % max_seq_len + seq_idx) % max_seq_len to prevent int64 overflow while preserving wrap semantics. - tensorscatter.cc: update comment to distinguish circular (modulo) vs linear (clamping) bounds handling. - attention_op_test.cc: enhance custom verifier to validate present_key/value prefix (past rows + new token) while still skipping uninitialized tail positions (NaN-safe). - attention_op_test.cc: rename FirstDecode test to PromptPath — past KV are absent (None), exercising mha_fwd prompt path, not decode. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
9184039 to
f3424db
Compare
- Change explicit lambda capture to [&] per ORT convention - Add FlashAttention_Decode_PartialMask_MultiBatch test with batch_size=2 and varying per-batch past_seqlens Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This pull request introduces a major optimization to the CUDA attention kernel by fusing the key/value (KV) cache concatenation step into a single kernel, reducing memory overhead and kernel launches during autoregressive decoding. The update also clarifies and improves documentation and safety around attention mask handling, especially regarding right-padding validity and mask conversion. Support for native
__nv_bfloat16types is expanded for better hardware utilization.Performance and Kernel Fusion Improvements:
LaunchConcatNewToPastKVto fuse the past KV copy and new token append into a single CUDA kernel, eliminating the previous memset and strided copy overhead. This reduces memory traffic and kernel launch overhead during decoding, matching the efficiency of the contrib GQA path. [1] [2]__nv_bfloat16types is added to the fused concat kernel, enabling direct hardware arithmetic on supported GPUs and aligning with GQA's early type conversion pattern. [1] [2]Attention Mask Handling and Documentation:
Truevalues are counted. [1] [2] [3] [4]Code Quality and Minor Fixes:
usinginstead oftypedeffor CUDA type aliases, and adds missing header includes for type conversion. [1] [2]