Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 45 additions & 129 deletions onnxruntime/core/providers/webgpu/math/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/webgpu/math/gemm.h"
#include "core/providers/webgpu/math/gemm_vec4.h"
#include "core/providers/webgpu/math/gemm_packed.h"

#include <vector>

Expand Down Expand Up @@ -38,130 +38,52 @@ WEBGPU_GEMM_VERSIONED_KERNEL(9, 10)
WEBGPU_GEMM_VERSIONED_KERNEL(11, 12)
WEBGPU_GEMM_KERNEL(13)

Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
const uint32_t TILE_SIZE = 16;

// Add shared memory arrays
shader.AdditionalImplementation() << "var<workgroup> tile_a: array<array<output_value_t, " << TILE_SIZE << ">, " << TILE_SIZE << ">;\n"
<< "var<workgroup> tile_b: array<array<output_value_t, " << TILE_SIZE << ">, " << TILE_SIZE << ">;\n\n";

Status GemmNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);

shader.MainFunctionBody() << " var value = output_value_t(0);\n\n"
<< " let tile_col_start = (workgroup_idx % uniforms.num_tile_n) * " << TILE_SIZE << "u;\n"
<< " let tile_row_start = (workgroup_idx / uniforms.num_tile_n) * " << TILE_SIZE << "u;\n";
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< " let m = global_idx / uniforms.N;\n"
<< " let n = global_idx % uniforms.N;\n"
<< " var value = output_value_t(0);\n"
<< "\n";

// When A or B is empty, we don't bind A and B. Because WebGPU doesn't support binding a zero-sized buffer.
if (need_handle_matmul_) {
const ShaderVariableHelper& A = shader.AddInput("A", ShaderUsage::UseUniform);
const ShaderVariableHelper& B = shader.AddInput("B", ShaderUsage::UseUniform);

shader.MainFunctionBody()
<< " let num_tiles = (uniforms.K - 1u) / " << TILE_SIZE << "u + 1u;\n"
<< " var k_start = 0u;\n"
<< " for (var t = 0u; t < num_tiles; t = t + 1u) {\n";

// Fill workgroup shared memory
if (transA_ && transB_) {
shader.MainFunctionBody() << " var col = tile_row_start + local_id.x;\n"
<< " var row = k_start + local_id.y;\n"
<< " if (col < uniforms.M && row < uniforms.K) {\n"
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.M + col") << ";\n"
<< " } else {\n"
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
<< " }\n\n"
<< " col = k_start + local_id.x;\n"
<< " row = tile_col_start + local_id.y;\n"
<< " if (col < uniforms.K && row < uniforms.N) {\n"
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.K + col") << ";\n"
<< " } else {\n"
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
<< " }\n";
} else if (transA_ && !transB_) {
shader.MainFunctionBody() << " var col = tile_row_start + local_id.x;\n"
<< " var row = k_start + local_id.y;\n"
<< " if (col < uniforms.M && row < uniforms.K) {\n"
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.M + col") << ";\n"
<< " } else {\n"
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
<< " }\n\n"
<< " col = tile_col_start + local_id.x;\n"
<< " row = k_start + local_id.y;\n"
<< " if (col < uniforms.N && row < uniforms.K) {\n"
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.N + col") << ";\n"
<< " } else {\n"
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
<< " }\n";
} else if (!transA_ && transB_) {
shader.MainFunctionBody() << " var col = k_start + local_id.x;\n"
<< " var row = tile_row_start + local_id.y;\n"
<< " if (col < uniforms.K && row < uniforms.M) {\n"
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.K + col") << ";\n"
<< " } else {\n"
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
<< " }\n\n"
<< " col = k_start + local_id.x;\n"
<< " row = tile_col_start + local_id.y;\n"
<< " if (col < uniforms.K && row < uniforms.N) {\n"
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.K + col") << ";\n"
<< " } else {\n"
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
<< " }\n";
} else {
shader.MainFunctionBody() << " var col = k_start + local_id.x;\n"
<< " var row = tile_row_start + local_id.y;\n"
<< " if (col < uniforms.K && row < uniforms.M) {\n"
<< " tile_a[local_id.y][local_id.x] = " << A.GetByOffset("row * uniforms.K + col") << ";\n"
<< " } else {\n"
<< " tile_a[local_id.y][local_id.x] = output_value_t(0);\n"
<< " }\n\n"
<< " col = tile_col_start + local_id.x;\n"
<< " row = k_start + local_id.y;\n"
<< " if (col < uniforms.N && row < uniforms.K) {\n"
<< " tile_b[local_id.y][local_id.x] = " << B.GetByOffset("row * uniforms.N + col") << ";\n"
<< " } else {\n"
<< " tile_b[local_id.y][local_id.x] = output_value_t(0);\n"
<< " }\n";
}

shader.MainFunctionBody() << " k_start = k_start + " << TILE_SIZE << "u;\n"
<< " workgroupBarrier();\n\n"
<< " for (var k = 0u; k < " << TILE_SIZE << "u; k = k + 1u) {\n";
shader.MainFunctionBody() << " for (var k = 0u; k < uniforms.K; k = k + 1u) {\n";

if (transA_ && transB_) {
shader.MainFunctionBody() << " value = value + tile_a[k][local_id.y] * tile_b[local_id.x][k];\n";
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
} else if (transA_ && !transB_) {
shader.MainFunctionBody() << " value = value + tile_a[k][local_id.y] * tile_b[k][local_id.x];\n";
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("k * uniforms.M + m")
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
} else if (!transA_ && transB_) {
shader.MainFunctionBody() << " value = value + tile_a[local_id.y][k] * tile_b[local_id.x][k];\n";
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
<< " * " << B.GetByOffset("n * uniforms.K + k") << ";\n";
} else {
shader.MainFunctionBody() << " value = value + tile_a[local_id.y][k] * tile_b[k][local_id.x];\n";
shader.MainFunctionBody() << " value = value + " << A.GetByOffset("m * uniforms.K + k")
<< " * " << B.GetByOffset("k * uniforms.N + n") << ";\n";
}

shader.MainFunctionBody() << " }\n"
<< " workgroupBarrier();\n"
<< " }\n\n";
shader.MainFunctionBody() << " }\n"
<< "\n";
}

// Calculate Alpha
if (alpha_) {
shader.MainFunctionBody() << " value = value * output_value_t(uniforms.alpha);\n";
}

shader.MainFunctionBody() << " let m = tile_row_start + local_id.y;\n"
<< " let n = tile_col_start + local_id.x;\n";

// Calculate Bias
if (need_handle_bias_) {
const ShaderVariableHelper& C = shader.AddInput("C", ShaderUsage::UseUniform);
shader.MainFunctionBody() << " value = value + output_value_t(uniforms.beta) * "
<< C.GetByOffset(C.BroadcastedIndicesToOffset("vec2(m, n)", output)) << ";\n";
}

// Write output
shader.MainFunctionBody() << " if (m < uniforms.M && n < uniforms.N) {\n"
<< " " << output.SetByOffset("m * uniforms.N + n", "value") << "\n"
<< " }\n";
shader.MainFunctionBody() << output.SetByOffset("global_idx", "value") << "\n";

return Status::OK();
}
Expand All @@ -182,14 +104,14 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input tensors A and B must be 2 dimensional.");
}

uint32_t M = onnxruntime::narrow<uint32_t>(transA_ ? A_shape[1] : A_shape[0]);
uint32_t K = onnxruntime::narrow<uint32_t>(transA_ ? A_shape[0] : A_shape[1]);
uint32_t N = onnxruntime::narrow<uint32_t>(transB_ ? B_shape[0] : B_shape[1]);

if ((transA_ ? A_shape[0] : A_shape[1]) != (transB_ ? B_shape[1] : B_shape[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inner dimensions of A and B must match.");
}

int64_t M = transA_ ? A_shape[1] : A_shape[0];
int64_t K = transA_ ? A_shape[0] : A_shape[1];
int64_t N = transB_ ? B_shape[0] : B_shape[1];

std::vector<int64_t> output_dims{M, N};
auto* Y = context.Output(0, output_dims);
int64_t output_size = Y->Shape().Size();
Expand All @@ -198,42 +120,36 @@ Status Gemm::ComputeInternal(ComputeContext& context) const {
return Status::OK();
}

// First try vec4 optimization if possible
if (CanApplyGemmVec4(A, B)) {
return ApplyGemmVec4(A, B, C, transA_, transB_, alpha_, beta_, context, Y);
}

// WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
bool need_handle_matmul = A_shape.Size() > 0 && B_shape.Size() > 0;
bool need_handle_bias = C && beta_;

GemmProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul};
if (M <= 8 && N <= 8 && K <= 8) {
// Use naive implementation for small matrices
GemmNaiveProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul};
if (need_handle_matmul) {
program.AddInputs({{A, ProgramTensorMetadataDependency::Type},
{B, ProgramTensorMetadataDependency::Type}});
}

if (need_handle_matmul) {
program.AddInputs({{A, ProgramTensorMetadataDependency::Type},
{B, ProgramTensorMetadataDependency::Type}});
}
if (need_handle_bias) {
program.AddInput({C, ProgramTensorMetadataDependency::Rank});
}

if (need_handle_bias) {
program.AddInput({C, ProgramTensorMetadataDependency::Rank});
program.CacheHint(alpha_, transA_, transB_)
.AddOutputs({{Y, ProgramTensorMetadataDependency::Type}})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.SetWorkgroupSize(WORKGROUP_SIZE)
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
{static_cast<uint32_t>(M)},
{static_cast<uint32_t>(N)},
{static_cast<uint32_t>(K)},
{alpha_},
{beta_}});
return context.RunProgram(program);
}

const uint32_t TILE_SIZE = 16;
const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE;

program.CacheHint(alpha_, transA_, transB_)
.AddOutputs({{Y, ProgramTensorMetadataDependency::Type}})
.SetDispatchGroupSize(num_tile_n * num_tile_m)
.SetWorkgroupSize(TILE_SIZE, TILE_SIZE)
.AddUniformVariables({{num_tile_n},
{M},
{N},
{K},
{alpha_},
{beta_}});

return context.RunProgram(program);
return ApplyGemmPacked(A, B, C, transA_, transB_, alpha_, beta_, context);
}

} // namespace webgpu
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/webgpu/math/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
namespace onnxruntime {
namespace webgpu {

class GemmProgram final : public Program<GemmProgram> {
class GemmNaiveProgram final : public Program<GemmNaiveProgram> {
public:
GemmProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul)
: Program{"Gemm"},
GemmNaiveProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul)
: Program{"GemmNaive"},
transA_{transA},
transB_{transB},
alpha_{alpha},
Expand All @@ -23,7 +23,7 @@ class GemmProgram final : public Program<GemmProgram> {
Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"num_tile_n", ProgramUniformVariableDataType::Uint32},
{"output_size", ProgramUniformVariableDataType::Uint32},
{"M", ProgramUniformVariableDataType::Uint32},
{"N", ProgramUniformVariableDataType::Uint32},
{"K", ProgramUniformVariableDataType::Uint32},
Expand Down
114 changes: 114 additions & 0 deletions onnxruntime/core/providers/webgpu/math/gemm_packed.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/math/gemm_packed.h"

#include "core/providers/webgpu/webgpu_utils.h"

#include "core/providers/webgpu/math/matmul_utils.h"
#include "core/providers/webgpu/math/gemm_utils.h"

namespace onnxruntime {
namespace webgpu {

Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);

// Each thread compute 4*4 elements
InlinedVector<int64_t> elements_per_thread = InlinedVector<int64_t>({4, 4, 1});

const std::string data_type = "output_element_t";

Check warning on line 20 in onnxruntime/core/providers/webgpu/math/gemm_packed.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/gemm_packed.cc:20: Add #include <string> for string [build/include_what_you_use] [4]

if (need_handle_matmul_) {
const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);

MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_);
}
if (is_vec4_) {
ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_));
} else {
ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_));
}
MatMulWriteFnSource(shader, output, need_handle_bias_, true, c_components_, output_components_, c_is_scalar_);

return Status::OK();
}

Status ApplyGemmPacked(const Tensor* a,
const Tensor* b,
const Tensor* c,
bool transA,
bool transB,
float alpha,
float beta,
ComputeContext& context) {
const auto& a_shape = a->Shape();
const auto& b_shape = b->Shape();

uint32_t M = onnxruntime::narrow<uint32_t>(transA ? a_shape[1] : a_shape[0]);
uint32_t K = onnxruntime::narrow<uint32_t>(transA ? a_shape[0] : a_shape[1]);
uint32_t N = onnxruntime::narrow<uint32_t>(transB ? b_shape[0] : b_shape[1]);

std::vector<int64_t> output_dims{M, N};

Check warning on line 53 in onnxruntime/core/providers/webgpu/math/gemm_packed.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/math/gemm_packed.cc:53: Add #include <vector> for vector<> [build/include_what_you_use] [4]
auto* y = context.Output(0, output_dims);
int64_t output_size = y->Shape().Size();

if (output_size == 0) {
return Status::OK();
}

// WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty.
bool need_handle_matmul = a_shape.Size() > 0 && b_shape.Size() > 0;
bool need_handle_bias = c && beta;

const bool is_vec4 = a_shape[1] % 4 == 0 && b_shape[1] % 4 == 0;

// Components for A, B
int components = is_vec4 ? 4 : 1;
// Components for Y
int output_components = (is_vec4 && N % 4 == 0) ? 4 : 1;
// Components for C.
int c_components = 1;

bool c_is_scalar = false;
if (need_handle_bias) {
const auto& c_shape = c->Shape();
int64_t c_last_dim = c_shape[c_shape.NumDimensions() - 1];
// `C` in GEMM might be broadcast to the output, and broadcasting requires the components to be consistent.
// So we use vec4 for C when its last dimension is N, and the output is also a vec4.
c_components = (c_last_dim == N && output_components == 4) ? 4 : 1;
c_is_scalar = c_shape.Size() == 1;
}

GemmProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, output_components, is_vec4};

if (need_handle_matmul) {
program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, components},
{b, ProgramTensorMetadataDependency::TypeAndRank, components}});
}

if (need_handle_bias) {
program.AddInput({c, ProgramTensorMetadataDependency::TypeAndRank, c_components});
}

const uint32_t TILE_SIZE = 32;
const uint32_t num_tile_n = (N + TILE_SIZE - 1) / TILE_SIZE;
const uint32_t num_tile_m = (M + TILE_SIZE - 1) / TILE_SIZE;

program.CacheHint(alpha, transA, transB, c_is_scalar)
.AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}})
.SetDispatchGroupSize(num_tile_n, num_tile_m, 1)
.SetWorkgroupSize(GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_X, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Y, GemmProgram::MATMUL_PACKED_WORKGROUP_SIZE_Z)
.AddUniformVariables({{alpha},
{beta},
{M}, /* dim_a_outer */
{N}, /* dim_b_outer */
{K}} /*dim_inner */
);

return context.RunProgram(program);
}

} // namespace webgpu
} // namespace onnxruntime
Loading
Loading