Skip to content

[WebGPU] Add opset 24 support and Gemma 4 GQA enhancements#28501

Open
feich-ms wants to merge 3 commits into
mainfrom
user/feich/webgpu-gemma4-opset24
Open

[WebGPU] Add opset 24 support and Gemma 4 GQA enhancements#28501
feich-ms wants to merge 3 commits into
mainfrom
user/feich/webgpu-gemma4-opset24

Conversation

@feich-ms
Copy link
Copy Markdown
Contributor

Models exported at opset 24 (e.g. Gemma 4) require these registrations to avoid falling back to CPU for basic tensor operations.

Description

This PR adds two categories of WebGPU EP changes to support Gemma 4 E2B model inference:

Opset 24 kernel registrations:

  • Cast op: version the opset 23 registration to 23-23 and add opset 24 registration
  • Shape op: version the opset 23 registration to 23-23 and add opset 24 registration
  • Updated webgpu_execution_provider.cc registration table accordingly

GQA (GroupQueryAttention) enhancements for Gemma 4:

  • Support variable head_dim (Gemma 4 uses head_dim=256 vs the typical 128)
  • Make present_key/present_value outputs optional to support KV-shared layers (Gemma 4 has 15 KV layers shared across 35 GQA nodes, where 20 nodes reuse cached KV from earlier layers)
  • Add unit tests for variable head_dim and optional present KV outputs

Motivation and Context

Gemma 4 E2B is exported at ONNX opset 24. Without Cast and Shape registrations for opset 24, these ops fall back to CPU, causing unnecessary data transfers and potential failures.

Gemma 4 also uses an unconventional attention architecture with head_dim=256 and KV-shared layers where only 15 of 35 GQA nodes produce new KV cache outputs. The GQA kernel changes enable the WebGPU EP to handle both of these patterns correctly.

🤖 Generated with Claude Code

Models exported at opset 24 (e.g. Gemma 4) require these registrations
to avoid falling back to CPU for basic tensor operations.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
@feich-ms feich-ms requested a review from Copilot May 14, 2026 03:56
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds WebGPU EP support for Gemma 4 E2B by registering opset 24 versions of Cast and Shape, and by extending the GroupQueryAttention kernel to handle the model's variable head_dim (e.g. 256) and KV-shared layers where kv_sequence_length==0 and present_key/present_value are optional.

Changes:

  • Cast and Shape get explicit opset 23–23 versioned registrations plus new opset 24 registrations (kernels themselves unchanged).
  • WebGPU GroupQueryAttention now accepts a missing present_key/present_value and a zero-length new KV input, routing the call through flash attention using past_key/past_value as the KV context; rotary on Q is supported via a dummy K buffer.
  • WebgpuAttentionParameters(GroupQueryAttentionParameters) now initializes kv_sequence_length_ from parameters.kv_sequence_length (previously the Q sequence length), and ApplyFlashAttention's internal present K/V buffer shape uses kv_num_heads_ instead of num_heads_.
  • New WebGPU tests cross-check shared-KV decode/prompt/GQA-ratio/rotary against CPU.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Splits Shape and Cast opset 23 registrations into 23–23 versioned and adds opset 24 registrations.
onnxruntime/core/providers/webgpu/tensor/shape_op.cc Adds Shape kernel definition for opset 23–23 and bumps the open-ended registration to opset 24.
onnxruntime/core/providers/webgpu/tensor/cast.cc Adds explicit CreateCastKernelInfo<23,23> and CreateCastKernelInfo<24> template instantiations.
onnxruntime/contrib_ops/webgpu/bert/attention_common.h Fixes GQA kv_sequence_length_ to read from parameters.kv_sequence_length rather than Q's sequence length.
onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc Uses kv_num_heads_ for internal present K/V shape; adds a kv_empty branch that aliases past_key/past_value as present_key/present_value via const_cast.
onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc Drops the hard requirement on present K/V outputs, adds a kv_empty fast path that skips K/V processing (with dummy K for rotary), and errors out if the non-flash path is reached with shared KV.
onnxruntime/test/contrib_ops/group_query_attention_op_test.cc Adds a use_webgpu parameter to two shared-KV test helpers and four new WebGPU shared-KV tests cross-checked against CPU.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +293 to +296
// K output is discarded since attention uses past_key/past_value.
qRotary = context.CreateGPUTensor(query->DataType(), query->Shape());
kDummy = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, 1, parameters.kv_hidden_size_}));
kRotary = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, 1, parameters.kv_hidden_size_}));
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9cf7ea3. Sized kDummy/kRotary with parameters.sequence_length_ instead of hardcoded 1, so the kernel's K write offsets stay in bounds during prefill. Also added a new test SharedKV_EmptyKV_WithPast_Rotary_Prompt_WebGPU with q_seq_len=6 to cover this case.

Comment on lines +465 to +477

if (kv_empty) {
// kv_sequence_length==0: K/V inputs are empty (shared KV layer).
// Skip CopyKVCache and fused split+rotary+copyKV.
// Use past_key/past_value directly as the present buffers for attention.
ORT_ENFORCE(!do_rotary, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV should not be used with kv_sequence_length==0.");
if (past_key != nullptr && past_value != nullptr) {
// Safe: flash attention kernels only read from present_key/present_value.
// CopyKVCache is skipped when kv_empty, so no writes through these pointers.
present_key = const_cast<Tensor*>(past_key);
present_value = const_cast<Tensor*>(past_value);
}
// If past is also null, present_key/present_value were already set to internal empty tensors above.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 9cf7ea3. Added an ORT_ENFORCE(!parameters.past_present_share_buffer_) after the aliasing to guard against future refactors that might invoke CopyKVCache (which writes to present buffers). This makes the contract explicit: the kv_empty path must never reach a write path through present_key/present_value.

feich-ms and others added 2 commits May 14, 2026 13:28
- Size kDummy/kRotary with sequence_length (not 1) to match the kernel's
  iteration domain, preventing OOB writes during prefill (q_seq > 1)
- Add ORT_ENFORCE in flash_attention kv_empty path to guard against
  future refactors that might write through the const_cast'd pointers
- Add new test SharedKV_EmptyKV_WithPast_Rotary_Prompt_WebGPU exercising
  the rotary + kv_empty path with q_seq_len=6

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The internal present KV buffer shape must use num_heads_ for MHA
(where kv_num_heads_ is 0) and kv_num_heads_ only for GQA. Using
kv_num_heads_ unconditionally caused zero-sized buffers for MHA
CrossAttention tests.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label May 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants