Skip to content

Commit

Permalink
WebNN: Implement the build method of matmul operation
Browse files Browse the repository at this point in the history
This CL implements the method of MLGraphBuilder that builds matmul
operation and adds validation under //components/.

Bug: 1273291
Change-Id: I12e92bd723be12c5004707f844d94ca7368e8a72
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4940628
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Commit-Queue: Lisha Guo <lisha.guo@intel.com>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Cr-Commit-Position: refs/heads/main@{#1215336}
  • Loading branch information
lisa0314 authored and Chromium LUCI CQ committed Oct 26, 2023
1 parent 275d551 commit fba6097
Show file tree
Hide file tree
Showing 21 changed files with 460 additions and 130 deletions.
61 changes: 61 additions & 0 deletions components/ml/webnn/graph_validation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,67 @@ base::expected<Operand, std::string> ValidatePadAndInferOutput(
return Operand(input.data_type, std::move(output_shape));
}

base::expected<Operand, std::string> ValidateMatmulAndInferOutput(
const Operand& a,
const Operand& b) {
if (a.data_type != b.data_type) {
return base::unexpected("The types of first two inputs don't match.");
}

std::vector<uint32_t> a_dimensions = a.dimensions;
std::vector<uint32_t> b_dimensions = b.dimensions;

// Based on the WG discussion:
// https://github.com/webmachinelearning/webnn/issues/470, prototype the
// matmul without 1-D input tensors support.
if (a_dimensions.size() < 2 || b_dimensions.size() < 2) {
return base::unexpected(
"The rank of input must be larger than or equal to 2.");
}

// The number of columns in the first matrix must be equal to the number of
// rows in the second matrix.
const uint32_t a_cols = a_dimensions[a_dimensions.size() - 1];
const uint32_t a_rows = a_dimensions[a_dimensions.size() - 2];
const uint32_t b_cols = b_dimensions[b_dimensions.size() - 1];
const uint32_t b_rows = b_dimensions[b_dimensions.size() - 2];
if (a_cols != b_rows) {
return base::unexpected(base::StringPrintf(
"The number of columns (%u) in the first matrix isn't equal to "
"the number of rows (%u) in the second matrix.",
a_cols, b_rows));
}

size_t output_rank = std::max(a_dimensions.size(), b_dimensions.size());
std::vector<uint32_t> output_dimensions;
// Figure out the output shape by broadcasting all the dimensions except the
// last two. The output is 2-D tensor of shape [M, N].
if (a_dimensions.size() > 2 && b_dimensions.size() > 2) {
std::vector<uint32_t> sliced_a_dimensions(a_dimensions.begin(),
a_dimensions.end() - 2);
std::vector<uint32_t> sliced_b_dimensions(b_dimensions.begin(),
b_dimensions.end() - 2);
absl::optional<std::vector<uint32_t>> optional_output_dimensions =
BroadcastShapes(sliced_a_dimensions, sliced_b_dimensions, true);
if (!optional_output_dimensions) {
return base::unexpected("The matmul input shapes are not broadcastable.");
}
output_dimensions = *optional_output_dimensions;
output_dimensions.push_back(a_rows);
output_dimensions.push_back(b_cols);
} else if (a_dimensions.size() == 2 && b_dimensions.size() == 2) {
output_dimensions.push_back(a_rows);
output_dimensions.push_back(b_cols);
} else {
output_dimensions =
a_dimensions.size() > b_dimensions.size() ? a_dimensions : b_dimensions;
output_dimensions[output_rank - 2] = a_rows;
output_dimensions[output_rank - 1] = b_cols;
}
CHECK_EQ(output_rank, output_dimensions.size());
return Operand(a.data_type, std::move(output_dimensions));
}

base::expected<Operand, std::string> ValidatePool2dAndInferOutput(
const Operand& input,
const Pool2dAttributes& attributes) {
Expand Down
6 changes: 6 additions & 0 deletions components/ml/webnn/graph_validation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ base::expected<Operand, std::string> ValidatePadAndInferOutput(
base::span<const uint32_t> beginning_padding,
base::span<const uint32_t> ending_padding);

// Validate and infer output information of matmul operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-matmul
base::expected<Operand, std::string> ValidateMatmulAndInferOutput(
const Operand& a,
const Operand& b);

// Validate and infer output information of 2-D pooling operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-pool2d
base::expected<Operand, std::string> ValidatePool2dAndInferOutput(
Expand Down
28 changes: 28 additions & 0 deletions third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,34 @@ MLActivation* MLGraphBuilder::leakyRelu(const MLLeakyReluOptions* options,
this, MLOperator::OperatorKind::kLeakyRelu, options);
}

MLOperand* MLGraphBuilder::matmul(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state) {
auto validated_output = webnn::ValidateMatmulAndInferOutput(
ConvertToComponentOperand(a), ConvertToComponentOperand(b));
if (!validated_output.has_value()) {
exception_state.ThrowDOMException(
DOMExceptionCode::kDataError,
WTF::String::FromUTF8(validated_output.error()));
return nullptr;
}
// Create matmul operator and its output operand. Connect the matmul operator
// to its input and output operands.
auto* matmul =
MakeGarbageCollected<MLOperator>(this, MLOperator::OperatorKind::kMatmul);
auto output = MLOperand::ValidateAndCreateOutput(
this, ComponentOperandTypeToBlink(validated_output.value().data_type),
Vector<uint32_t>(validated_output.value().dimensions), matmul);
if (!output.has_value()) {
exception_state.ThrowDOMException(DOMExceptionCode::kDataError,
output.error());
return nullptr;
}
HeapVector<Member<const MLOperand>> inputs = {a, b};
matmul->Connect(std::move(inputs), {output.value()});
return output.value();
}

MLOperand* MLGraphBuilder::pad(const MLOperand* input,
const Vector<uint32_t>& beginning_padding,
const Vector<uint32_t>& ending_padding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ class MODULES_EXPORT MLGraphBuilder final : public ScriptWrappable {
MLActivation* leakyRelu(const MLLeakyReluOptions* options,
ExceptionState& exception_state);

MLOperand* matmul(const MLOperand* a,
const MLOperand* b,
ExceptionState& exception_state);

MLOperand* pad(const MLOperand* input,
const Vector<uint32_t>& beginningPadding,
const Vector<uint32_t>& endingPadding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ dictionary MLSplitOptions {
[RaisesException] MLOperand hardSwish(MLOperand x);
[RaisesException] MLActivation hardSwish();

[RaisesException] MLOperand matmul(MLOperand a, MLOperand b);

[RaisesException] MLOperand leakyRelu(MLOperand x, optional MLLeakyReluOptions options = {});
[RaisesException] MLActivation leakyRelu(optional MLLeakyReluOptions options = {});

Expand Down
133 changes: 133 additions & 0 deletions third_party/blink/renderer/modules/ml/webnn/ml_graph_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4314,6 +4314,139 @@ TEST_F(MLGraphBuilderTest, TanhTest) {
}
}

MLOperand* BuildMatmul(V8TestingScope& scope,
MLGraphBuilder* builder,
const MLOperand* a,
const MLOperand* b) {
auto* output = builder->matmul(a, b, scope.GetExceptionState());
EXPECT_NE(output, nullptr);
EXPECT_EQ(output->Kind(), MLOperand::OperandKind::kOutput);
EXPECT_EQ(output->Type(), a->Type());
auto* matmul = output->Operator();
EXPECT_NE(matmul, nullptr);
EXPECT_EQ(matmul->Kind(), MLOperator::OperatorKind::kMatmul);
EXPECT_EQ(matmul->IsConnected(), true);
EXPECT_EQ(matmul->Options(), nullptr);
return output;
}

TEST_F(MLGraphBuilderTest, MatmulTest) {
V8TestingScope scope;
auto* builder =
CreateMLGraphBuilder(scope.GetExecutionContext(), scope.GetScriptState(),
scope.GetExceptionState());
{
// Test throwing exception when the rank of input is smaller than 2.
auto* a = BuildInput(builder, "a", {2}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b = BuildInput(builder, "b", {2}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* output = builder->matmul(a, b, scope.GetExceptionState());
EXPECT_EQ(output, nullptr);
EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
DOMExceptionCode::kDataError);
EXPECT_EQ(scope.GetExceptionState().Message(),
"The rank of input must be larger than or equal to 2.");
}
{
// Test building matmul with 2-D * 4-D inputs.
auto* a = BuildInput(builder, "a", {1, 4}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b =
BuildInput(builder, "b", {2, 2, 4, 2}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* output = BuildMatmul(scope, builder, a, b);
EXPECT_EQ(output->Dimensions(), Vector<uint32_t>({2, 2, 1, 2}));
}
{
// Test building matmul with 2-D * 2-D inputs.
auto* a = BuildInput(builder, "a", {4, 2}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b = BuildInput(builder, "b", {2, 3}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* output = BuildMatmul(scope, builder, a, b);
EXPECT_EQ(output->Dimensions(), Vector<uint32_t>({4, 3}));
}
{
// Test building matmul with 3-D * 3-D inputs using broadcast.
auto* a =
BuildInput(builder, "a", {2, 3, 4}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b =
BuildInput(builder, "b", {1, 4, 1}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* output = BuildMatmul(scope, builder, a, b);
EXPECT_EQ(output->Dimensions(), Vector<uint32_t>({2, 3, 1}));
}
{
// Test building matmul with 4-D * 3-D inputs using broadcast.
auto* a =
BuildInput(builder, "a", {2, 2, 3, 4}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b =
BuildInput(builder, "b", {1, 4, 5}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* output = BuildMatmul(scope, builder, a, b);
EXPECT_EQ(output->Dimensions(), Vector<uint32_t>({2, 2, 3, 5}));
}
{
// Test building matmul with 3-D * 3-D inputs.
auto* a =
BuildInput(builder, "a", {2, 3, 4}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b =
BuildInput(builder, "b", {2, 4, 5}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* output = BuildMatmul(scope, builder, a, b);
EXPECT_EQ(output->Dimensions(), Vector<uint32_t>({2, 3, 5}));
}
{
// Test throwing exception when the types of first two inputs don't match.
auto* a = BuildInput(builder, "a", {2, 3}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b = BuildInput(builder, "b", {3, 4}, V8MLOperandType::Enum::kInt32,
scope.GetExceptionState());
auto* output = builder->matmul(a, b, scope.GetExceptionState());
;
EXPECT_EQ(output, nullptr);
EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
DOMExceptionCode::kDataError);
EXPECT_EQ(scope.GetExceptionState().Message(),
"The types of first two inputs don't match.");
}
{
// Test throwing exception when the number in first matrix mismatches with
// the numbers in second matrix is .
auto* a = BuildInput(builder, "a", {2, 3}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b = BuildInput(builder, "b", {2, 4}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* output = builder->matmul(a, b, scope.GetExceptionState());
;
EXPECT_EQ(output, nullptr);
EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
DOMExceptionCode::kDataError);
EXPECT_EQ(scope.GetExceptionState().Message(),
"The number of columns (3) in the first matrix isn't equal to "
"the number of rows (2) in the second matrix.");
}
{
// Test throwing exception when the input shapes are not broadcastable..
auto* a =
BuildInput(builder, "a", {3, 3, 4}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* b =
BuildInput(builder, "b", {2, 4, 1}, V8MLOperandType::Enum::kFloat32,
scope.GetExceptionState());
auto* output = builder->matmul(a, b, scope.GetExceptionState());
EXPECT_EQ(output, nullptr);
EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
DOMExceptionCode::kDataError);
EXPECT_EQ(scope.GetExceptionState().Message(),
"The matmul input shapes are not broadcastable.");
}
}

class FakeMLGraphBackend final : public MLGraph {
public:
// Create and build a FakeMLGraphBackend object. Resolve the promise with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ base::expected<OperationPtr, String> ConvertToMojoOperation(
case MLOperator::OperatorKind::kCeil:
case MLOperator::OperatorKind::kFloor:
case MLOperator::OperatorKind::kNeg:
case MLOperator::OperatorKind::kMatmul:
return base::unexpected(MLOperator::OperatorKindToString(op->Kind()) +
" is not implemented.");
}
Expand Down
2 changes: 2 additions & 0 deletions third_party/blink/renderer/modules/ml/webnn/ml_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ String MLOperator::OperatorKindToString(MLOperator::OperatorKind kind) {
return "averagePool2d";
case MLOperator::OperatorKind::kMaxPool2d:
return "maxPool2d";
case MLOperator::OperatorKind::kMatmul:
return "matmul";
case MLOperator::OperatorKind::kPad:
return "pad";
case MLOperator::OperatorKind::kPow:
Expand Down
1 change: 1 addition & 0 deletions third_party/blink/renderer/modules/ml/webnn/ml_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class MODULES_EXPORT MLOperator : public GarbageCollected<MLOperator> {
kGemm,
kHardSwish,
kAveragePool2d,
kMatmul,
kMaxPool2d,
kPad,
kPow,
Expand Down

This file was deleted.

0 comments on commit fba6097

Please sign in to comment.