Skip to content

Commit

Permalink
WebNN: Define concat operator in mojo
Browse files Browse the repository at this point in the history
The concat operator concatenates the input tensors along a given axis.

This CL moves the validation steps of concat to //components/,
implements CreateConcatOperation() for creating concat mojo operation
and adds operator validation on service side.

Bug: 1273291
Change-Id: I5ad8d6ee387d44f104e8c988733e7b321b381a39
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4929772
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Reviewed-by: Alex Gough <ajgo@chromium.org>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Commit-Queue: Shiyi Zou <shiyi.zou@intel.com>
Cr-Commit-Position: refs/heads/main@{#1209977}
  • Loading branch information
shiyi9801 authored and Chromium LUCI CQ committed Oct 16, 2023
1 parent 9442494 commit 65521ce
Show file tree
Hide file tree
Showing 11 changed files with 433 additions and 77 deletions.
62 changes: 62 additions & 0 deletions components/ml/webnn/graph_validation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,68 @@ base::expected<Operand, std::string> ValidateGemmAndInferOutput(
return Operand(a.data_type, std::move(output_shape));
}

base::expected<Operand, std::string> ValidateConcatAndInferOutput(
const std::vector<Operand>& inputs,
const uint32_t axis) {
if (inputs.empty()) {
return base::unexpected("The inputs should not be empty.");
}
const auto& first_input_shape = inputs[0].dimensions;
const auto first_input_rank = first_input_shape.size();
// According to WebNN spec:
// https://www.w3.org/TR/webnn/#dom-mlgraphbuilder-concat-inputs-axis-axis,
// the axis that the inputs concatenate along, with the value in the interval
// [0, N-1] where N is the rank of input tensors. We just check the first
// input rank here because we will check all inputs have same rank in the
// following loop.
if (axis >= first_input_rank) {
return base::unexpected(
"The axis must be in the range [0, N-1] where N is the rank of input "
"tensor.");
}
const auto output_type = inputs[0].data_type;
// The loop skips the first input to avoid repeated checks.
for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs[i].data_type != output_type) {
return base::unexpected("The input types don't match.");
}
// According to WebNN spec:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-concat, all input tensors
// must have the same dimension.
if (inputs[i].dimensions.size() != first_input_rank) {
return base::unexpected(
"All input tensors must have the same dimension.");
}
// According to WebNN spec:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-concat, all input tensors
// must have the same shape, except for the size of the dimension to
// concatenate on.
for (size_t dim = 0; dim < first_input_rank; ++dim) {
if (dim == axis || inputs[i].dimensions[dim] == first_input_shape[dim]) {
continue;
}
return base::unexpected(
"All input tensors must have the same shape, except for the size of "
"the dimension to concatenate on.");
}
}
// Calculate the output shape according to WebNN spec:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-concat, the output tensor
// has the same shape except on the dimension that all the inputs concatenated
// along. The size of that dimension is computed as the sum of all the input
// sizes of the same dimension.
auto axis_size = base::MakeCheckedNum<uint32_t>(0);
for (auto& input : inputs) {
axis_size += input.dimensions[axis];
}
auto output_shape = first_input_shape;
if (!axis_size.AssignIfValid(&output_shape[axis])) {
return base::unexpected("The concatenated dimension size is too large.");
}

return Operand(output_type, std::move(output_shape));
}

base::expected<Operand, std::string> ValidateTransposeAndInferOutput(
const Operand& input,
base::span<const uint32_t> permutation) {
Expand Down
6 changes: 6 additions & 0 deletions components/ml/webnn/graph_validation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ base::expected<Operand, std::string> ValidateGemmAndInferOutput(
const Operand& b,
const GemmAttributes& attributes);

// Validate concat operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-concat
base::expected<Operand, std::string> ValidateConcatAndInferOutput(
const std::vector<Operand>& input,
const uint32_t axis);

// Validate transpose operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-transpose
base::expected<Operand, std::string> ValidateTransposeAndInferOutput(
Expand Down
2 changes: 2 additions & 0 deletions services/webnn/dml/graph_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ std::string OpTagToString(Operation::Tag tag) {
switch (tag) {
case Operation::Tag::kClamp:
return "clamp";
case Operation::Tag::kConcat:
return "concat";
case Operation::Tag::kConv2d:
return "conv2d";
case Operation::Tag::kPool2d:
Expand Down
14 changes: 14 additions & 0 deletions services/webnn/public/mojom/webnn_graph.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ struct Clamp {
float max_value;
};

// Represents the concat operation that concatenates the input tensors along
// the given axis.
struct Concat {
// The ids of input operand are used to get the `Operand` description from
// `GraphInfo.id_to_operand_map`.
array<uint64> input_operand_ids;
// The id of output operand is used to get the `Operand` description from
// `GraphInfo.id_to_operand_map`.
uint64 output_operand_id;
// The axis used to concatenate along.
uint32 axis;
};

// Represents the `MLInputOperandLayout` that specifies the layout format of
// the input tensor. `kChannelsFirst` means `nchw` (batches, channels, height,
// width), `kChannelsLast` means `nhwc` (batches, height, width, channels).
Expand Down Expand Up @@ -246,6 +259,7 @@ struct Operator {
// Holds one of operator.
union Operation {
Clamp clamp;
Concat concat;
Conv2d conv2d;
Pool2d pool2d;
Relu relu;
Expand Down
31 changes: 31 additions & 0 deletions services/webnn/webnn_graph_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,35 @@ bool ValidateClamp(const IdToOperandMap& id_to_operand_map,
return true;
}

bool ValidateConcat(const IdToOperandMap& id_to_operand_map,
const mojom::ConcatPtr& concat) {
auto* output = GetMojoOperand(id_to_operand_map, concat->output_operand_id);
if (!output) {
// The concat operator is invalid.
return false;
}

std::vector<Operand> inputs;
inputs.reserve(concat->input_operand_ids.size());
for (const auto& input_operand_id : concat->input_operand_ids) {
auto* input = GetMojoOperand(id_to_operand_map, input_operand_id);
if (!input || input == output) {
return false;
}
inputs.push_back(ConvertToComponentOperand(input));
}

auto validated_output = ValidateConcatAndInferOutput(inputs, concat->axis);
if (!validated_output.has_value()) {
return false;
}
if (validated_output != ConvertToComponentOperand(output)) {
return false;
}

return true;
}

bool ValidateConv2d(const IdToOperandMap& id_to_operand_map,
const mojom::Conv2dPtr& conv2d) {
auto* input = GetMojoOperand(id_to_operand_map, conv2d->input_operand_id);
Expand Down Expand Up @@ -547,6 +576,8 @@ bool ValidateOperation(const IdToOperandMap& id_to_operand_map,
switch (operation->which()) {
case mojom::Operation::Tag::kClamp:
return ValidateClamp(id_to_operand_map, operation->get_clamp());
case mojom::Operation::Tag::kConcat:
return ValidateConcat(id_to_operand_map, operation->get_concat());
case mojom::Operation::Tag::kConv2d:
return ValidateConv2d(id_to_operand_map, operation->get_conv2d());
case mojom::Operation::Tag::kPool2d:
Expand Down
151 changes: 151 additions & 0 deletions services/webnn/webnn_graph_impl_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
// found in the LICENSE file.

#include "services/webnn/webnn_graph_impl.h"

#include <limits>

#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/test/bind.h"
#include "base/test/task_environment.h"
#include "components/ml/webnn/graph_validation_utils.h"
Expand Down Expand Up @@ -288,6 +292,153 @@ TEST_F(WebNNGraphImplTest, ClampTest) {
}
}

struct ConcatTester {
std::vector<OperandInfo> inputs;
uint32_t axis;
OperandInfo output;
bool expected;

void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
std::vector<uint64_t> input_operand_ids;
input_operand_ids.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
input_operand_ids.push_back(
builder.BuildInput(base::StringPrintf("input%zu", i),
inputs[i].dimensions, inputs[i].type));
}
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildConcat(std::move(input_operand_ids), output_operand_id, axis);
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};

TEST_F(WebNNGraphImplTest, ConcatTest) {
{
// Test concat operator with three inputs.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 3, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 6, 5, 6}},
.expected = true}
.Test();
}
{
// Test concat operator when the input is the same as output.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
.expected = true}
.Test();
}
{
// Test concat operator with empty inputs.
ConcatTester{
.inputs = {},
.axis = 0,
.output = {.type = mojom::Operand::DataType::kInt32, .dimensions = {1}},
.expected = false}
.Test();
}
{
// Test concat operator when the inputs' datatypes don't match each other.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kInt32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 3, 5, 6}},
.expected = false}
.Test();
}
{
// Test concat operator when the inputs can not be concatenated.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 3, 5}},
.expected = false}
.Test();
}
{
// Test concat operator when the axis is equal to or greater than the
// size of dimension.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}}},
.axis = 4,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 12}},
.expected = false}
.Test();
}
{
// Test concat operator when the inputs have other axes with different
// sizes except on the axis.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 1}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5, 7}},
.expected = false}
.Test();
}
{
// Test concat operator when the concatenated dimension size overflows.
ConcatTester{
.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {std::numeric_limits<uint32_t>::max()}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1}}},
.axis = 0,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {0}},
.expected = false}
.Test();
}
{
// Test concat operator when the output datatype doesn't match the inputs'
// datatypes.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 5, 6}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 2, 5, 6}}},
.axis = 1,
.output = {.type = mojom::Operand::DataType::kInt32,
.dimensions = {3, 3, 5, 6}},
.expected = false}
.Test();
}
{
// Test concat operator when the output dimension is incorrect.
ConcatTester{.inputs = {{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {3, 1, 2}},
{.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1, 1, 2}}},
.axis = 0,
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {5, 1, 2}},
.expected = false}
.Test();
}
}

struct Conv2dTester {
OperandInfo input;
OperandInfo filter;
Expand Down
11 changes: 11 additions & 0 deletions services/webnn/webnn_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ void GraphInfoBuilder::BuildClamp(uint64_t input_operand_id,
mojom::Operation::NewClamp(std::move(clamp)));
}

void GraphInfoBuilder::BuildConcat(std::vector<uint64_t> input_operand_ids,
uint64_t output_operand_id,
uint32_t axis) {
mojom::ConcatPtr concat = mojom::Concat::New();
concat->input_operand_ids = std::move(input_operand_ids);
concat->output_operand_id = output_operand_id;
concat->axis = axis;
graph_info_->operations.push_back(
mojom::Operation::NewConcat(std::move(concat)));
}

void GraphInfoBuilder::BuildRelu(uint64_t input_operand_id,
uint64_t output_operand_id) {
mojom::ReluPtr relu = mojom::Relu::New();
Expand Down
4 changes: 4 additions & 0 deletions services/webnn/webnn_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ class GraphInfoBuilder final {
float min_value,
float max_value);

void BuildConcat(std::vector<uint64_t> input_operand_ids,
uint64_t output_operand_id,
uint32_t axis);

// A `Conv2dAttributes` type should have the following members:
// struct Conv2dAttributes {
// std::vector<uint32_t> padding;
Expand Down

0 comments on commit 65521ce

Please sign in to comment.