Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions onnxruntime/core/providers/webgpu/math/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,18 @@ static std::string MaxVector(const std::string& name, int components) {

Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
// Add input and output variables
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias);
const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
int components = input.NumComponents();

const std::string thread_max_decl = is_fp32_
? "var thread_max = x_value_t(-3.402823e+38f);\n"
: "var thread_max = x_value_t(-65504.0h);\n";
const std::string thread_max_decl = "var thread_max = f32_value_t(-3.402823e+38f);\n";

// Define shared memory for row max and row sum
shader.AdditionalImplementation()
<< "var<workgroup> row_max_shared : x_value_t;\n"
<< "var<workgroup> row_sum_shared : x_value_t;\n"
<< "var<workgroup> thread_shared : array<x_value_t, " << wg_ << ">;\n";
<< "alias f32_value_t = " << (components == 4 ? "vec4<f32>" : (components == 2 ? "vec2<f32>" : "f32")) << ";\n"
<< "var<workgroup> row_max_shared : f32_value_t;\n"
<< "var<workgroup> row_sum_shared : f32_value_t;\n"
<< "var<workgroup> thread_shared : array<f32_value_t, " << wg_ << ">;\n";

// Define helper functions to get and set values
shader.AdditionalImplementation()
Expand All @@ -97,7 +96,7 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< thread_max_decl
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let value = getValue(row, col, row_stride);\n"
<< " thread_max = max(thread_max, value);\n"
<< " thread_max = max(thread_max, f32_value_t(value));\n"
<< " }\n"
<< " if (lindex < cols) {\n"
<< " thread_shared[lindex] = thread_max;\n"
Expand All @@ -114,14 +113,14 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " row_max_shared = x_value_t(" << MaxVector("thread_shared[0]", components) << ");\n"
<< " row_max_shared = f32_value_t(" << MaxVector("thread_shared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Find the row's sum of exponentials
<< " var thread_sum = x_value_t(0.0);\n"
<< " var thread_sum = f32_value_t(0.0);\n"
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let sub_exp = exp(getValue(row, col, row_stride) - row_max_shared);\n"
<< " let sub_exp = exp(f32_value_t(getValue(row, col, row_stride)) - row_max_shared);\n"
<< " thread_sum += sub_exp;\n"
<< " }\n"
<< " thread_shared[lindex] = thread_sum;\n"
Expand All @@ -135,14 +134,14 @@ Status SoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " workgroupBarrier();\n"
<< " }\n"
<< " if (lindex == 0) {\n"
<< " row_sum_shared = x_value_t(" << SumVector("thread_shared[0]", components) << ");\n"
<< " row_sum_shared = f32_value_t(" << SumVector("thread_shared[0]", components) << ");\n"
<< " }\n"
<< " workgroupBarrier();\n"

// Calculate the final value for each element in the row
<< " for (var col = lindex; col < cols; col += wg) {\n"
<< " let value = exp(getValue(row, col, row_stride) - row_max_shared) / row_sum_shared;\n"
<< " setValue(row, col, row_stride, value);\n"
<< " let value = exp(f32_value_t(getValue(row, col, row_stride)) - row_max_shared) / row_sum_shared;\n"
<< " setValue(row, col, row_stride, result_value_t(value));\n"
<< " }\n";

return Status::OK();
Expand Down
19 changes: 11 additions & 8 deletions onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ std::unordered_map<std::string, ReduceOpType> reduce_op_types = {
std::unordered_map<ReduceOpType, std::string> reduce_op_code_map = {
{ReduceOpType::Max, "select(bestValue, candidate, candidate > bestValue)"},
{ReduceOpType::Min, "select(bestValue, candidate, candidate < bestValue)"},
{ReduceOpType::Mean, "bestValue + candidate"},
{ReduceOpType::Mean, "bestValue + best_value_t(candidate)"},
{ReduceOpType::Sum, "bestValue + candidate"},
{ReduceOpType::Prod, "bestValue * candidate"},
{ReduceOpType::SumSquare, "bestValue + candidate * candidate"},
Expand All @@ -133,7 +133,7 @@ std::unordered_map<ReduceOpType, std::string> reduce_op_code_map = {
std::unordered_map<ReduceOpType, std::string> reduce_op_shared_code_map = {
{ReduceOpType::Max, "select(bestValue, candidate, candidate > bestValue)"},
{ReduceOpType::Min, "select(bestValue, candidate, candidate < bestValue)"},
{ReduceOpType::Mean, "bestValue + candidate"},
{ReduceOpType::Mean, "bestValue + best_value_t(candidate)"},
{ReduceOpType::Sum, "bestValue + candidate"},
{ReduceOpType::Prod, "bestValue * candidate"},
{ReduceOpType::SumSquare, "bestValue + candidate"},
Expand All @@ -159,7 +159,7 @@ std::unordered_map<ReduceOpType, std::string> reduce_op_init_values_map = {
std::unordered_map<ReduceOpType, std::string> reduce_op_output_values_map = {
{ReduceOpType::Max, "bestValue"},
{ReduceOpType::Min, "bestValue"},
{ReduceOpType::Mean, "bestValue"},
{ReduceOpType::Mean, "bestValue / best_value_t(uniforms.reduceSize)"},
{ReduceOpType::Sum, "bestValue"},
{ReduceOpType::Prod, "bestValue"},
{ReduceOpType::SumSquare, "bestValue"},
Expand Down Expand Up @@ -204,7 +204,7 @@ Status ReduceNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
bool reduce_on_all_axes = no_op_with_empty_axes_ == false && axes_.empty();
std::string loop_header = code.loop_header_.find("first_element") == std::string::npos ? code.loop_header_ : "let first_element = " + input.GetByIndices("input_indices") + ";\n" + code.loop_header_ + "\n";
std::string loop_body = "let current_element: input_value_t = " + input.GetByIndices("input_indices") + ";\n" + code.loop_body_;
std::string loop_body = "let current_element: output_value_t = " + input.GetByIndices("input_indices") + ";\n" + code.loop_body_;
std::string loop_footer = code.loop_footer_;
const auto input_rank = input.Rank();
for (int i = 0, l = 0; i < input_rank; ++i) {
Expand Down Expand Up @@ -248,14 +248,17 @@ Status ReduceNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
Status ReduceSharedProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("_A", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
int components = input.NumComponents();
shader.AdditionalImplementation()
<< "var<workgroup> aBestValues : array<output_value_t, " << workgroup_size_ << ">;\n\n"
<< "alias f32_value_t = " << (components == 4 ? "vec4<f32>" : (components == 2 ? "vec2<f32>" : "f32")) << ";\n"
<< "alias best_value_t = " << (reduce_op_type_ == ReduceOpType::Mean ? "f32_value_t" : "output_value_t") << ";\n"
<< "var<workgroup> aBestValues : array<best_value_t, " << workgroup_size_ << ">;\n\n"
<< "fn DIV_CEIL(a : u32, b : u32) -> u32 {\n"
<< " return ((a - 1u) / b + 1u);\n"
<< "}\n";
shader.MainFunctionBody() << "let outputIndex = global_idx / " << workgroup_size_ << ";\n"
<< "let offset = outputIndex * uniforms.reduceSize;\n"
<< "var bestValue = output_value_t(" << reduce_op_init_values_map[reduce_op_type_] << ");\n"
<< "var bestValue = best_value_t(" << reduce_op_init_values_map[reduce_op_type_] << ");\n"
<< "let length = uniforms.reduceSize;\n"
<< "for (var k = local_idx; k < length; k += " << workgroup_size_ << ") {\n"
<< " let candidate = output_value_t(" << input.GetByOffset("offset + k") << ");\n"
Expand All @@ -268,14 +271,14 @@ Status ReduceSharedProgram::GenerateShaderCode(ShaderHelper& shader) const {
<< " let interval = DIV_CEIL(reduceSize, 2u);\n"
<< " if (local_idx < currentSize) {\n"
<< " let candidate = aBestValues[local_idx + interval];\n"
<< " bestValue = " << reduce_op_shared_code_map[reduce_op_type_] << ";\n"
<< " bestValue = best_value_t(" << reduce_op_shared_code_map[reduce_op_type_] << ");\n"
<< " aBestValues[local_idx] = bestValue;\n"
<< " }\n"
<< " reduceSize = interval;\n"
<< " workgroupBarrier();\n"
<< "}\n"
<< "if (local_idx == 0) {\n"
<< " let outputValue = output_value_t(" << (reduce_op_type_ == ReduceOpType::Mean ? "(bestValue / output_element_t(uniforms.reduceSize))" : reduce_op_output_values_map[reduce_op_type_]) << ");\n"
<< " let outputValue = output_value_t(" << reduce_op_output_values_map[reduce_op_type_] << ");\n"
<< " " << output.SetByOffset("outputIndex", "outputValue") << ";\n"
<< "}\n";
return Status::OK();
Expand Down
Loading