From e15bd7228eab8ff03de7297470dcdef9262bd872 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Thu, 18 Dec 2025 16:52:41 +0800 Subject: [PATCH 1/8] [webgpu] Optimize attentionPrepare --- .../contrib_ops/webgpu/bert/attention.cc | 118 ++++++++++++------ .../contrib_ops/webgpu/bert/attention.h | 10 ++ .../webgpu/bert/attention_common.h | 3 + .../webgpu/bert/group_query_attention.cc | 36 ------ .../webgpu/bert/group_query_attention.h | 10 -- 5 files changed, 96 insertions(+), 81 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 8929c6b7cf6e4..f85892103d0eb 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -8,6 +8,7 @@ #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/math/matmul.h" using namespace onnxruntime::webgpu; using namespace ::onnxruntime::common; using namespace ONNX_NAMESPACE; @@ -70,6 +71,43 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h return context.RunProgram(program); }; +Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { + const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform); + const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); + const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); + const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); + sh.MainFunctionBody() << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" + << " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n" + << " let index = " << packed_qkv.IndicesGet("packed_qkv_indices", "2") << ";\n" + << " if (index < uniforms.hidden_size) {\n" + << " " << query.SetByIndices("packed_qkv_indices", "input_data") << ";\n" + << " } else if (index < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n" + << " var key_indices = packed_qkv_indices;\n" + << " " << key.IndicesSet("key_indices", "2", "u32(index - uniforms.hidden_size)") << ";\n" + << " " << key.SetByIndices("key_indices", "input_data") << ";\n" + << " } else {\n" + << " var val_indices = packed_qkv_indices;\n" + << " " << value.IndicesSet("val_indices", "2", "u32(index - uniforms.hidden_size - uniforms.kv_hidden_size)") << ";\n" + << " " << value.SetByIndices("val_indices", "input_data") << ";\n" + << " }"; + return Status::OK(); +} + +Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, + const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) { + SplitPackedQKVProgram program; + auto input_size = packedQKV->Shape().Size(); + program + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}) + .AddOutputs({{query, ProgramTensorMetadataDependency::None}, {key, ProgramTensorMetadataDependency::None}, {val, ProgramTensorMetadataDependency::None}}) + .AddUniformVariables({ + {static_cast(params.hidden_size_)}, + {static_cast(params.kv_hidden_size_)}, + }) + .SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) { if (has_seqlen_k) { ss << "total_sequence_length = u32(seqlen_k[batch_idx]) + 1;\n"; @@ -646,33 +684,23 @@ class AttentionPrepareProgram final : public Program { Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, const Tensor* input, const Tensor* weights, const Tensor* bias, Tensor* q, Tensor* k, Tensor* v) { - constexpr int TILE_SIZE = 12; - const int M = parameters.sequence_length_; - const int K = parameters.input_hidden_size_; - const int N = parameters.head_size_; - - const uint32_t dispatch_x = (parameters.head_size_ + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t dispatch_y = (parameters.sequence_length_ + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t dispatch_z = parameters.batch_size_ * parameters.num_heads_; - - AttentionPrepareProgram program{}; - program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, - {weights, ProgramTensorMetadataDependency::TypeAndRank}, - {bias, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({{q, ProgramTensorMetadataDependency::TypeAndRank}, - {k, ProgramTensorMetadataDependency::TypeAndRank}, - {v, ProgramTensorMetadataDependency::TypeAndRank}}) - .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) - .SetWorkgroupSize(TILE_SIZE, TILE_SIZE) - .AddUniformVariables({{static_cast(M)}, - {static_cast(K)}, - {static_cast(N)}, - {static_cast(parameters.num_heads_)}, - {static_cast(parameters.head_size_)}, - {static_cast(parameters.hidden_size_)}, - {static_cast(parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_)}}); + // Use MatMul to compute packed QKV output: input * weights + bias + // Then use SplitPackedQKV to split into Q, K, V in BSD format + // Returns Q, K, V in BSD format - caller can convert to BNSH if needed - return context.RunProgram(program); + // Create packed QKV tensor with shape [batch_size, sequence_length, hidden_size + hidden_size + v_hidden_size] + const int64_t packed_qkv_size = parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_; + TensorShapeVector packed_qkv_shape({parameters.batch_size_, parameters.sequence_length_, packed_qkv_size}); + Tensor packed_qkv = context.CreateGPUTensor(input->DataType(), TensorShape(packed_qkv_shape)); + + // Prepare inputs for MatMul + std::vector matmul_inputs = {input, weights, bias}; + + // Call MatMul: packed_qkv = input * weights + bias + ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true)); + + // Split the packed QKV into Q, K, V in BSD format + return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v); } Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { @@ -727,15 +755,18 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) ORT_NOT_IMPLEMENTED("present tensor not implemented for webgpu Attention"); } - // Create Q, K, V tensors by computing input * weights + bias - TensorShapeVector qkv_shape({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, parameters.head_size_}); - Tensor Q = context.CreateGPUTensor(input->DataType(), qkv_shape); - Tensor K = context.CreateGPUTensor(input->DataType(), qkv_shape); - Tensor V = context.CreateGPUTensor(input->DataType(), qkv_shape); + // Create Q, K, V tensors in BSD format from input * weights + bias + TensorShapeVector qkv_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}); + TensorShapeVector v_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.v_hidden_size_}); + Tensor Q_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape)); + Tensor K_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape)); + Tensor V_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(v_bsd_shape)); + + // Compute Q, K, V from input, weights, and bias (returns BSD format) + ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q_bsd, &K_bsd, &V_bsd)); - // Compute Q, K, V from input, weights, and bias - ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V)); + // Update parameters for Q_K_V_BSNH format + parameters.qkv_format_ = Q_K_V_BSNH; // Check if we can use flash attention // For Attention operator, we need to create present_key and present_value tensors for flash attention @@ -746,10 +777,27 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) Tensor present_value = context.CreateGPUTensor(input->DataType(), present_kv_shape); if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) { - return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value, + // FlashAttention supports Q_K_V_BSNH format directly + return ApplyFlashAttention(&Q_bsd, &K_bsd, &V_bsd, attention_bias, output, nullptr, &present_key, nullptr, &present_value, parameters, context, nullptr); } + // For non-flash attention path, convert BSD to BNSH format + TensorShapeVector qkv_bnsh_shape({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.head_size_}); + TensorShapeVector v_bnsh_shape({parameters.batch_size_, parameters.num_heads_, + parameters.sequence_length_, parameters.v_head_size_}); + Tensor Q = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); + Tensor K = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); + Tensor V = context.CreateGPUTensor(input->DataType(), TensorShape(v_bnsh_shape)); + + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, + parameters.head_size_, &Q_bsd, nullptr, 0, &Q)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, + parameters.head_size_, &K_bsd, nullptr, 0, &K)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, + parameters.v_head_size_, &V_bsd, nullptr, 0, &V)); + // Apply the actual attention computation return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr, /* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index feaee1df7f6fc..b4afe7938883d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -32,6 +32,16 @@ class TransferBSDToBNSHProgram final : public Program bool has_bias_; }; +class SplitPackedQKVProgram final : public Program { + public: + SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}); +}; + class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index eb1d637896f1d..d0fd17212a57a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -122,6 +122,9 @@ struct WebgpuAttentionParameters { Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor); +Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, + const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val); + Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, Tensor* output_qk, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 7e0cfb94d2bd2..59e0c7cfccc99 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -20,28 +20,6 @@ namespace onnxruntime { namespace contrib { namespace webgpu { -Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { - const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform); - const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); - const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); - const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); - sh.MainFunctionBody() << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" - << " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n" - << " let index = " << packed_qkv.IndicesGet("packed_qkv_indices", "2") << ";\n" - << " if (index < uniforms.hidden_size) {\n" - << " " << query.SetByIndices("packed_qkv_indices", "input_data") << ";\n" - << " } else if (index < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n" - << " var key_indices = packed_qkv_indices;\n" - << " " << key.IndicesSet("key_indices", "2", "u32(index - uniforms.hidden_size)") << ";\n" - << " " << key.SetByIndices("key_indices", "input_data") << ";\n" - << " } else {\n" - << " var val_indices = packed_qkv_indices;\n" - << " " << value.IndicesSet("val_indices", "2", "u32(index - uniforms.hidden_size - uniforms.kv_hidden_size)") << ";\n" - << " " << value.SetByIndices("val_indices", "input_data") << ";\n" - << " }"; - return Status::OK(); -} - Status SplitPackedQKVWithRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseUniform); const auto& seqlens = sh.AddInput("seqlens", ShaderUsage::UseUniform); @@ -63,20 +41,6 @@ Status SplitPackedQKVWithRotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper WGSL_TEMPLATE_VARIABLE(val, val)); } -Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) { - SplitPackedQKVProgram program; - auto input_size = packedQKV->Shape().Size(); - program - .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}) - .AddOutputs({{query, ProgramTensorMetadataDependency::None}, {key, ProgramTensorMetadataDependency::None}, {val, ProgramTensorMetadataDependency::None}}) - .AddUniformVariables({ - {static_cast(params.hidden_size_)}, - {static_cast(params.kv_hidden_size_)}, - }) - .SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); - return context.RunProgram(program); -} - // Split packed QKV with Q/K rotary embedding fusion Status RunSplitPackedQKVWithRotaryEmbedding(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h index 5e780c4ca4cdf..077ec7768ea07 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.h @@ -14,16 +14,6 @@ namespace webgpu { using namespace onnxruntime::webgpu; -class SplitPackedQKVProgram final : public Program { - public: - SplitPackedQKVProgram() : Program{"SplitPackedQKV"} {} - - Status GenerateShaderCode(ShaderHelper& sh) const override; - - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, - {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}); -}; - class SplitPackedQKVWithRotaryEmbeddingProgram final : public Program { public: SplitPackedQKVWithRotaryEmbeddingProgram(bool interleaved) : Program{"SplitPackedQKVWithRotaryEmbedding"}, interleaved_{interleaved} {} From 8d38f51bdf0ff7a3cbf35e6f0d19accc7ab7e7f5 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 19 Dec 2025 18:04:07 +0800 Subject: [PATCH 2/8] change the splitQKV to BNSH format --- .../contrib_ops/webgpu/bert/attention.cc | 94 +++++++++---------- .../contrib_ops/webgpu/bert/attention.h | 4 +- 2 files changed, 48 insertions(+), 50 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index f85892103d0eb..5f909235d12dd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -72,40 +72,52 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h }; Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { + // Inputs: packed_qkv [B, S, D], outputs: Q, K, V [B, N, S, H] const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform); const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); - sh.MainFunctionBody() << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" - << " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n" - << " let index = " << packed_qkv.IndicesGet("packed_qkv_indices", "2") << ";\n" - << " if (index < uniforms.hidden_size) {\n" - << " " << query.SetByIndices("packed_qkv_indices", "input_data") << ";\n" - << " } else if (index < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n" - << " var key_indices = packed_qkv_indices;\n" - << " " << key.IndicesSet("key_indices", "2", "u32(index - uniforms.hidden_size)") << ";\n" - << " " << key.SetByIndices("key_indices", "input_data") << ";\n" - << " } else {\n" - << " var val_indices = packed_qkv_indices;\n" - << " " << value.IndicesSet("val_indices", "2", "u32(index - uniforms.hidden_size - uniforms.kv_hidden_size)") << ";\n" - << " " << value.SetByIndices("val_indices", "input_data") << ";\n" - << " }"; + // Uniforms: hidden_size, kv_hidden_size, num_heads, head_size, v_head_size, sequence_length + sh.MainFunctionBody() + << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" + << " let batch = packed_qkv_indices[0];\n" + << " let seq = packed_qkv_indices[1];\n" + << " let d = packed_qkv_indices[2];\n" + << " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n" + << " if (d < uniforms.hidden_size) {\n" + << " let head = d / uniforms.head_size;\n" + << " let h = d % uniforms.head_size;\n" + << " " << query.SetByIndices("vec4(batch, head, seq, h)", "input_data") << ";\n" + << " } else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n" + << " let kd = d - uniforms.hidden_size;\n" + << " let head = kd / uniforms.head_size;\n" + << " let h = kd % uniforms.head_size;\n" + << " " << key.SetByIndices("vec4(batch, head, seq, h)", "input_data") << ";\n" + << " } else {\n" + << " let vd = d - uniforms.hidden_size - uniforms.hidden_size;\n" + << " let head = vd / uniforms.v_head_size;\n" + << " let h = vd % uniforms.v_head_size;\n" + << " " << value.SetByIndices("vec4(batch, head, seq, h)", "input_data") << ";\n" + << " }\n"; return Status::OK(); } Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) { - SplitPackedQKVProgram program; - auto input_size = packedQKV->Shape().Size(); - program + // Output Q, K, V in BNSH format + SplitPackedQKVProgram program; + auto input_size = packedQKV->Shape().Size(); + program .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}) - .AddOutputs({{query, ProgramTensorMetadataDependency::None}, {key, ProgramTensorMetadataDependency::None}, {val, ProgramTensorMetadataDependency::None}}) + .AddOutputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, {key, ProgramTensorMetadataDependency::TypeAndRank}, {val, ProgramTensorMetadataDependency::TypeAndRank}}) .AddUniformVariables({ - {static_cast(params.hidden_size_)}, - {static_cast(params.kv_hidden_size_)}, + {static_cast(params.hidden_size_)}, + {static_cast(params.kv_hidden_size_)}, + {static_cast(params.head_size_)}, + {static_cast(params.v_head_size_)}, }) .SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); - return context.RunProgram(program); + return context.RunProgram(program); } void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) { @@ -685,8 +697,8 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte const Tensor* input, const Tensor* weights, const Tensor* bias, Tensor* q, Tensor* k, Tensor* v) { // Use MatMul to compute packed QKV output: input * weights + bias - // Then use SplitPackedQKV to split into Q, K, V in BSD format - // Returns Q, K, V in BSD format - caller can convert to BNSH if needed + // Then use SplitPackedQKV to split into Q, K, V in BNSH format + // Returns Q, K, V in BNSH format for direct comparison with AttentionPrepareProgram // Create packed QKV tensor with shape [batch_size, sequence_length, hidden_size + hidden_size + v_hidden_size] const int64_t packed_qkv_size = parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_; @@ -699,7 +711,7 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte // Call MatMul: packed_qkv = input * weights + bias ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true)); - // Split the packed QKV into Q, K, V in BSD format + // Output Q, K, V in BNSH format return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v); } @@ -755,15 +767,15 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) ORT_NOT_IMPLEMENTED("present tensor not implemented for webgpu Attention"); } - // Create Q, K, V tensors in BSD format from input * weights + bias - TensorShapeVector qkv_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}); - TensorShapeVector v_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.v_hidden_size_}); - Tensor Q_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape)); - Tensor K_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape)); - Tensor V_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(v_bsd_shape)); + // Create Q, K, V tensors in BNSH format from input * weights + bias + TensorShapeVector qkv_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); + TensorShapeVector v_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.v_head_size_}); + Tensor Q = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); + Tensor K = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); + Tensor V = context.CreateGPUTensor(input->DataType(), TensorShape(v_bnsh_shape)); - // Compute Q, K, V from input, weights, and bias (returns BSD format) - ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q_bsd, &K_bsd, &V_bsd)); + // Compute Q, K, V from input, weights, and bias (returns BNSH format) + ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V)); // Update parameters for Q_K_V_BSNH format parameters.qkv_format_ = Q_K_V_BSNH; @@ -778,26 +790,10 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) { // FlashAttention supports Q_K_V_BSNH format directly - return ApplyFlashAttention(&Q_bsd, &K_bsd, &V_bsd, attention_bias, output, nullptr, &present_key, nullptr, &present_value, + return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value, parameters, context, nullptr); } - // For non-flash attention path, convert BSD to BNSH format - TensorShapeVector qkv_bnsh_shape({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, parameters.head_size_}); - TensorShapeVector v_bnsh_shape({parameters.batch_size_, parameters.num_heads_, - parameters.sequence_length_, parameters.v_head_size_}); - Tensor Q = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); - Tensor K = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); - Tensor V = context.CreateGPUTensor(input->DataType(), TensorShape(v_bnsh_shape)); - - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, - parameters.head_size_, &Q_bsd, nullptr, 0, &Q)); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, - parameters.head_size_, &K_bsd, nullptr, 0, &K)); - ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, - parameters.v_head_size_, &V_bsd, nullptr, 0, &V)); - // Apply the actual attention computation return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr, /* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index b4afe7938883d..1848b02f4c8f3 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -39,7 +39,9 @@ class SplitPackedQKVProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, - {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}); + {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"head_size", ProgramUniformVariableDataType::Uint32}, + {"v_head_size", ProgramUniformVariableDataType::Uint32}); }; class AttentionProbsProgram final : public Program { From dfb98e3f517170e53e0c6a670d94deefa7dc6c4f Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 22 Dec 2025 15:44:48 +0800 Subject: [PATCH 3/8] fix the bugs --- .../contrib_ops/webgpu/bert/attention.cc | 138 +++++++++++++++++- .../contrib_ops/webgpu/bert/attention.h | 1 - 2 files changed, 132 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 5f909235d12dd..3962effed27fd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -77,7 +77,6 @@ Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); - // Uniforms: hidden_size, kv_hidden_size, num_heads, head_size, v_head_size, sequence_length sh.MainFunctionBody() << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" << " let batch = packed_qkv_indices[0];\n" @@ -88,7 +87,7 @@ Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { << " let head = d / uniforms.head_size;\n" << " let h = d % uniforms.head_size;\n" << " " << query.SetByIndices("vec4(batch, head, seq, h)", "input_data") << ";\n" - << " } else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n" + << " } else if (d < (uniforms.hidden_size + uniforms.hidden_size)) {\n" << " let kd = d - uniforms.hidden_size;\n" << " let head = kd / uniforms.head_size;\n" << " let h = kd % uniforms.head_size;\n" @@ -112,7 +111,6 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu .AddOutputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, {key, ProgramTensorMetadataDependency::TypeAndRank}, {val, ProgramTensorMetadataDependency::TypeAndRank}}) .AddUniformVariables({ {static_cast(params.hidden_size_)}, - {static_cast(params.kv_hidden_size_)}, {static_cast(params.head_size_)}, {static_cast(params.v_head_size_)}, }) @@ -616,6 +614,61 @@ Attention::Attention(const OpKernelInfo& info) onnxruntime::contrib::AttentionBase(info, false) { } +// QKV preparation program - computes packed QKV from input, weights, and bias +class AttentionPreparePackedProgram final : public Program { + public: + AttentionPreparePackedProgram() : Program{"AttentionPreparePacked"} {} + + Status GenerateShaderCode(ShaderHelper& shader) const override { + shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("weight", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddOutput("packed_qkv", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + constexpr int TILE_SIZE = 12; + + shader.AdditionalImplementation() << "const TILE_SIZE = " << TILE_SIZE << "u;\n" + << "var tileInput: array;\n" + << "var tileWeight: array;\n"; + + shader.MainFunctionBody() + << "let batch_idx = workgroup_id.z;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let inputOffset = batch_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n" + << "var value = input_value_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];\n" + << " }\n" + << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << " let offset = n + (w + local_id.y) * uniforms.ldb;\n" + << " tileWeight[TILE_SIZE * local_id.y + local_id.x] = weight[offset];\n" + << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k { public: @@ -715,6 +768,82 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v); } +Status PrepareQKVVersion2(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, + const Tensor* input, const Tensor* weights, const Tensor* bias, + Tensor* q, Tensor* k, Tensor* v) { + // Use AttentionPreparePackedProgram to compute packed QKV output + // Then use SplitPackedQKV to split into Q, K, V in BNSH format + // Returns Q, K, V in BNSH format for comparison with PrepareQKV + + // Create packed QKV tensor with shape [batch_size, sequence_length, hidden_size + hidden_size + v_hidden_size] + const int64_t packed_qkv_size = parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_; + TensorShapeVector packed_qkv_shape({parameters.batch_size_, parameters.sequence_length_, packed_qkv_size}); + Tensor packed_qkv = context.CreateGPUTensor(input->DataType(), TensorShape(packed_qkv_shape)); + + // Use custom program to compute packed QKV + AttentionPreparePackedProgram program; + const int M = static_cast(parameters.sequence_length_); + const int K = static_cast(input->Shape().GetDims()[2]); // input hidden size + const int N = static_cast(packed_qkv_size); + const int ldb = N; // leading dimension of weight (assumed transposed) + + constexpr int TILE_SIZE = 12; + const int num_workgroups_x = (N + TILE_SIZE - 1) / TILE_SIZE; + const int num_workgroups_y = (M + TILE_SIZE - 1) / TILE_SIZE; + const int num_workgroups_z = parameters.batch_size_; + + program + .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, + {weights, ProgramTensorMetadataDependency::TypeAndRank}, + {bias, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({&packed_qkv, ProgramTensorMetadataDependency::TypeAndRank}) + .SetDispatchGroupSize(num_workgroups_x, num_workgroups_y, num_workgroups_z) + .SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1) + .AddUniformVariables({ + {static_cast(M)}, + {static_cast(K)}, + {static_cast(N)}, + {static_cast(ldb)}, + }); + + ORT_RETURN_IF_ERROR(context.RunProgram(program)); + + // Split the packed QKV into Q, K, V in BNSH format + return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v); +} + +Status PrepareQKVVersion3(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, + const Tensor* input, const Tensor* weights, const Tensor* bias, + Tensor* q, Tensor* k, Tensor* v) { + constexpr int TILE_SIZE = 12; + const int M = parameters.sequence_length_; + const int K = parameters.input_hidden_size_; + const int N = parameters.head_size_; + + const uint32_t dispatch_x = (parameters.head_size_ + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_y = (parameters.sequence_length_ + TILE_SIZE - 1) / TILE_SIZE; + const uint32_t dispatch_z = parameters.batch_size_ * parameters.num_heads_; + + AttentionPrepareProgram program{}; + program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, + {weights, ProgramTensorMetadataDependency::TypeAndRank}, + {bias, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{q, ProgramTensorMetadataDependency::TypeAndRank}, + {k, ProgramTensorMetadataDependency::TypeAndRank}, + {v, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) + .SetWorkgroupSize(TILE_SIZE, TILE_SIZE) + .AddUniformVariables({{static_cast(M)}, + {static_cast(K)}, + {static_cast(N)}, + {static_cast(parameters.num_heads_)}, + {static_cast(parameters.head_size_)}, + {static_cast(parameters.hidden_size_)}, + {static_cast(parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_)}}); + + return context.RunProgram(program); +} + Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* input = context.Input(0); const Tensor* weights = context.Input(1); @@ -777,9 +906,6 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) // Compute Q, K, V from input, weights, and bias (returns BNSH format) ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V)); - // Update parameters for Q_K_V_BSNH format - parameters.qkv_format_ = Q_K_V_BSNH; - // Check if we can use flash attention // For Attention operator, we need to create present_key and present_value tensors for flash attention // even though they are not exposed as outputs diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 1848b02f4c8f3..6540000845885 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -39,7 +39,6 @@ class SplitPackedQKVProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, - {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}, {"head_size", ProgramUniformVariableDataType::Uint32}, {"v_head_size", ProgramUniformVariableDataType::Uint32}); }; From cbb8ef22e26337651567a87c00a10e1c01302e78 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 22 Dec 2025 17:33:50 +0800 Subject: [PATCH 4/8] remove debugging codes --- .../contrib_ops/webgpu/bert/attention.cc | 208 ------------------ 1 file changed, 208 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 3962effed27fd..9ed0f99158fc9 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -614,138 +614,6 @@ Attention::Attention(const OpKernelInfo& info) onnxruntime::contrib::AttentionBase(info, false) { } -// QKV preparation program - computes packed QKV from input, weights, and bias -class AttentionPreparePackedProgram final : public Program { - public: - AttentionPreparePackedProgram() : Program{"AttentionPreparePacked"} {} - - Status GenerateShaderCode(ShaderHelper& shader) const override { - shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("weight", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddOutput("packed_qkv", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - - constexpr int TILE_SIZE = 12; - - shader.AdditionalImplementation() << "const TILE_SIZE = " << TILE_SIZE << "u;\n" - << "var tileInput: array;\n" - << "var tileWeight: array;\n"; - - shader.MainFunctionBody() - << "let batch_idx = workgroup_id.z;\n" - << "let m = global_id.y;\n" - << "let n = global_id.x;\n" - << "let inputOffset = batch_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n" - << "var value = input_value_t(0);\n" - << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" - << " tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];\n" - << " }\n" - << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" - << " let offset = n + (w + local_id.y) * uniforms.ldb;\n" - << " tileWeight[TILE_SIZE * local_id.y + local_id.x] = weight[offset];\n" - << " }\n" - << " workgroupBarrier();\n" - << " for (var k: u32 = 0u; k { - public: - AttentionPrepareProgram() : Program{"AttentionPrepare"} {} - - Status GenerateShaderCode(ShaderHelper& shader) const override { - shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("weight", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddOutput("output_q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddOutput("output_k", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - shader.AddOutput("output_v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); - - constexpr int TILE_SIZE = 12; - - shader.AdditionalImplementation() << "const TILE_SIZE = " << TILE_SIZE << "u;\n" - << "var tileInput: array;\n" - << "var tileWeightQ: array;\n" - << "var tileWeightK: array;\n" - << "var tileWeightV: array;\n"; - - shader.MainFunctionBody() //<< shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.M * uniforms.N") - << "let batchIndex = workgroup_id.z / uniforms.num_heads;\n" - << "let headNumber = workgroup_id.z % uniforms.num_heads;\n" - << "let m = global_id.y;\n" - << "let n = global_id.x;\n" - << "let inputOffset = batchIndex * (uniforms.M * uniforms.K) + m * uniforms.K;\n" - << "let biasOffsetQ = headNumber * uniforms.head_size;\n" - << "let biasOffsetK = uniforms.hidden_size + biasOffsetQ;\n" - << "let biasOffsetV = uniforms.hidden_size + biasOffsetK;\n" - << "var valueQ = input_value_t(0);\n" - << "var valueK = input_value_t(0);\n" - << "var valueV = input_value_t(0);\n" - << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" - << " tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];\n" - << " }\n" - << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" - << " let offset = n + (w + local_id.y) * uniforms.ldb;\n" - << " tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];\n" - << " tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];\n" - << " tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];\n" - << " }\n" - << " workgroupBarrier();\n" - << " for (var k: u32 = 0u; kDataType(), TensorShape(packed_qkv_shape)); - - // Use custom program to compute packed QKV - AttentionPreparePackedProgram program; - const int M = static_cast(parameters.sequence_length_); - const int K = static_cast(input->Shape().GetDims()[2]); // input hidden size - const int N = static_cast(packed_qkv_size); - const int ldb = N; // leading dimension of weight (assumed transposed) - - constexpr int TILE_SIZE = 12; - const int num_workgroups_x = (N + TILE_SIZE - 1) / TILE_SIZE; - const int num_workgroups_y = (M + TILE_SIZE - 1) / TILE_SIZE; - const int num_workgroups_z = parameters.batch_size_; - - program - .AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, - {weights, ProgramTensorMetadataDependency::TypeAndRank}, - {bias, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutput({&packed_qkv, ProgramTensorMetadataDependency::TypeAndRank}) - .SetDispatchGroupSize(num_workgroups_x, num_workgroups_y, num_workgroups_z) - .SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1) - .AddUniformVariables({ - {static_cast(M)}, - {static_cast(K)}, - {static_cast(N)}, - {static_cast(ldb)}, - }); - - ORT_RETURN_IF_ERROR(context.RunProgram(program)); - - // Split the packed QKV into Q, K, V in BNSH format - return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v); -} - -Status PrepareQKVVersion3(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& parameters, - const Tensor* input, const Tensor* weights, const Tensor* bias, - Tensor* q, Tensor* k, Tensor* v) { - constexpr int TILE_SIZE = 12; - const int M = parameters.sequence_length_; - const int K = parameters.input_hidden_size_; - const int N = parameters.head_size_; - - const uint32_t dispatch_x = (parameters.head_size_ + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t dispatch_y = (parameters.sequence_length_ + TILE_SIZE - 1) / TILE_SIZE; - const uint32_t dispatch_z = parameters.batch_size_ * parameters.num_heads_; - - AttentionPrepareProgram program{}; - program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank}, - {weights, ProgramTensorMetadataDependency::TypeAndRank}, - {bias, ProgramTensorMetadataDependency::TypeAndRank}}) - .AddOutputs({{q, ProgramTensorMetadataDependency::TypeAndRank}, - {k, ProgramTensorMetadataDependency::TypeAndRank}, - {v, ProgramTensorMetadataDependency::TypeAndRank}}) - .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) - .SetWorkgroupSize(TILE_SIZE, TILE_SIZE) - .AddUniformVariables({{static_cast(M)}, - {static_cast(K)}, - {static_cast(N)}, - {static_cast(parameters.num_heads_)}, - {static_cast(parameters.head_size_)}, - {static_cast(parameters.hidden_size_)}, - {static_cast(parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_)}}); - - return context.RunProgram(program); -} - Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { const Tensor* input = context.Input(0); const Tensor* weights = context.Input(1); From 0067a2d918c78a9171bae5fc20fd5fb23f42f2cb Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 22 Dec 2025 17:51:13 +0800 Subject: [PATCH 5/8] SplitPackedQKV with BSD format --- .../contrib_ops/webgpu/bert/attention.cc | 57 +++++++++++-------- .../contrib_ops/webgpu/bert/attention.h | 4 +- 2 files changed, 33 insertions(+), 28 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 9ed0f99158fc9..6fd0e5bf119a1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -72,7 +72,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h }; Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { - // Inputs: packed_qkv [B, S, D], outputs: Q, K, V [B, N, S, H] + // Inputs: packed_qkv [B, S, D], outputs: Q, K, V [B, S, D] const auto& packed_qkv = sh.AddInput("packed_qkv", ShaderUsage::UseOffsetToIndices | ShaderUsage::UseUniform); const auto& query = sh.AddOutput("query", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); @@ -84,26 +84,20 @@ Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { << " let d = packed_qkv_indices[2];\n" << " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n" << " if (d < uniforms.hidden_size) {\n" - << " let head = d / uniforms.head_size;\n" - << " let h = d % uniforms.head_size;\n" - << " " << query.SetByIndices("vec4(batch, head, seq, h)", "input_data") << ";\n" + << " " << query.SetByIndices("vec3(batch, seq, d)", "input_data") << ";\n" << " } else if (d < (uniforms.hidden_size + uniforms.hidden_size)) {\n" << " let kd = d - uniforms.hidden_size;\n" - << " let head = kd / uniforms.head_size;\n" - << " let h = kd % uniforms.head_size;\n" - << " " << key.SetByIndices("vec4(batch, head, seq, h)", "input_data") << ";\n" + << " " << key.SetByIndices("vec3(batch, seq, kd)", "input_data") << ";\n" << " } else {\n" << " let vd = d - uniforms.hidden_size - uniforms.hidden_size;\n" - << " let head = vd / uniforms.v_head_size;\n" - << " let h = vd % uniforms.v_head_size;\n" - << " " << value.SetByIndices("vec4(batch, head, seq, h)", "input_data") << ";\n" + << " " << value.SetByIndices("vec3(batch, seq, vd)", "input_data") << ";\n" << " }\n"; return Status::OK(); } Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) { - // Output Q, K, V in BNSH format + // Output Q, K, V in BSD format SplitPackedQKVProgram program; auto input_size = packedQKV->Shape().Size(); program @@ -111,8 +105,6 @@ Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const Webgpu .AddOutputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, {key, ProgramTensorMetadataDependency::TypeAndRank}, {val, ProgramTensorMetadataDependency::TypeAndRank}}) .AddUniformVariables({ {static_cast(params.hidden_size_)}, - {static_cast(params.head_size_)}, - {static_cast(params.v_head_size_)}, }) .SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); @@ -618,8 +610,8 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte const Tensor* input, const Tensor* weights, const Tensor* bias, Tensor* q, Tensor* k, Tensor* v) { // Use MatMul to compute packed QKV output: input * weights + bias - // Then use SplitPackedQKV to split into Q, K, V in BNSH format - // Returns Q, K, V in BNSH format for direct comparison with AttentionPrepareProgram + // Then use SplitPackedQKV to split into Q, K, V in BSD format + // Returns Q, K, V in BSD format // Create packed QKV tensor with shape [batch_size, sequence_length, hidden_size + hidden_size + v_hidden_size] const int64_t packed_qkv_size = parameters.hidden_size_ + parameters.hidden_size_ + parameters.v_hidden_size_; @@ -632,7 +624,7 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte // Call MatMul: packed_qkv = input * weights + bias ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true)); - // Output Q, K, V in BNSH format + // Output Q, K, V in BSD format return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v); } @@ -688,15 +680,16 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) ORT_NOT_IMPLEMENTED("present tensor not implemented for webgpu Attention"); } - // Create Q, K, V tensors in BNSH format from input * weights + bias - TensorShapeVector qkv_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); - TensorShapeVector v_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.v_head_size_}); - Tensor Q = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); - Tensor K = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); - Tensor V = context.CreateGPUTensor(input->DataType(), TensorShape(v_bnsh_shape)); + // Create Q, K, V tensors in BSD format from input * weights + bias + TensorShapeVector qkv_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_}); + TensorShapeVector v_bsd_shape({parameters.batch_size_, parameters.sequence_length_, parameters.v_hidden_size_}); + Tensor Q_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape)); + Tensor K_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bsd_shape)); + Tensor V_bsd = context.CreateGPUTensor(input->DataType(), TensorShape(v_bsd_shape)); - // Compute Q, K, V from input, weights, and bias (returns BNSH format) - ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q, &K, &V)); + // Compute Q, K, V from input, weights, and bias (returns BSD format) + ORT_RETURN_IF_ERROR(PrepareQKV(context, parameters, input, weights, bias, &Q_bsd, &K_bsd, &V_bsd)); + parameters.qkv_format_ = Q_K_V_BSNH; // Check if we can use flash attention // For Attention operator, we need to create present_key and present_value tensors for flash attention @@ -708,10 +701,24 @@ Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) if (CanApplyFlashAttention(nullptr, &present_key, &present_value, parameters, context)) { // FlashAttention supports Q_K_V_BSNH format directly - return ApplyFlashAttention(&Q, &K, &V, attention_bias, output, nullptr, &present_key, nullptr, &present_value, + return ApplyFlashAttention(&Q_bsd, &K_bsd, &V_bsd, attention_bias, output, nullptr, &present_key, nullptr, &present_value, parameters, context, nullptr); } + // For non-flash attention path, convert BSD to BNSH format + TensorShapeVector qkv_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_}); + TensorShapeVector v_bnsh_shape({parameters.batch_size_, parameters.num_heads_, parameters.sequence_length_, parameters.v_head_size_}); + Tensor Q = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); + Tensor K = context.CreateGPUTensor(input->DataType(), TensorShape(qkv_bnsh_shape)); + Tensor V = context.CreateGPUTensor(input->DataType(), TensorShape(v_bnsh_shape)); + + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, + parameters.head_size_, &Q_bsd, nullptr, 0, &Q)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, + parameters.head_size_, &K_bsd, nullptr, 0, &K)); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads_, parameters.sequence_length_, + parameters.v_head_size_, &V_bsd, nullptr, 0, &V)); + // Apply the actual attention computation return ApplyAttention(&Q, &K, &V, attention_bias, nullptr, nullptr, output, /* present_key */ nullptr, /* present_value */ nullptr, /* output_qk */ nullptr, parameters, context, nullptr, nullptr, -1); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 6540000845885..f5ada2d6b7996 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -38,9 +38,7 @@ class SplitPackedQKVProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, - {"head_size", ProgramUniformVariableDataType::Uint32}, - {"v_head_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}); }; class AttentionProbsProgram final : public Program { From efdc783d6687449b19099c04209c9b1c12a091a9 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 22 Dec 2025 18:06:01 +0800 Subject: [PATCH 6/8] make SplitPackedQKV work on GQA --- .../contrib_ops/webgpu/bert/attention.cc | 45 ++++++++++--------- .../contrib_ops/webgpu/bert/attention.h | 3 +- .../webgpu/bert/attention_common.h | 2 +- .../webgpu/bert/group_query_attention.cc | 2 +- 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 6fd0e5bf119a1..6843ae08ab4ba 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -78,36 +78,37 @@ Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); sh.MainFunctionBody() - << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" - << " let batch = packed_qkv_indices[0];\n" - << " let seq = packed_qkv_indices[1];\n" - << " let d = packed_qkv_indices[2];\n" - << " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n" - << " if (d < uniforms.hidden_size) {\n" - << " " << query.SetByIndices("vec3(batch, seq, d)", "input_data") << ";\n" - << " } else if (d < (uniforms.hidden_size + uniforms.hidden_size)) {\n" - << " let kd = d - uniforms.hidden_size;\n" - << " " << key.SetByIndices("vec3(batch, seq, kd)", "input_data") << ";\n" - << " } else {\n" - << " let vd = d - uniforms.hidden_size - uniforms.hidden_size;\n" - << " " << value.SetByIndices("vec3(batch, seq, vd)", "input_data") << ";\n" - << " }\n"; + << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" + << " let batch = packed_qkv_indices[0];\n" + << " let seq = packed_qkv_indices[1];\n" + << " let d = packed_qkv_indices[2];\n" + << " let input_data = " << packed_qkv.GetByOffset("global_idx") << ";\n" + << " if (d < uniforms.hidden_size) {\n" + << " " << query.SetByIndices("vec3(batch, seq, d)", "input_data") << ";\n" + << " } else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {\n" + << " let kd = d - uniforms.hidden_size;\n" + << " " << key.SetByIndices("vec3(batch, seq, kd)", "input_data") << ";\n" + << " } else {\n" + << " let vd = d - uniforms.hidden_size - uniforms.kv_hidden_size;\n" + << " " << value.SetByIndices("vec3(batch, seq, vd)", "input_data") << ";\n" + << " }\n"; return Status::OK(); } Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, - const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val) { - // Output Q, K, V in BSD format - SplitPackedQKVProgram program; - auto input_size = packedQKV->Shape().Size(); - program + const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size) { + // Output Q, K, V in BSD format + SplitPackedQKVProgram program; + auto input_size = packedQKV->Shape().Size(); + program .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}) .AddOutputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, {key, ProgramTensorMetadataDependency::TypeAndRank}, {val, ProgramTensorMetadataDependency::TypeAndRank}}) .AddUniformVariables({ - {static_cast(params.hidden_size_)}, + {static_cast(params.hidden_size_)}, + {static_cast(kv_hidden_size)}, }) .SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); - return context.RunProgram(program); + return context.RunProgram(program); } void InitVarStub(std::ostringstream& ss, bool has_seqlen_k) { @@ -625,7 +626,7 @@ Status PrepareQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtte ORT_RETURN_IF_ERROR(onnxruntime::webgpu::ComputeMatMul(&context, Activation(), matmul_inputs, &packed_qkv, true)); // Output Q, K, V in BSD format - return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v); + return SplitPackedQKV(context, parameters, &packed_qkv, q, k, v, parameters.hidden_size_); } Status Attention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index f5ada2d6b7996..b4afe7938883d 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -38,7 +38,8 @@ class SplitPackedQKVProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}); }; class AttentionProbsProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index d0fd17212a57a..4fc0e7826b49c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -123,7 +123,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor); Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, - const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val); + const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size); Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 59e0c7cfccc99..799f9d8a6a028 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -282,7 +282,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& qSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.hidden_size_})); kSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); vSplit = context.CreateGPUTensor(query->DataType(), TensorShape({parameters.batch_size_, parameters.sequence_length_, parameters.kv_hidden_size_})); - ORT_RETURN_IF_ERROR(SplitPackedQKV(context, parameters, query, &qSplit, &kSplit, &vSplit)); + ORT_RETURN_IF_ERROR(SplitPackedQKV(context, parameters, query, &qSplit, &kSplit, &vSplit, parameters.kv_hidden_size_)); parameters.is_packed_qkv_ = false; parameters.qkv_format_ = Q_K_V_BSNH; query = &qSplit; From ceea4d880aa881703b82e77f2c8a0ca3c28352e5 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Dec 2025 10:37:33 +0800 Subject: [PATCH 7/8] Add component support to SplitPackedQKV --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 15 ++++++++++----- onnxruntime/contrib_ops/webgpu/bert/attention.h | 3 ++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 6843ae08ab4ba..bdb00eeb6f5a2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -8,6 +8,7 @@ #include "contrib_ops/webgpu/bert/multihead_attention.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_utils.h" #include "core/providers/webgpu/math/matmul.h" using namespace onnxruntime::webgpu; using namespace ::onnxruntime::common; @@ -78,6 +79,7 @@ Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { const auto& key = sh.AddOutput("key", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); const auto& value = sh.AddOutput("val", ShaderUsage::UseSetByIndices | ShaderUsage::UseUniform); sh.MainFunctionBody() + << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.input_size") << " let packed_qkv_indices = " << packed_qkv.OffsetToIndices("global_idx") << ";\n" << " let batch = packed_qkv_indices[0];\n" << " let seq = packed_qkv_indices[1];\n" @@ -98,16 +100,19 @@ Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size) { // Output Q, K, V in BSD format + const int components = std::min(GetMaxComponents(params.hidden_size_), GetMaxComponents(kv_hidden_size)); SplitPackedQKVProgram program; auto input_size = packedQKV->Shape().Size(); + const uint32_t vectorized_input_size = static_cast(input_size / components); program - .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank}) - .AddOutputs({{query, ProgramTensorMetadataDependency::TypeAndRank}, {key, ProgramTensorMetadataDependency::TypeAndRank}, {val, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddInput({packedQKV, ProgramTensorMetadataDependency::TypeAndRank, components}) + .AddOutputs({{query, ProgramTensorMetadataDependency::TypeAndRank, components}, {key, ProgramTensorMetadataDependency::TypeAndRank, components}, {val, ProgramTensorMetadataDependency::TypeAndRank, components}}) .AddUniformVariables({ - {static_cast(params.hidden_size_)}, - {static_cast(kv_hidden_size)}, + {vectorized_input_size}, + {static_cast(params.hidden_size_ / components)}, + {static_cast(kv_hidden_size / components)}, }) - .SetDispatchGroupSize((input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + .SetDispatchGroupSize((vectorized_input_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index b4afe7938883d..b8fc0ce0cc055 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -38,7 +38,8 @@ class SplitPackedQKVProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"hidden_size", ProgramUniformVariableDataType::Uint32}, + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"input_size", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, {"kv_hidden_size", ProgramUniformVariableDataType::Uint32}); }; From 419e26309e2260e723820f7a8e8cba69af4a6684 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 23 Dec 2025 11:02:07 +0800 Subject: [PATCH 8/8] address comments from Copilot --- onnxruntime/contrib_ops/webgpu/bert/attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index bdb00eeb6f5a2..3064c81506068 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -100,7 +100,7 @@ Status SplitPackedQKVProgram::GenerateShaderCode(ShaderHelper& sh) const { Status SplitPackedQKV(onnxruntime::webgpu::ComputeContext& context, const WebgpuAttentionParameters& params, const Tensor* packedQKV, Tensor* query, Tensor* key, Tensor* val, int kv_hidden_size) { // Output Q, K, V in BSD format - const int components = std::min(GetMaxComponents(params.hidden_size_), GetMaxComponents(kv_hidden_size)); + const int components = std::min({GetMaxComponents(params.hidden_size_), GetMaxComponents(kv_hidden_size), GetMaxComponents(params.v_hidden_size_)}); SplitPackedQKVProgram program; auto input_size = packedQKV->Shape().Size(); const uint32_t vectorized_input_size = static_cast(input_size / components);