Skip to content

[WebGPU EP] MultiHeadAttention fix#27791

Draft
edgchen1 wants to merge 3 commits intomainfrom
edgchen1/webgpu_fix_mha_issue
Draft

[WebGPU EP] MultiHeadAttention fix#27791
edgchen1 wants to merge 3 commits intomainfrom
edgchen1/webgpu_fix_mha_issue

Conversation

@edgchen1
Copy link
Contributor

Description

Fix WebGPU MultiHeadAttention flash path ignoring output_count for past KV

The flash attention path in MultiHeadAttention always used past_key/past_value
inputs even when present_key/present_value outputs were not requested. The
non-flash path (ApplyAttention) correctly ignores past when present outputs
are not available. This mismatch caused incorrect results on GPUs with
subgroup support, where the flash path is selected.

Fix by adding the same check inside ApplyFlashAttention: when present_key and
present_value are both nullptr, null out past_key/past_value and adjust
past_sequence_length and total_sequence_length accordingly. The parameters
signature is changed from const ref to by-value to allow this adjustment.

Motivation and Context

Address test failure:

"name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",

    [webgpu]MultiHeadAttention - MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue
e Validator 2026-03-20T23:45:27.942Z|abs/rel check failed-- index:3: actual=20,expected=12
e TestRunner 2026-03-20T23:45:27.945Z|Tensor mismatch: 
ACTUAL: type=float32; dims=[1,1,4]; data=[17,18,19,20]
EXPECT: type=float32; dims=[1,1,4]; data=[9,10,11,12]

edgchen1 and others added 3 commits March 18, 2026 14:27
Fix two bugs in the pack_input_as4 code path when a_components == 1
(triggered when input channels per group is not divisible by 2 or 4):

1. Weight values were read using x_offset (input tensor offset) instead
   of w_offset (weight tensor offset), producing incorrect results.

2. The remainder loop added raw input values to dotProd without
   multiplying by the corresponding weight values.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…st KV

The flash attention path in MultiHeadAttention always used past_key/past_value
inputs even when present_key/present_value outputs were not requested. The
non-flash path (ApplyAttention) correctly ignores past when present outputs
are not available. This mismatch caused incorrect results on GPUs with
subgroup support, where the flash path is selected.

Fix by adding the same check inside ApplyFlashAttention: when present_key and
present_value are both nullptr, null out past_key/past_value and adjust
past_sequence_length and total_sequence_length accordingly. The parameters
signature is changed from const ref to by-value to allow this adjustment.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a correctness mismatch in the WebGPU EP’s MultiHeadAttention flash-attention path: when present_key/present_value outputs are not requested, the flash path should ignore past_key/past_value (matching the non-flash path behavior), avoiding incorrect results on subgroup-capable GPUs where flash attention is selected.

Changes:

  • Update ApplyFlashAttention to ignore past KV when both present outputs are not requested, and adjust sequence-length parameters accordingly.
  • Change ApplyFlashAttention’s parameters argument from const& to by-value to allow local adjustment.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
onnxruntime/contrib_ops/webgpu/bert/flash_attention.h Updates ApplyFlashAttention signature to pass WebgpuAttentionParameters by value.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Implements the “ignore past KV when present outputs aren’t requested” behavior in the flash-attention path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +420 to +421
ORT_ENFORCE((present_key == nullptr) == (present_value == nullptr),
"present_key and present_value must be both nullptr or both non-nullptr");
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new ORT_ENFORCE requiring present_key and present_value to be both nullptr or both non-nullptr can break valid graphs that request only one of the optional outputs (schemas allow outputs to be omitted independently, and CPU/CUDA paths already tolerate nullptr per-output). Consider removing this enforce and just treating "both nullptr" as the condition to ignore past; otherwise keep allocating internal tensors for whichever output is missing.

Suggested change
ORT_ENFORCE((present_key == nullptr) == (present_value == nullptr),
"present_key and present_value must be both nullptr or both non-nullptr");

Copilot uses AI. Check for mistakes.
past_key = nullptr;
past_value = nullptr;
parameters.past_sequence_length_ = 0;
parameters.total_sequence_length_ = parameters.kv_sequence_length_;
Copy link

Copilot AI Mar 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When present outputs are not requested, this block nulls out past_* and shrinks total_sequence_length_. However seqlen_k (graph-capture path) is still honored later via use_seqlen_k, and the WGSL uses seqlens_k[0]+1 as total_sequence_length. That can cause the shaders to index past the internally-created present_key/present_value buffers (now sized to kv_sequence_length_), leading to OOB reads/writes or incorrect results. Consider also disabling seqlen_k/indirect-dispatch in this case (e.g., set the local seqlen_k pointer to nullptr and ensure use_seqlen_k/use_indirect_dispatch are false) so shader-side total length stays consistent with the resized parameters/buffers.

Suggested change
parameters.total_sequence_length_ = parameters.kv_sequence_length_;
parameters.total_sequence_length_ = parameters.kv_sequence_length_;
// Disable seqlen_k-based dispatch when present outputs are not requested to keep
// shader-side total sequence length consistent with the resized parameters/buffers.
seqlen_k = nullptr;

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants