From 5b4ee70e85cc5c515f8daaa6331ca2467e05667e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 12 Mar 2025 21:29:27 -0700 Subject: [PATCH 1/3] [webgpu] allow overloads to Program::AddIndices --- onnxruntime/core/providers/webgpu/program.cc | 10 ---------- onnxruntime/core/providers/webgpu/program.h | 8 +++++--- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 976b7927ac3dd..a70c6e540d830 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -308,16 +308,6 @@ ProgramBase& ProgramBase::AddOutputs(std::initializer_list output return *this; } -ProgramBase& ProgramBase::AddIndices(const TensorShape& shape) { - indices_.emplace_back(shape); - return *this; -} - -ProgramBase& ProgramBase::AddIndices(TensorShape&& shape) { - indices_.emplace_back(shape); - return *this; -} - ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) { return SetDispatchGroupSize(x, 1, 1); } diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 95fef36144025..ea7d8ae5471af 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -271,9 +271,11 @@ class ProgramBase { // add multiple program outputs ProgramBase& AddOutputs(std::initializer_list outputs); // add a program variable for indices - ProgramBase& AddIndices(const TensorShape& shape); - // add a program variable for indices - ProgramBase& AddIndices(TensorShape&& shape); + template + ProgramBase& AddIndices(Args&&... args) { + indices_.emplace_back(std::forward(args)...); + return *this; + } // set the size of dispatch groups. Y and Z are 1 if not specified. ProgramBase& SetDispatchGroupSize(uint32_t x); From 49ddf81d3548b0034cd9f480e238bcc10e932eb7 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 12 Mar 2025 21:33:25 -0700 Subject: [PATCH 2/3] use r-value ref for params of AddIndices() when possible --- onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc | 2 +- .../core/providers/webgpu/math/binary_elementwise_ops.cc | 6 +++--- onnxruntime/core/providers/webgpu/tensor/expand.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index 5759e7c1232de..7e4dbe2a111cd 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -97,7 +97,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt } program.AddOutputs({{present_key, ProgramTensorMetadataDependency::Rank, components}, {present_value, ProgramTensorMetadataDependency::Rank, components}}) - .AddIndices(valid_present_shape); + .AddIndices(std::move(valid_present_shape)); program.SetDispatchGroupSize(onnxruntime::narrow(valid_kv_size + 63 / 64)) .SetWorkgroupSize(64) .CacheHint(has_past, parameters.qkv_format_, parameters.past_present_share_buffer_) diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 8a22e45f17047..13004af25726d 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -190,9 +190,9 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { // Mode Vectorize broadcast // cache hint: "V{a_rank};{b_rank};{output_rank}" program - .AddIndices(reshaped_output_shape) - .AddIndices(reshaped_lhs_shape) - .AddIndices(reshaped_rhs_shape) + .AddIndices(std::move(reshaped_output_shape)) + .AddIndices(std::move(reshaped_lhs_shape)) + .AddIndices(std::move(reshaped_rhs_shape)) .CacheHint("V"); } else { // Mode Broadcast diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 9bdebe2c1e0d3..3e831f9853451 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -53,7 +53,7 @@ Status Expand::ComputeInternal(ComputeContext& context) const { {data_size}, }); if (components_i != components_o) { - program.AddIndices(output_shape); + program.AddIndices(std::move(output_shape)); } return context.RunProgram(program); } From 970c10e3801a92b916c0f5e812035a0ef7b51abc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 12 Mar 2025 21:34:09 -0700 Subject: [PATCH 3/3] fix r-value ref for AddInputs and AddOutputs --- onnxruntime/core/providers/webgpu/program.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index a70c6e540d830..73291e1e93ff1 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -289,7 +289,7 @@ ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata) } ProgramBase& ProgramBase::AddInput(ProgramInput&& input) { - inputs_.emplace_back(input); + inputs_.emplace_back(std::move(input)); return *this; } @@ -299,7 +299,7 @@ ProgramBase& ProgramBase::AddInputs(std::initializer_list inputs) } ProgramBase& ProgramBase::AddOutput(ProgramOutput&& output) { - outputs_.emplace_back(output); + outputs_.emplace_back(std::move(output)); return *this; }