Skip to content

Commit

Permalink
WebNN: Define prelu operation in mojo
Browse files Browse the repository at this point in the history
The prelu operation calculates the parametric version of rectified
linear function (Parametric ReLU) on the input tensor element-wise.

This CL moves the validation steps of prelu to //components/,
implements CreatePreluOperation() for creating prelu mojo operation and
adds operator validation on service side.

Bug: 1273291
Change-Id: I62f68dd2c1da6d735da2020affa2fa9a322cb657
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4950626
Commit-Queue: Shiyi Zou <shiyi.zou@intel.com>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Reviewed-by: Alex Gough <ajgo@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1212567}
  • Loading branch information
shiyi9801 authored and Chromium LUCI CQ committed Oct 20, 2023
1 parent a35e5fc commit 5cacc81
Show file tree
Hide file tree
Showing 11 changed files with 338 additions and 19 deletions.
21 changes: 21 additions & 0 deletions components/ml/webnn/graph_validation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,27 @@ base::expected<Operand, std::string> ValidateConcatAndInferOutput(
return Operand(output_type, std::move(output_shape));
}

base::expected<Operand, std::string> ValidatePreluAndInferOutput(
const Operand& input,
const Operand& slope) {
if (input.data_type != slope.data_type) {
return base::unexpected(
"The type of slope doesn't match the type of input.");
}
if (!IsFloatingPointType(input.data_type)) {
return base::unexpected(
"The type of input and slope must be one of the floating point types.");
}
// BroadcastShape unidirectionally broadcasts slope.dimensions to
// input.dimensions.
if (!BroadcastShapes(slope.dimensions, input.dimensions, false)) {
return base::unexpected(
"The shape of slope is not broadcastable to the shape of input.");
}

return Operand(input.data_type, input.dimensions);
}

base::expected<Operand, std::string> ValidateTransposeAndInferOutput(
const Operand& input,
base::span<const uint32_t> permutation) {
Expand Down
6 changes: 6 additions & 0 deletions components/ml/webnn/graph_validation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,12 @@ base::expected<Operand, std::string> ValidateConcatAndInferOutput(
const std::vector<Operand>& input,
const uint32_t axis);

// Validate prelu operator defined in WebIDL here:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-prelu
base::expected<Operand, std::string> ValidatePreluAndInferOutput(
const Operand& input,
const Operand& slope);

// Validate transpose operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-transpose
base::expected<Operand, std::string> ValidateTransposeAndInferOutput(
Expand Down
2 changes: 2 additions & 0 deletions services/webnn/dml/graph_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ std::string OpTagToString(Operation::Tag tag) {
return "pad";
case Operation::Tag::kPool2d:
return "pool2d";
case Operation::Tag::kPrelu:
return "prelu";
case Operation::Tag::kResample2d:
return "resample2d";
case Operation::Tag::kRelu:
Expand Down
15 changes: 15 additions & 0 deletions services/webnn/public/mojom/webnn_graph.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ struct GemmAttributes {
bool b_transpose = false;
};

// Represents a parametric relu operation whose calculation follows the
// expression max(0, x) + slope ∗ min(0, x).
struct Prelu {
// The id of input operand is used to get the `Operand` description from
// `GraphInfo.id_to_operand_map`.
uint64 input_operand_id;
// The id of slope operand is used to get the `Operand` description from
// `GraphInfo.id_to_operand_map`.
uint64 slope_operand_id;
// The id of output operand is used to get the `Operand` description from
// `GraphInfo.id_to_operand_map`.
uint64 output_operand_id;
};

// Corresponds to `MLOperand relu(MLOperand x)` that compute the rectified
// linear function of the input tensor.
struct Relu {
Expand Down Expand Up @@ -354,6 +368,7 @@ union Operation {
ElementWiseBinary element_wise_binary;
Pad pad;
Pool2d pool2d;
Prelu prelu;
Relu relu;
Resample2d resample2d;
Slice slice;
Expand Down
24 changes: 24 additions & 0 deletions services/webnn/webnn_graph_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,28 @@ bool ValidatePool2d(const IdToOperandMap& id_to_operand_map,
return true;
}

bool ValidatePrelu(const IdToOperandMap& id_to_operand_map,
const mojom::PreluPtr& prelu) {
auto* input = GetMojoOperand(id_to_operand_map, prelu->input_operand_id);
auto* output = GetMojoOperand(id_to_operand_map, prelu->output_operand_id);
auto* slope = GetMojoOperand(id_to_operand_map, prelu->slope_operand_id);
if (!input || !output || !slope || output == input || output == slope) {
// The prelu operator is invalid.
return false;
}

auto validated_output = ValidatePreluAndInferOutput(
ConvertToComponentOperand(input), ConvertToComponentOperand(slope));
if (!validated_output.has_value()) {
return false;
}
if (validated_output != ConvertToComponentOperand(output)) {
return false;
}

return true;
}

bool ValidateRelu(const IdToOperandMap& id_to_operand_map,
const mojom::ReluPtr& relu) {
auto* input = GetMojoOperand(id_to_operand_map, relu->input_operand_id);
Expand Down Expand Up @@ -654,6 +676,8 @@ bool ValidateOperation(const IdToOperandMap& id_to_operand_map,
return ValidatePad(id_to_operand_map, operation->get_pad());
case mojom::Operation::Tag::kPool2d:
return ValidatePool2d(id_to_operand_map, operation->get_pool2d());
case mojom::Operation::Tag::kPrelu:
return ValidatePrelu(id_to_operand_map, operation->get_prelu());
case mojom::Operation::Tag::kResample2d:
return ValidateResample2d(id_to_operand_map, operation->get_resample2d());
case mojom::Operation::Tag::kRelu:
Expand Down
122 changes: 122 additions & 0 deletions services/webnn/webnn_graph_impl_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1181,6 +1181,128 @@ TEST_F(WebNNGraphImplTest, Pool2dTest) {
}
}

struct PreluTester {
OperandInfo input;
OperandInfo slope;
OperandInfo output;
bool expected;

void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t slope_operand_id =
builder.BuildInput("slope", slope.dimensions, slope.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildPrelu(input_operand_id, slope_operand_id, output_operand_id);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};

TEST_F(WebNNGraphImplTest, PreluTest) {
{
// Test prelu operator when the input and the slope have the same shape.
PreluTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.slope = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.expected = true}
.Test();
}
{
// Test prelu operator with a broadcastable slope.
PreluTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.slope = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.expected = true}
.Test();
}
{
// Test the invalid graph with an invalid slope.
PreluTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.slope = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the datatype isn't floating point.
PreluTester{.input = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {3, 2, 5}},
.slope = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {3, 2, 5}},
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {3, 2, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the slope datatype doesn't match the input's
// datatype.
PreluTester{.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {3, 2, 5}},
.slope = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {3, 2, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the output datatype doesn't match the input's
// datatype.
PreluTester{.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {3, 2, 5}},
.slope = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {3, 2, 5}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5}},
.expected = false}
.Test();
}
{
// Test the invalid graph for the output shapes are not expected.
PreluTester{.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {3, 2, 5}},
.slope = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {3, 2, 5}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {3, 2, 6}},
.expected = false}
.Test();
}
{
// Test the invalid graph when the input is as same as output.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", {2, 3}, mojom::Operand::DataType::kFloat32);
uint64_t slope_operand_id =
builder.BuildInput("slope", {2, 3}, mojom::Operand::DataType::kFloat32);
builder.BuildPrelu(input_operand_id, slope_operand_id, input_operand_id);
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
{
// Test the invalid graph when the slope is as same as output.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", {2, 3}, mojom::Operand::DataType::kFloat32);
uint64_t output_operand_id = builder.BuildOutput(
"output", {2, 3}, mojom::Operand::DataType::kFloat32);
builder.BuildPrelu(input_operand_id, output_operand_id, output_operand_id);
EXPECT_FALSE(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()));
}
}

struct ReluTester {
OperandInfo input;
OperandInfo output;
Expand Down
11 changes: 11 additions & 0 deletions services/webnn/webnn_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,17 @@ void GraphInfoBuilder::BuildElementWiseBinary(
mojom::Operation::NewElementWiseBinary(std::move(binary)));
}

void GraphInfoBuilder::BuildPrelu(uint64_t input_operand_id,
uint64_t slope_operand_id,
uint64_t output_operand_id) {
mojom::PreluPtr prelu = mojom::Prelu::New();
prelu->input_operand_id = input_operand_id;
prelu->slope_operand_id = slope_operand_id;
prelu->output_operand_id = output_operand_id;
graph_info_->operations.push_back(
mojom::Operation::NewPrelu(std::move(prelu)));
}

void GraphInfoBuilder::BuildRelu(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::ReluPtr relu = mojom::Relu::New();
Expand Down
4 changes: 4 additions & 0 deletions services/webnn/webnn_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ class GraphInfoBuilder final {
mojom::Operation::NewPool2d(std::move(pool2d)));
}

void BuildPrelu(uint64_t input_operand_id,
uint64_t slope_operand_id,
uint64_t output_operand_id);

void BuildRelu(uint64_t input_operand_id, uint64_t output_operand_id);

void BuildResample2d(uint64_t input_operand_id,
Expand Down
26 changes: 8 additions & 18 deletions third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1227,30 +1227,20 @@ MLOperand* MLGraphBuilder::maxPool2d(const MLOperand* input,
MLOperand* MLGraphBuilder::prelu(const MLOperand* input,
const MLOperand* slope,
ExceptionState& exception_state) {
if (input->Type() != slope->Type()) {
exception_state.ThrowDOMException(
DOMExceptionCode::kDataError,
"The type of slope doesn't match the type of input.");
return nullptr;
}
if (!IsFloatingPointType(input->Type())) {
exception_state.ThrowDOMException(
DOMExceptionCode::kDataError,
"The type of input and slope must be one of the floating point types.");
return nullptr;
}
// BroadcastShape unidirectionally broadcasts the slope->Dimensions() to the
// input->Dimensions().
if (!BroadcastShapes(slope->Dimensions(), input->Dimensions(), false)) {
auto validated_output = webnn::ValidatePreluAndInferOutput(
ConvertToComponentOperand(input), ConvertToComponentOperand(slope));
if (!validated_output.has_value()) {
exception_state.ThrowDOMException(
DOMExceptionCode::kDataError,
"The shape of slope is not broadcastable to the shape of input.");
String::FromUTF8(validated_output.error()));
return nullptr;
}

auto* prelu =
MakeGarbageCollected<MLOperator>(this, MLOperator::OperatorKind::kPRelu);
auto output = MLOperand::ValidateAndCreateOutput(this, input->Type(),
input->Dimensions(), prelu);
auto output = MLOperand::ValidateAndCreateOutput(
this, ComponentOperandTypeToBlink(validated_output->data_type),
Vector<uint32_t>(validated_output->dimensions), prelu);
if (!output.has_value()) {
exception_state.ThrowDOMException(DOMExceptionCode::kDataError,
output.error());
Expand Down

0 comments on commit 5cacc81

Please sign in to comment.