Skip to content

Commit

Permalink
WebNN: Define resample2d in mojo
Browse files Browse the repository at this point in the history
This CL defines resample2d in mojo, implements `CreateResample2dOperation()` for creating resample2d mojo operation and adds validation on service side.

Some tests for resample2d are also added both for blink unittests and
services unittests.

Bug: 1273291
Change-Id: I69ffdf2f80797e385654b71052def9309e337285
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4924806
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Commit-Queue: Mingming1 Xu <mingming1.xu@intel.com>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Cr-Commit-Position: refs/heads/main@{#1210065}
  • Loading branch information
mingmingtasd authored and Chromium LUCI CQ committed Oct 16, 2023
1 parent 6cced94 commit 34e5374
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 123 deletions.
27 changes: 11 additions & 16 deletions components/ml/webnn/graph_validation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@ bool IsFloatingPointType(Operand::DataType data_type) {
NOTREACHED_NORETURN();
}

struct FloatSize2D {
double height;
double width;
};

// Calculate the output size for conv2d based on WebNN spec:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-conv2d
// Return the calculated output size if no error.
Expand Down Expand Up @@ -81,15 +76,15 @@ base::expected<double, std::string> CalculateConv2dOutputSize(
// input sizes, filter sizes, padding, strides and dilations.
// Return the calculated output sizes in double precision floating point number
// if no errors.
base::expected<FloatSize2D, std::string> ValidateAndCalculateConv2dOutputSizes(
const uint32_t input_height,
const uint32_t input_width,
const uint32_t filter_height,
const uint32_t filter_width,
const Padding2d& padding,
const Size2d& strides,
const Size2d& dilations,
const AutoPad auto_pad) {
base::expected<Size2d<double>, std::string>
ValidateAndCalculateConv2dOutputSizes(const uint32_t input_height,
const uint32_t input_width,
const uint32_t filter_height,
const uint32_t filter_width,
const Padding2d& padding,
const Size2d<uint32_t>& strides,
const Size2d<uint32_t>& dilations,
const AutoPad auto_pad) {
uint32_t padding_beginning_height = padding.beginning.height;
uint32_t padding_ending_height = padding.ending.height;
uint32_t padding_beginning_width = padding.beginning.width;
Expand Down Expand Up @@ -147,8 +142,8 @@ base::expected<FloatSize2D, std::string> ValidateAndCalculateConv2dOutputSizes(
float_output_width.error());
}

return FloatSize2D({.height = float_output_height.value(),
.width = float_output_width.value()});
return Size2d<double>{.height = float_output_height.value(),
.width = float_output_width.value()};
}

} // namespace
Expand Down
25 changes: 13 additions & 12 deletions components/ml/webnn/graph_validation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,19 @@ enum AutoPad { kExplicit, kSameUpper, kSameLower };
enum RoundingType { kFloor, kCeil };

// A size has height and width values.
template <typename T>
struct Size2d {
uint32_t height;
uint32_t width;
T height;
T width;
};

// The additional rows and columns added to the beginning and ending of each
// spatial dimension of input.
struct Padding2d {
// The height and width padding at the beginning of input tensor.
Size2d beginning;
Size2d<uint32_t> beginning;
// The height and width padding at the ending of input tensor.
Size2d ending;
Size2d<uint32_t> ending;
};

// Contains the attributes of conv2d operator.
Expand All @@ -98,9 +99,9 @@ struct Conv2dAttributes {
// spatial dimension of input.
Padding2d padding;
// The stride of the sliding window for each spatial dimension of input.
Size2d strides;
Size2d<uint32_t> strides;
// The dilation factor for each spatial dimension of input.
Size2d dilations;
Size2d<uint32_t> dilations;
// The automatic input padding options.
AutoPad auto_pad = AutoPad::kExplicit;
// The number of groups that input channels and output channels are divided
Expand All @@ -118,23 +119,23 @@ struct Conv2dAttributes {
// Contains the attributes of pool2d operator.
struct Pool2dAttributes {
// The dimensions of the sliding window.
absl::optional<Size2d> window_dimensions;
absl::optional<Size2d<uint32_t>> window_dimensions;
// The additional rows and columns added to the beginning and ending of each
// spatial dimension of input.
Padding2d padding;
// The element stride of the sliding window for each spatial dimension of
// input.
Size2d strides;
Size2d<uint32_t> strides;
// The dilation factor for each spatial dimension of input.
Size2d dilations;
Size2d<uint32_t> dilations;
// The automatic input padding options.
AutoPad auto_pad = AutoPad::kExplicit;
// The layout format of the input.
InputOperandLayout layout = InputOperandLayout::kNchw;
// The rounding function used to compute the output shape.
RoundingType rounding_type = RoundingType::kFloor;
// The element height and width of the output tensor.
absl::optional<Size2d> output_sizes;
absl::optional<Size2d<uint32_t>> output_sizes;
};

// Contains the attributes of gemm operator.
Expand Down Expand Up @@ -191,8 +192,8 @@ base::expected<Operand, std::string> ValidateConv2dAndInferOutput(
const Operand& filter,
const Conv2dAttributes& attributes);

// Validate a mean, L2 norm, or max reduction operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-pool2d
// Validate and infer output information of 2-D pooling operator defined in
// WebIDL here https://www.w3.org/TR/webnn/#api-mlgraphbuilder-pool2d
base::expected<Operand, std::string> ValidatePool2dAndInferOutput(
const Operand& input,
const Pool2dAttributes& attributes);
Expand Down
2 changes: 2 additions & 0 deletions services/webnn/dml/graph_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ std::string OpTagToString(Operation::Tag tag) {
return "conv2d";
case Operation::Tag::kPool2d:
return "pool2d";
case Operation::Tag::kResample2d:
return "resample2d";
case Operation::Tag::kRelu:
return "relu";
case Operation::Tag::kSplit:
Expand Down
17 changes: 17 additions & 0 deletions services/webnn/public/mojom/webnn_graph.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,22 @@ union OperatorAttributes {
GemmAttributes gemm;
};

// Resample the tensor values from the source to the destination spatial
// dimensions.
struct Resample2d {
// The id of input operand.
uint64 input_operand_id;
// The id of output operand.
uint64 output_operand_id;

enum InterpolationMode {
kNearestNeighbor,
kLinear,
};

InterpolationMode mode;
};

// Represents the operations defined in `MLGraphBuilder` that describes the
// functional semantics.
struct Operator {
Expand Down Expand Up @@ -263,6 +279,7 @@ union Operation {
Conv2d conv2d;
Pool2d pool2d;
Relu relu;
Resample2d resample2d;
Softmax softmax;
Split split;
Transpose transpose;
Expand Down
69 changes: 45 additions & 24 deletions services/webnn/webnn_graph_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ webnn::Conv2dAttributes ConvertToConv2dAttributes(
// Convert padding, strides, dilations.
auto& mojo_padding = conv2d->padding;
component_attributes.padding = webnn::Padding2d{
.beginning = webnn::Size2d{.height = mojo_padding->beginning->height,
.width = mojo_padding->beginning->width},
.ending = webnn::Size2d{.height = mojo_padding->ending->height,
.width = mojo_padding->ending->width}};
component_attributes.strides = webnn::Size2d{
.beginning =
webnn::Size2d<uint32_t>{.height = mojo_padding->beginning->height,
.width = mojo_padding->beginning->width},
.ending = webnn::Size2d<uint32_t>{.height = mojo_padding->ending->height,
.width = mojo_padding->ending->width}};
component_attributes.strides = webnn::Size2d<uint32_t>{
.height = conv2d->strides->height, .width = conv2d->strides->width};
component_attributes.dilations = webnn::Size2d{
component_attributes.dilations = webnn::Size2d<uint32_t>{
.height = conv2d->dilations->height, .width = conv2d->dilations->width};

// Convert groups, input and filter layout.
Expand Down Expand Up @@ -144,28 +145,29 @@ webnn::Pool2dAttributes ConvertToPool2dAttributes(
const mojom::Operand* output) {
webnn::Pool2dAttributes component_attributes;
auto& window_dimensions = pool2d->window_dimensions;
component_attributes.window_dimensions = webnn::Size2d{
component_attributes.window_dimensions = webnn::Size2d<uint32_t>{
.height = window_dimensions->height, .width = window_dimensions->width};
auto& mojo_padding = pool2d->padding;
component_attributes.padding = webnn::Padding2d{
.beginning = webnn::Size2d{.height = mojo_padding->beginning->height,
.width = mojo_padding->beginning->width},
.ending = webnn::Size2d{.height = mojo_padding->ending->height,
.width = mojo_padding->ending->width}};
component_attributes.strides = webnn::Size2d{
.beginning =
webnn::Size2d<uint32_t>{.height = mojo_padding->beginning->height,
.width = mojo_padding->beginning->width},
.ending = webnn::Size2d<uint32_t>{.height = mojo_padding->ending->height,
.width = mojo_padding->ending->width}};
component_attributes.strides = webnn::Size2d<uint32_t>{
.height = pool2d->strides->height, .width = pool2d->strides->width};
component_attributes.dilations = webnn::Size2d{
component_attributes.dilations = webnn::Size2d<uint32_t>{
.height = pool2d->dilations->height, .width = pool2d->dilations->width};
component_attributes.layout =
MojoInputOperandLayoutToComponent(pool2d->layout);
CHECK_EQ(output->dimensions.size(), 4u);
switch (component_attributes.layout) {
case webnn::InputOperandLayout::kNchw:
component_attributes.output_sizes = webnn::Size2d{
component_attributes.output_sizes = webnn::Size2d<uint32_t>{
.height = output->dimensions[2], .width = output->dimensions[3]};
break;
case webnn::InputOperandLayout::kNhwc:
component_attributes.output_sizes = webnn::Size2d{
component_attributes.output_sizes = webnn::Size2d<uint32_t>{
.height = output->dimensions[1], .width = output->dimensions[2]};
break;
}
Expand Down Expand Up @@ -417,6 +419,23 @@ bool ValidateRelu(const IdToOperandMap& id_to_operand_map,
return true;
}

bool ValidateResample2d(const IdToOperandMap& id_to_operand_map,
const mojom::Resample2dPtr& resample2d) {
auto* input = GetMojoOperand(id_to_operand_map, resample2d->input_operand_id);
auto* output =
GetMojoOperand(id_to_operand_map, resample2d->output_operand_id);
if (!input || !output || output == input) {
// The resample2d operator is invalid.
return false;
}
if (output->data_type != input->data_type) {
// The output data type doesn't match input data type.
return false;
}

return true;
}

bool ValidateReshape(const IdToOperandMap& id_to_operand_map,
const mojom::OperatorPtr& operation) {
auto* input = GetMojoOperand(id_to_operand_map, operation->input_operands);
Expand Down Expand Up @@ -582,6 +601,8 @@ bool ValidateOperation(const IdToOperandMap& id_to_operand_map,
return ValidateConv2d(id_to_operand_map, operation->get_conv2d());
case mojom::Operation::Tag::kPool2d:
return ValidatePool2d(id_to_operand_map, operation->get_pool2d());
case mojom::Operation::Tag::kResample2d:
return ValidateResample2d(id_to_operand_map, operation->get_resample2d());
case mojom::Operation::Tag::kRelu:
return ValidateRelu(id_to_operand_map, operation->get_relu());
case mojom::Operation::Tag::kSoftmax:
Expand Down Expand Up @@ -712,15 +733,15 @@ bool WebNNGraphImpl::ValidateGraph(const mojom::GraphInfoPtr& graph_info) {
void WebNNGraphImpl::Compute(
base::flat_map<std::string, mojo_base::BigBuffer> named_inputs,
mojom::WebNNGraph::ComputeCallback callback) {
// Validate the inputs for computation match the built graph's expected.
if (!base::ranges::equal(
named_inputs, compute_resource_info_.input_name_to_byte_length_map,
[](const auto& iter_a, const auto& iter_b) {
// Compare the input name with the key of map and the byte length of
// buffer with value of map.
return iter_a.first == iter_b.first &&
iter_a.second.size() == iter_b.second;
})) {
// Validate the inputs for computation match the built graph's expectation.
if (!base::ranges::equal(named_inputs,
compute_resource_info_.input_name_to_byte_length_map,
[](const auto& iter_a, const auto& iter_b) {
// Compare the input name with the key of map and
// the byte length of buffer with value of map.
return iter_a.first == iter_b.first &&
iter_a.second.size() == iter_b.second;
})) {
std::move(callback).Run(mojom::ComputeResult::kInvalidInputs,
absl::nullopt);
return;
Expand Down
50 changes: 50 additions & 0 deletions services/webnn/webnn_graph_impl_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,56 @@ TEST_F(WebNNGraphImplTest, ReluTest) {
}
}

struct Resample2dTester {
OperandInfo input;
mojom::Resample2d::InterpolationMode mode =
mojom::Resample2d::InterpolationMode::kNearestNeighbor;
OperandInfo output;
bool expected;

void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildResample2d(input_operand_id, output_operand_id, mode);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};

TEST_F(WebNNGraphImplTest, Resample2dTest) {
{
Resample2dTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.expected = true}
.Test();
}
{
// Test resample2d with mode =
// "mojom::Resample2d::InterpolationMode::kLinear".
Resample2dTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.mode = mojom::Resample2d::InterpolationMode::kLinear,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 4, 8}},
.expected = true}
.Test();
}
{
// Test the invalid graph for output types don't match.
Resample2dTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 2, 4}},
.output = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {1, 1, 4, 8}},
.expected = false}
.Test();
}
}

struct ReshapeTester {
OperandInfo input;
OperandInfo output;
Expand Down
12 changes: 12 additions & 0 deletions services/webnn/webnn_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ class GraphInfoBuilder final {

void BuildRelu(uint64_t input_operand_id, uint64_t output_operand_id);

void BuildResample2d(uint64_t input_operand_id,
uint64_t output_operand_id,
mojom::Resample2d::InterpolationMode mode) {
mojom::Resample2dPtr resample2d = mojom::Resample2d::New();
resample2d->input_operand_id = input_operand_id;
resample2d->output_operand_id = output_operand_id;
resample2d->mode = mode;

graph_info_->operations.push_back(
mojom::Operation::NewResample2d(std::move(resample2d)));
}

void BuildSoftmax(uint64_t input_operand_id, uint64_t output_operand_id);

void BuildSplit(uint64_t input_operand_id,
Expand Down

0 comments on commit 34e5374

Please sign in to comment.