Conversation
Fix two bugs in the pack_input_as4 code path when a_components == 1 (triggered when input channels per group is not divisible by 2 or 4): 1. Weight values were read using x_offset (input tensor offset) instead of w_offset (weight tensor offset), producing incorrect results. 2. The remainder loop added raw input values to dotProd without multiplying by the corresponding weight values. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…st KV The flash attention path in MultiHeadAttention always used past_key/past_value inputs even when present_key/present_value outputs were not requested. The non-flash path (ApplyAttention) correctly ignores past when present outputs are not available. This mismatch caused incorrect results on GPUs with subgroup support, where the flash path is selected. Fix by adding the same check inside ApplyFlashAttention: when present_key and present_value are both nullptr, null out past_key/past_value and adjust past_sequence_length and total_sequence_length accordingly. The parameters signature is changed from const ref to by-value to allow this adjustment. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR fixes a correctness mismatch in the WebGPU EP’s MultiHeadAttention flash-attention path: when present_key/present_value outputs are not requested, the flash path should ignore past_key/past_value (matching the non-flash path behavior), avoiding incorrect results on subgroup-capable GPUs where flash attention is selected.
Changes:
- Update
ApplyFlashAttentionto ignore past KV when both present outputs are not requested, and adjust sequence-length parameters accordingly. - Change
ApplyFlashAttention’s parameters argument fromconst&to by-value to allow local adjustment.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.h | Updates ApplyFlashAttention signature to pass WebgpuAttentionParameters by value. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Implements the “ignore past KV when present outputs aren’t requested” behavior in the flash-attention path. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ORT_ENFORCE((present_key == nullptr) == (present_value == nullptr), | ||
| "present_key and present_value must be both nullptr or both non-nullptr"); |
There was a problem hiding this comment.
The new ORT_ENFORCE requiring present_key and present_value to be both nullptr or both non-nullptr can break valid graphs that request only one of the optional outputs (schemas allow outputs to be omitted independently, and CPU/CUDA paths already tolerate nullptr per-output). Consider removing this enforce and just treating "both nullptr" as the condition to ignore past; otherwise keep allocating internal tensors for whichever output is missing.
| ORT_ENFORCE((present_key == nullptr) == (present_value == nullptr), | |
| "present_key and present_value must be both nullptr or both non-nullptr"); |
| past_key = nullptr; | ||
| past_value = nullptr; | ||
| parameters.past_sequence_length_ = 0; | ||
| parameters.total_sequence_length_ = parameters.kv_sequence_length_; |
There was a problem hiding this comment.
When present outputs are not requested, this block nulls out past_* and shrinks total_sequence_length_. However seqlen_k (graph-capture path) is still honored later via use_seqlen_k, and the WGSL uses seqlens_k[0]+1 as total_sequence_length. That can cause the shaders to index past the internally-created present_key/present_value buffers (now sized to kv_sequence_length_), leading to OOB reads/writes or incorrect results. Consider also disabling seqlen_k/indirect-dispatch in this case (e.g., set the local seqlen_k pointer to nullptr and ensure use_seqlen_k/use_indirect_dispatch are false) so shader-side total length stays consistent with the resized parameters/buffers.
| parameters.total_sequence_length_ = parameters.kv_sequence_length_; | |
| parameters.total_sequence_length_ = parameters.kv_sequence_length_; | |
| // Disable seqlen_k-based dispatch when present outputs are not requested to keep | |
| // shader-side total sequence length consistent with the resized parameters/buffers. | |
| seqlen_k = nullptr; |
Description
Fix WebGPU MultiHeadAttention flash path ignoring output_count for past KV
The flash attention path in MultiHeadAttention always used past_key/past_value
inputs even when present_key/present_value outputs were not requested. The
non-flash path (ApplyAttention) correctly ignores past when present outputs
are not available. This mismatch caused incorrect results on GPUs with
subgroup support, where the flash path is selected.
Fix by adding the same check inside ApplyFlashAttention: when present_key and
present_value are both nullptr, null out past_key/past_value and adjust
past_sequence_length and total_sequence_length accordingly. The parameters
signature is changed from const ref to by-value to allow this adjustment.
Motivation and Context
Address test failure:
onnxruntime/js/web/test/data/ops/multihead-attention.jsonc
Line 260 in b1c5e04