Skip to content

Commit

Permalink
Use WebNN constant + slice for Shape node
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Jun 14, 2023
1 parent 9eda260 commit 5002d00
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ static const InlinedHashMap<std::string, std::string> op_map = {
{"ReduceMean", "reduceMean"},
{"Reshape", "reshape"},
{"Resize", "resample2d"},
{"Shape", "constant"},
{"Shape", "slice"},
{"Split", "split"},
{"Transpose", "transpose"},
{"Unsqueeze", "unsqueeze"},
Expand Down
36 changes: 17 additions & 19 deletions onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,32 +33,30 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
const auto rank = static_cast<int32_t>(input_shape.size());

emscripten::val desc = emscripten::val::object();
desc.set("type", emscripten::val("int64"));
emscripten::val dims = emscripten::val::array();
dims.call<void>("push", rank);
desc.set("dimensions", dims);
emscripten::val shape_buffer = emscripten::val::global("BigInt64Array").new_(emscripten::val::array(input_shape));
emscripten::val shape_constant = model_builder.GetBuilder().call<emscripten::val>("constant", desc, shape_buffer);

NodeAttrHelper helper(node);
auto true_start = helper.Get("start", 0);
auto true_end = helper.Get("end", rank);

// Deal with negative(s) and clamp.
true_start = true_start < 0 ? true_start + rank : true_start;
true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start);

true_end = true_end < 0 ? true_end + rank : true_end;
true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end);
true_start = std::clamp(true_start + (true_start < 0 ? rank : 0), 0, rank);
true_end = std::clamp(true_end + (true_end < 0 ? rank : 0), true_start, rank);
auto slice_length = true_end - true_start;

emscripten::val new_shape = emscripten::val::array(input_shape);
emscripten::val starts = emscripten::val::array();
starts.call<void>("push", true_start);
emscripten::val sizes = emscripten::val::array();
sizes.call<void>("push", slice_length);

// Slice the input shape if start or end attribute exists.
new_shape = new_shape.call<emscripten::val>("slice", true_start, true_end);

emscripten::val desc = emscripten::val::object();
desc.set("type", emscripten::val("int64"));
emscripten::val dims = emscripten::val::array();
auto slice_length = true_end < true_start ? 0 : (true_end - true_start);
dims.call<void>("push", slice_length);
desc.set("dimensions", dims);
emscripten::val output_buffer = emscripten::val::global("BigInt64Array").new_(new_shape);
// Since WebNN doesn't support Shape op, we calculate the Shape output and pass values to
// WebNN's constant op as workaround.
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("constant", desc, output_buffer);
// Since WebNN doesn't support Shape op, we use constant + slice ops as workaround.
emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("slice", shape_constant, starts, sizes);

model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
Expand Down

0 comments on commit 5002d00

Please sign in to comment.