[WebGPU] Add opset 24 support and Gemma 4 GQA enhancements#28501
[WebGPU] Add opset 24 support and Gemma 4 GQA enhancements#28501feich-ms wants to merge 3 commits into
Conversation
Models exported at opset 24 (e.g. Gemma 4) require these registrations to avoid falling back to CPU for basic tensor operations. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds WebGPU EP support for Gemma 4 E2B by registering opset 24 versions of Cast and Shape, and by extending the GroupQueryAttention kernel to handle the model's variable head_dim (e.g. 256) and KV-shared layers where kv_sequence_length==0 and present_key/present_value are optional.
Changes:
- Cast and Shape get explicit opset 23–23 versioned registrations plus new opset 24 registrations (kernels themselves unchanged).
- WebGPU GroupQueryAttention now accepts a missing
present_key/present_valueand a zero-length new KV input, routing the call through flash attention usingpast_key/past_valueas the KV context; rotary on Q is supported via a dummy K buffer. WebgpuAttentionParameters(GroupQueryAttentionParameters)now initializeskv_sequence_length_fromparameters.kv_sequence_length(previously the Q sequence length), andApplyFlashAttention's internal present K/V buffer shape useskv_num_heads_instead ofnum_heads_.- New WebGPU tests cross-check shared-KV decode/prompt/GQA-ratio/rotary against CPU.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Splits Shape and Cast opset 23 registrations into 23–23 versioned and adds opset 24 registrations. |
| onnxruntime/core/providers/webgpu/tensor/shape_op.cc | Adds Shape kernel definition for opset 23–23 and bumps the open-ended registration to opset 24. |
| onnxruntime/core/providers/webgpu/tensor/cast.cc | Adds explicit CreateCastKernelInfo<23,23> and CreateCastKernelInfo<24> template instantiations. |
| onnxruntime/contrib_ops/webgpu/bert/attention_common.h | Fixes GQA kv_sequence_length_ to read from parameters.kv_sequence_length rather than Q's sequence length. |
| onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | Uses kv_num_heads_ for internal present K/V shape; adds a kv_empty branch that aliases past_key/past_value as present_key/present_value via const_cast. |
| onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc | Drops the hard requirement on present K/V outputs, adds a kv_empty fast path that skips K/V processing (with dummy K for rotary), and errors out if the non-flash path is reached with shared KV. |
| onnxruntime/test/contrib_ops/group_query_attention_op_test.cc | Adds a use_webgpu parameter to two shared-KV test helpers and four new WebGPU shared-KV tests cross-checked against CPU. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // K output is discarded since attention uses past_key/past_value. | ||
| qRotary = context.CreateGPUTensor(query->DataType(), query->Shape()); | ||
| kDummy = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, 1, parameters.kv_hidden_size_})); | ||
| kRotary = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, 1, parameters.kv_hidden_size_})); |
There was a problem hiding this comment.
Fixed in 9cf7ea3. Sized kDummy/kRotary with parameters.sequence_length_ instead of hardcoded 1, so the kernel's K write offsets stay in bounds during prefill. Also added a new test SharedKV_EmptyKV_WithPast_Rotary_Prompt_WebGPU with q_seq_len=6 to cover this case.
|
|
||
| if (kv_empty) { | ||
| // kv_sequence_length==0: K/V inputs are empty (shared KV layer). | ||
| // Skip CopyKVCache and fused split+rotary+copyKV. | ||
| // Use past_key/past_value directly as the present buffers for attention. | ||
| ORT_ENFORCE(!do_rotary, "Fused SplitPackedQKVWithRotaryEmbeddingAndCopyKV should not be used with kv_sequence_length==0."); | ||
| if (past_key != nullptr && past_value != nullptr) { | ||
| // Safe: flash attention kernels only read from present_key/present_value. | ||
| // CopyKVCache is skipped when kv_empty, so no writes through these pointers. | ||
| present_key = const_cast<Tensor*>(past_key); | ||
| present_value = const_cast<Tensor*>(past_value); | ||
| } | ||
| // If past is also null, present_key/present_value were already set to internal empty tensors above. |
There was a problem hiding this comment.
Fixed in 9cf7ea3. Added an ORT_ENFORCE(!parameters.past_present_share_buffer_) after the aliasing to guard against future refactors that might invoke CopyKVCache (which writes to present buffers). This makes the contract explicit: the kv_empty path must never reach a write path through present_key/present_value.
- Size kDummy/kRotary with sequence_length (not 1) to match the kernel's iteration domain, preventing OOB writes during prefill (q_seq > 1) - Add ORT_ENFORCE in flash_attention kv_empty path to guard against future refactors that might write through the const_cast'd pointers - Add new test SharedKV_EmptyKV_WithPast_Rotary_Prompt_WebGPU exercising the rotary + kv_empty path with q_seq_len=6 Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
The internal present KV buffer shape must use num_heads_ for MHA (where kv_num_heads_ is 0) and kv_num_heads_ only for GQA. Using kv_num_heads_ unconditionally caused zero-sized buffers for MHA CrossAttention tests. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Models exported at opset 24 (e.g. Gemma 4) require these registrations to avoid falling back to CPU for basic tensor operations.
Description
This PR adds two categories of WebGPU EP changes to support Gemma 4 E2B model inference:
Opset 24 kernel registrations:
webgpu_execution_provider.ccregistration table accordinglyGQA (GroupQueryAttention) enhancements for Gemma 4:
head_dim(Gemma 4 uses head_dim=256 vs the typical 128)present_key/present_valueoutputs optional to support KV-shared layers (Gemma 4 has 15 KV layers shared across 35 GQA nodes, where 20 nodes reuse cached KV from earlier layers)Motivation and Context
Gemma 4 E2B is exported at ONNX opset 24. Without Cast and Shape registrations for opset 24, these ops fall back to CPU, causing unnecessary data transfers and potential failures.
Gemma 4 also uses an unconventional attention architecture with
head_dim=256and KV-shared layers where only 15 of 35 GQA nodes produce new KV cache outputs. The GQA kernel changes enable the WebGPU EP to handle both of these patterns correctly.🤖 Generated with Claude Code