Skip to content

Commit

Permalink
WebNN: Fix axis of concat MLOperator according to WebNN spec
Browse files Browse the repository at this point in the history
This CL implements the WebNN spec change [1] that changes axis of concat
MLOperator to be an unsigned integer.

The MLGraphBuilder and MLGraphBuilderTest are also updated according to
this IDL change.

[1]: webmachinelearning/webnn#359

Bug: 1273291
Change-Id: I519a070e6b2fed78cecc736004895c67bda9c785
Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/4328150
Commit-Queue: Bin Miao <bin.miao@intel.com>
Reviewed-by: Jiewei Qian <qjw@chromium.org>
Cr-Commit-Position: refs/heads/main@{#1115579}
  • Loading branch information
miaobin authored and Chromium LUCI CQ committed Mar 10, 2023
1 parent e1fb8f6 commit 3599a69
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 51 deletions.
28 changes: 11 additions & 17 deletions third_party/blink/renderer/modules/ml/webnn/ml_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ MLOperand* MLGraphBuilder::constant(const MLOperandDescriptor* desc,
}

MLOperand* MLGraphBuilder::concat(const HeapVector<Member<MLOperand>>& inputs,
int32_t axis,
const uint32_t axis,
ExceptionState& exception_state) {
auto* concat =
MakeGarbageCollected<MLOperator>(this, MLOperator::OperatorKind::kConcat);
Expand All @@ -565,22 +565,16 @@ MLOperand* MLGraphBuilder::concat(const HeapVector<Member<MLOperand>>& inputs,
// According to WebNN spec:
// https://www.w3.org/TR/webnn/#dom-mlgraphbuilder-concat-inputs-axis-axis,
// the axis that the inputs concatenate along, with the value in the interval
// [0, N) where N is the rank of all the inputs. We just check the first input
// rank here because we will check all inputs have same rank in the following
// loop.
//
// TODO(crbug.com/1273291): There is a WebNN spec issue discussing whether to
// support signed axis with [-N, N) range or unsigned integer. Update the
// implementation once the WG makes the consensus.
// https://github.com/webmachinelearning/webnn/issues/345
if (axis < 0 || base::MakeStrictNum(axis) >= first_input_rank) {
// [0, N-1] where N is the rank of input tensors. We just check the first
// input rank here because we will check all inputs have same rank in the
// following loop.
if (axis >= first_input_rank) {
exception_state.ThrowDOMException(
DOMExceptionCode::kDataError,
"The value of axis should be in the interval [0, N) where N is the "
"rank of all the inputs.");
"The value of axis should be in the interval [0, N-1] where N is the "
"rank of input tensors.");
return nullptr;
}
const auto concat_axis = base::checked_cast<uint32_t>(axis);
const auto output_type = inputs[0]->Type();
// The loop skips the first input to avoid repeated checks.
for (wtf_size_t i = 1; i < inputs.size(); ++i) {
Expand All @@ -603,7 +597,7 @@ MLOperand* MLGraphBuilder::concat(const HeapVector<Member<MLOperand>>& inputs,
// must have the same shape, except for the size of the dimension to
// concatenate on.
for (wtf_size_t dim = 0; dim < first_input_rank; ++dim) {
if (dim == concat_axis ||
if (dim == axis ||
inputs[i]->Dimensions()[dim] == first_input_shape[dim]) {
continue;
}
Expand All @@ -619,12 +613,12 @@ MLOperand* MLGraphBuilder::concat(const HeapVector<Member<MLOperand>>& inputs,
// has the same shape except on the dimension that all the inputs concatenated
// along. The size of that dimension is computed as the sum of all the input
// sizes of the same dimension.
auto concat_axis_size = base::MakeCheckedNum<uint32_t>(0);
auto axis_size = base::MakeCheckedNum<uint32_t>(0);
for (auto& input : inputs) {
concat_axis_size += input->Dimensions()[concat_axis];
axis_size += input->Dimensions()[axis];
}
auto output_shape = first_input_shape;
if (!concat_axis_size.AssignIfValid(&output_shape[concat_axis])) {
if (!axis_size.AssignIfValid(&output_shape[axis])) {
exception_state.ThrowDOMException(
DOMExceptionCode::kDataError,
"The concatenated dimension size is too large.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class MODULES_EXPORT MLGraphBuilder final : public ScriptWrappable {
ExceptionState& exception_state);

MLOperand* concat(const HeapVector<Member<MLOperand>>& inputs,
int32_t axis,
const uint32_t axis,
ExceptionState& exception_state);

MLOperand* conv2d(const MLOperand* input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ dictionary MLTransposeOptions{
[RaisesException] MLOperand clamp(MLOperand input, optional MLClampOptions options = {});
[RaisesException] MLActivation clamp(optional MLClampOptions options = {});

[RaisesException] MLOperand concat(sequence<MLOperand> inputs, long axis);
[RaisesException] MLOperand concat(sequence<MLOperand> inputs, [EnforceRange] unsigned long axis);

[RaisesException] MLOperand conv2d(MLOperand input, MLOperand filter, optional MLConv2dOptions options = {});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
auto* input_a =
BuildInput(builder, "input_a", input_a_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
int32_t axis = 2;
uint32_t axis = 2;
auto* output = builder->concat({input_a}, axis, scope.GetExceptionState());
EXPECT_NE(output, nullptr);
EXPECT_EQ(output->Kind(), MLOperand::OperandKind::kOutput);
Expand All @@ -228,7 +228,7 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
auto* input_b =
BuildInput(builder, "input_b", input_b_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
int32_t axis = 1;
uint32_t axis = 1;
auto* output =
builder->concat({input_a, input_b}, axis, scope.GetExceptionState());
EXPECT_NE(output, nullptr);
Expand Down Expand Up @@ -256,7 +256,7 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
auto* input_c =
BuildInput(builder, "input_c", input_c_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
int32_t axis = 2;
uint32_t axis = 2;
auto* output = builder->concat({input_a, input_b, input_c}, axis,
scope.GetExceptionState());
EXPECT_NE(output, nullptr);
Expand All @@ -280,7 +280,7 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
auto* input_b =
BuildInput(builder, "input_b", input_b_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
int32_t axis = 0;
uint32_t axis = 0;
auto* output =
builder->concat({input_a, input_b}, axis, scope.GetExceptionState());
EXPECT_NE(output, nullptr);
Expand All @@ -295,7 +295,7 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
}
{
// Test throwing exception when the inputs are empty.
int32_t axis = 0;
uint32_t axis = 0;
auto* output = builder->concat({}, axis, scope.GetExceptionState());
EXPECT_EQ(output, nullptr);
EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
Expand All @@ -313,7 +313,7 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
auto* input_b =
BuildInput(builder, "input_b", input_b_shape,
V8MLOperandType::Enum::kInt32, scope.GetExceptionState());
int32_t axis = 0;
uint32_t axis = 0;
auto* output =
builder->concat({input_a, input_b}, axis, scope.GetExceptionState());
EXPECT_EQ(output, nullptr);
Expand All @@ -332,7 +332,7 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
auto* input_b =
BuildInput(builder, "input_b", input_b_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
int32_t axis = 0;
uint32_t axis = 0;
auto* output =
builder->concat({input_a, input_b}, axis, scope.GetExceptionState());
EXPECT_EQ(output, nullptr);
Expand All @@ -342,7 +342,8 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
"All input tensors must have the same dimension.");
}
{
// Test throwing exception when the axis smaller than 0.
// Test throwing exception when the axis is equal to or greater than the
// size of dimension.
Vector<uint32_t> input_a_shape({1, 1});
Vector<uint32_t> input_b_shape({1, 1});
auto* input_a =
Expand All @@ -351,35 +352,15 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
auto* input_b =
BuildInput(builder, "input_b", input_b_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
int32_t axis = -1;
uint32_t axis = 2;
auto* output =
builder->concat({input_a, input_b}, axis, scope.GetExceptionState());
EXPECT_EQ(output, nullptr);
EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
DOMExceptionCode::kDataError);
EXPECT_EQ(scope.GetExceptionState().Message(),
"The value of axis should be in the interval [0, N) where N is "
"the rank of all the inputs.");
}
{
// Test throwing exception when the axis greater than the size of dimension.
Vector<uint32_t> input_a_shape({1, 1});
Vector<uint32_t> input_b_shape({1, 1});
auto* input_a =
BuildInput(builder, "input_a", input_a_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
auto* input_b =
BuildInput(builder, "input_b", input_b_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
int32_t axis = 2;
auto* output =
builder->concat({input_a, input_b}, axis, scope.GetExceptionState());
EXPECT_EQ(output, nullptr);
EXPECT_EQ(scope.GetExceptionState().CodeAs<DOMExceptionCode>(),
DOMExceptionCode::kDataError);
EXPECT_EQ(scope.GetExceptionState().Message(),
"The value of axis should be in the interval [0, N) where N is "
"the rank of all the inputs.");
"The value of axis should be in the interval [0, N-1] where N is "
"the rank of input tensors.");
}
{
// Test throwing exception when the inputs have other axes with different
Expand All @@ -392,7 +373,7 @@ TEST_F(MLGraphBuilderTest, ConcatTest) {
auto* input_b =
BuildInput(builder, "input_b", input_b_shape,
V8MLOperandType::Enum::kFloat32, scope.GetExceptionState());
int32_t axis = 1;
uint32_t axis = 1;
auto* output =
builder->concat({input_a, input_b}, axis, scope.GetExceptionState());
EXPECT_EQ(output, nullptr);
Expand Down

0 comments on commit 3599a69

Please sign in to comment.