Skip to content

Commit

Permalink
WebNN: Define slice operator in mojo
Browse files Browse the repository at this point in the history
Adds validation and unit tests for the slice operator in the WebNN
service.

Bug: 1273291
Change-Id: Id4ddd36dc4d2809e59d69a755d9d051caf26ef02
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4844179
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Commit-Queue: Brandon1 Jones <brandon1.jones@intel.com>
Reviewed-by: Rafael Cintron <rafael.cintron@microsoft.com>
Reviewed-by: ningxin hu <ningxin.hu@intel.com>
Cr-Commit-Position: refs/heads/main@{#1210363}
  • Loading branch information
bjjones authored and Chromium LUCI CQ committed Oct 16, 2023
1 parent 6be1e13 commit 164347c
Show file tree
Hide file tree
Showing 11 changed files with 390 additions and 54 deletions.
63 changes: 63 additions & 0 deletions components/ml/webnn/graph_validation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,69 @@ base::expected<Operand, std::string> ValidateTransposeAndInferOutput(
return Operand(input.data_type, std::move(output_shape));
}

SliceAttributes::SliceAttributes() = default;
SliceAttributes::~SliceAttributes() = default;

SliceAttributes::SliceAttributes(SliceAttributes&& other) = default;
SliceAttributes& SliceAttributes::operator=(SliceAttributes&& other) = default;

base::expected<Operand, std::string> ValidateSliceAndInferOutput(
const Operand& input,
const SliceAttributes& attributes) {
if (!attributes.sizes.size()) {
return base::unexpected("The length of sizes must be not be zero.");
}

const auto input_rank = input.dimensions.size();
if (attributes.starts.size() != input_rank) {
return base::unexpected(
"The length of starts must be equal to the rank of the input tensor.");
}

if (attributes.sizes.size() != input_rank) {
return base::unexpected(
"The length of sizes must be equal to the rank of the input tensor.");
}

for (uint32_t i = 0; i < input_rank; ++i) {
if (attributes.starts[i] >= input.dimensions[i]) {
return base::unexpected(base::StringPrintf(
"For dimension (%u): the starting index to slice must "
"be less than input size (%u).",
i, input.dimensions[i]));
}

// WebNN plans to allow 0 size dimensions and an issue has been filed to
// track it: https://github.com/webmachinelearning/webnn/issues/391.
if (attributes.sizes[i] == 0) {
return base::unexpected(base::StringPrintf(
"For dimension (%u): the number of elements to slice "
"must not be 0.",
i));
}

auto checked_ending_index =
base::MakeCheckedNum<uint32_t>(attributes.starts[i]) +
attributes.sizes[i];
if (!checked_ending_index.IsValid<uint32_t>()) {
return base::unexpected(base::StringPrintf(
"For dimension (%u): the ending index to slice is too large.", i));
}

if (checked_ending_index.ValueOrDie() > input.dimensions[i]) {
return base::unexpected(
base::StringPrintf("For dimension (%u): the ending index to slice "
"must not be greater than input size (%u).",
i, input.dimensions[i]));
}
}

// The output is a tensor the same as the specified slice sizes.
std::vector<uint32_t> output_shape;
output_shape.assign(attributes.sizes.begin(), attributes.sizes.end());
return Operand(input.data_type, std::move(output_shape));
}

base::expected<size_t, std::string> ValidateAndCalculateElementsNumber(
base::span<const uint32_t> dimensions) {
if (dimensions.empty()) {
Expand Down
24 changes: 24 additions & 0 deletions components/ml/webnn/graph_validation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,24 @@ struct GemmAttributes {
bool b_transpose = false;
};

struct SliceAttributes {
SliceAttributes();
~SliceAttributes();

SliceAttributes(SliceAttributes&& other);
SliceAttributes& operator=(SliceAttributes&& other);

SliceAttributes(const SliceAttributes&) = delete;
SliceAttributes& operator=(const SliceAttributes&) = delete;

// The sequence of unsigned integer values indicating the starting index to
// slice of each input dimension.
std::vector<uint32_t> starts;
// The sequence of unsigned integer values indicating the number of elements
// to slice of each input dimension.
std::vector<uint32_t> sizes;
};

// Validate softmax operator defined in WebIDL here
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax
base::expected<Operand, std::string> ValidateSoftmaxAndInferOutput(
Expand Down Expand Up @@ -217,6 +235,12 @@ base::expected<Operand, std::string> ValidateTransposeAndInferOutput(
const Operand& input,
base::span<const uint32_t> permutation);

// Validate slice operator defined in WebIDL here:
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-slice
base::expected<Operand, std::string> ValidateSliceAndInferOutput(
const Operand& input,
const SliceAttributes& attributes);

base::expected<size_t, std::string> ValidateAndCalculateElementsNumber(
base::span<const uint32_t> dimensions);

Expand Down
2 changes: 2 additions & 0 deletions services/webnn/dml/graph_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ std::string OpTagToString(Operation::Tag tag) {
return "resample2d";
case Operation::Tag::kRelu:
return "relu";
case Operation::Tag::kSlice:
return "slice";
case Operation::Tag::kSplit:
return "split";
case Operation::Tag::kTranspose:
Expand Down
18 changes: 18 additions & 0 deletions services/webnn/public/mojom/webnn_graph.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ struct Pool2d {
InputOperandLayout layout;
};

struct StartAndSize {
uint32 start;
uint32 size;
};

struct Slice {
// The id of input operand is used to get the `Operand` description from
// `GraphInfo.id_to_operand_map`.
uint64 input_operand_id;
// The id of output operand is used to get the `Operand` description from
// `GraphInfo.id_to_operand_map`.
uint64 output_operand_id;
// An array containing the number of elements of the input window in each
// dimension.
array<StartAndSize> starts_and_sizes;
};

// Contains the attributes of gemm operator.
struct GemmAttributes {
// The optional third tensor in expression alpha * A * B + beta * C.
Expand Down Expand Up @@ -280,6 +297,7 @@ union Operation {
Pool2d pool2d;
Relu relu;
Resample2d resample2d;
Slice slice;
Softmax softmax;
Split split;
Transpose transpose;
Expand Down
36 changes: 36 additions & 0 deletions services/webnn/webnn_graph_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,18 @@ absl::optional<webnn::GemmAttributes> ConvertToGemmAttributes(
return component_attributes;
}

webnn::SliceAttributes ConvertToSliceAttributes(
const webnn::mojom::SlicePtr& slice) {
webnn::SliceAttributes component_attributes;
component_attributes.starts.reserve(slice->starts_and_sizes.size());
component_attributes.sizes.reserve(slice->starts_and_sizes.size());
for (const auto& start_and_size : slice->starts_and_sizes) {
component_attributes.starts.push_back(start_and_size->start);
component_attributes.sizes.push_back(start_and_size->size);
}
return component_attributes;
}

// TODO(crbug.com/1273291): This function will replaced by `operation`
const mojom::Operand* GetMojoOperand(
const IdToOperandMap& id_to_operand_map,
Expand Down Expand Up @@ -464,6 +476,28 @@ bool ValidateReshape(const IdToOperandMap& id_to_operand_map,
return true;
}

bool ValidateSlice(const IdToOperandMap& id_to_operand_map,
const mojom::SlicePtr& slice) {
auto* input = GetMojoOperand(id_to_operand_map, slice->input_operand_id);
auto* output = GetMojoOperand(id_to_operand_map, slice->output_operand_id);

if (!input || !output || output == input) {
// The slice operator is invalid.
return false;
}

auto validated_output = ValidateSliceAndInferOutput(
ConvertToComponentOperand(input), ConvertToSliceAttributes(slice));
if (!validated_output.has_value()) {
return false;
}
if (validated_output != ConvertToComponentOperand(output)) {
return false;
}

return true;
}

bool ValidateSoftmax(const IdToOperandMap& id_to_operand_map,
const mojom::SoftmaxPtr& softmax) {
auto* input = GetMojoOperand(id_to_operand_map, softmax->input_operand_id);
Expand Down Expand Up @@ -605,6 +639,8 @@ bool ValidateOperation(const IdToOperandMap& id_to_operand_map,
return ValidateResample2d(id_to_operand_map, operation->get_resample2d());
case mojom::Operation::Tag::kRelu:
return ValidateRelu(id_to_operand_map, operation->get_relu());
case mojom::Operation::Tag::kSlice:
return ValidateSlice(id_to_operand_map, operation->get_slice());
case mojom::Operation::Tag::kSoftmax:
return ValidateSoftmax(id_to_operand_map, operation->get_softmax());
case mojom::Operation::Tag::kSplit:
Expand Down
109 changes: 109 additions & 0 deletions services/webnn/webnn_graph_impl_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,115 @@ TEST_F(WebNNGraphImplTest, ReshapeTest) {
.Test();
}
}
struct SliceTester {
struct SliceAttributes {
std::vector<uint32_t> starts;
std::vector<uint32_t> sizes;
};

OperandInfo input;
SliceAttributes attributes;
OperandInfo output;

bool expected;

void Test() {
// Build the graph with mojo type.
GraphInfoBuilder builder;
uint64_t input_operand_id =
builder.BuildInput("input", input.dimensions, input.type);
uint64_t output_operand_id =
builder.BuildOutput("output", output.dimensions, output.type);
builder.BuildSlice(input_operand_id, output_operand_id,
std::move(attributes.starts),
std::move(attributes.sizes));
EXPECT_EQ(WebNNGraphImpl::ValidateGraph(builder.GetGraphInfo()), expected);
}
};

TEST_F(WebNNGraphImplTest, SliceTest) {
{
// Test slice with output dimensions equal to input dimensions.
SliceTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 4}},
.attributes = {.starts = {0, 0}, .sizes = {4, 4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 4}},
.expected = true}
.Test();
}
{
// Test 4x4 2-D Tensor to 2x2 slice
SliceTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 4}},
.attributes = {.starts = {0, 0}, .sizes = {2, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 2}},
.expected = true}
.Test();
}
{
// Test 4x4 2-D Tensor to 2x2 slice with offsets
SliceTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 4}},
.attributes = {.starts = {2, 2}, .sizes = {2, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 2}},
.expected = true}
.Test();
}
{
// Test that going out-of-bounds of the input tensor fails.
SliceTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 2}},
.attributes = {.starts = {1, 0}, .sizes = {2, 2}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 2}},
.expected = false}
.Test();
}
{
// Test that mismatched output dimensions and size attribute will fail.
SliceTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 2}},
.attributes = {.starts = {0, 0}, .sizes = {1, 1}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 1}},
.expected = false}
.Test();
}
{
// Test that using size zero will result in failure.
SliceTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {2, 2}},
.attributes = {.starts = {0, 0}, .sizes = {0, 1}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {1}},
.expected = false}
.Test();
}
{
// Test that having starts and sizes lengths not equal to the input rank
// will fail.
SliceTester{.input = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 4}},
.attributes = {.starts = {0}, .sizes = {4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 4}},
.expected = false}
.Test();
}
{
// Test that input data type not equal to the output data type will fail.
SliceTester{.input = {.type = mojom::Operand::DataType::kFloat16,
.dimensions = {4, 4}},
.attributes = {.starts = {0, 0}, .sizes = {4, 4}},
.output = {.type = mojom::Operand::DataType::kFloat32,
.dimensions = {4, 4}},
.expected = false}
.Test();
}
}

struct SoftmaxTester {
OperandInfo input;
Expand Down
19 changes: 19 additions & 0 deletions services/webnn/webnn_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,25 @@ void GraphInfoBuilder::BuildTranspose(uint64_t input_operand_id,
mojom::Operation::NewTranspose(std::move(transpose)));
}

void GraphInfoBuilder::BuildSlice(uint64_t input_operand_id,
uint64_t output_operand_id,
std::vector<uint32_t> starts,
std::vector<uint32_t> sizes) {
CHECK(starts.size() == sizes.size());
mojom::SlicePtr slice = mojom::Slice::New();
slice->input_operand_id = input_operand_id;
slice->output_operand_id = output_operand_id;
for (uint32_t i = 0; i < starts.size(); ++i) {
mojom::StartAndSizePtr start_and_size = mojom::StartAndSize::New();
start_and_size->start = starts[i];
start_and_size->size = sizes[i];
slice->starts_and_sizes.push_back(std::move(start_and_size));
}

graph_info_->operations.push_back(
mojom::Operation::NewSlice(std::move(slice)));
}

mojom::GraphInfoPtr GraphInfoBuilder::CloneGraphInfo() const {
CHECK_IS_TEST();
mojom::GraphInfoPtr cloned_graph_info = mojom::GraphInfo::New();
Expand Down
5 changes: 5 additions & 0 deletions services/webnn/webnn_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ class GraphInfoBuilder final {
uint64_t output_operand_id,
std::vector<uint32_t> permutation);

void BuildSlice(uint64_t input_operand_id,
uint64_t output_operand_id,
std::vector<uint32_t> starts,
std::vector<uint32_t> sizes);

const mojom::GraphInfoPtr& GetGraphInfo() const { return graph_info_; }

// Get a clone of internal graph info. This is used by
Expand Down

0 comments on commit 164347c

Please sign in to comment.