Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Support latest WebNN softmax op #20827

Merged
merged 2 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Softsign | ai.onnx(7+) | softsign | ✗ | ✓ | |
| Sin | ai.onnx(7+) | sin | ✗ | ✓ | |
| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice | ✓ | ✓ | Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 |
| Softmax | ai.onnx(7-10, 11-12, 13+) | softmax | ✓ | ✓ | Only supports input rank >= 2 |
| Softmax | ai.onnx(7-10, 11-12, 13+) | softmax | ✓ | ✓ | |
| Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split | ✓ | ✓ | Input 'split' if present should be a constant |
| Sqrt | ai.onnx(7-12, 13+) | sqrt | ✓ | ✓ | |
| Squeeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | Input 'axes' if present should be a constant |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,82 +33,16 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
emscripten::val input = model_builder.GetOperand(node.InputDefs()[0]->Name());
emscripten::val output = emscripten::val::object();
std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape");
const auto input_size = input_shape.size();

NodeAttrHelper helper(node);
if (node.SinceVersion() < 13) {
int32_t axis = helper.Get("axis", 1);
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_size));
// Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}].
if (input_size != 2) {
int32_t first_dim = static_cast<int32_t>(std::reduce(input_shape.begin(), input_shape.begin() + axis,
1, std::multiplies<int64_t>()));
int32_t second_dim = static_cast<int32_t>(std::reduce(input_shape.begin() + axis, input_shape.end(),
1, std::multiplies<int64_t>()));
emscripten::val new_shape = emscripten::val::array(std::vector<int32_t>{first_dim, second_dim});
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, new_shape);
}

output = model_builder.GetBuilder().call<emscripten::val>("softmax", input);

// Reshape output to the same shape of input.
if (input_size != 2) {
emscripten::val new_shape = emscripten::val::array();
for (size_t i = 0; i < input_size; i++) {
new_shape.call<void>("push", static_cast<int32_t>(input_shape[i]));
}
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, new_shape);
}
} else {
int32_t axis = helper.Get("axis", -1);
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_size));
// Wraparound for transpose the target axis to the last.
// WebNN compute the softmax values of the 2-D input tensor along axis 1.
// https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax-method
if (axis != static_cast<int>(input_shape.size() - 1)) {
emscripten::val options = emscripten::val::object();
std::vector<uint32_t> permutation(input_shape.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::rotate(permutation.begin() + axis, permutation.begin() + axis + 1, permutation.end());
options.set("permutation", emscripten::val::array(permutation));
input = model_builder.GetBuilder().call<emscripten::val>("transpose", input, options);
}
// Wraparound for reshape input tensor to 2-D.
if (input_shape.size() != 2) {
uint32_t first_dim = static_cast<uint32_t>(std::reduce(input_shape.begin(), input_shape.begin() + axis,
1, std::multiplies<int64_t>()));
first_dim *= static_cast<uint32_t>(std::reduce(input_shape.begin() + axis + 1, input_shape.end(),
1, std::multiplies<int64_t>()));
uint32_t second_dim = static_cast<uint32_t>(input_shape[axis]);
emscripten::val new_shape = emscripten::val::array(std::vector<uint32_t>{first_dim, second_dim});
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, new_shape);
}

output = model_builder.GetBuilder().call<emscripten::val>("softmax", input);

// Restore from 2-D to the original shape.
if (input_shape.size() != 2) {
std::vector<uint32_t> new_shape;
std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return static_cast<uint32_t>(dim); });
std::transform(input_shape.begin() + axis + 1, input_shape.end(), std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return static_cast<uint32_t>(dim); });
new_shape.push_back(static_cast<int32_t>(input_shape[axis]));
output = model_builder.GetBuilder().call<emscripten::val>("reshape",
output, emscripten::val::array(new_shape));
}
// Restore the corresponding axis back to the initial position from the last position.
if (axis != static_cast<int>(input_shape.size() - 1)) {
emscripten::val options = emscripten::val::object();
std::vector<uint32_t> permutation(input_shape.size());
std::iota(permutation.begin(), permutation.end(), 0);
std::rotate(permutation.rbegin(), permutation.rbegin() + 1, permutation.rend() - axis);
options.set("permutation", emscripten::val::array(permutation));
output = model_builder.GetBuilder().call<emscripten::val>("transpose", output, options);
}
}
const int32_t default_axis = node.SinceVersion() < 13 ? 1 : -1;
int32_t axis = helper.Get("axis", default_axis);
axis = static_cast<int32_t>(HandleNegativeAxis(axis, input_size));

emscripten::val output = model_builder.GetBuilder().call<emscripten::val>("softmax", input, axis);
model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output));
return Status::OK();
}
Expand All @@ -123,12 +57,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger))
return false;
const auto input_size = input_shape.size();
if (input_size < 2) {
LOGS(logger, VERBOSE) << "SoftMax only support input size >= 2d shape, input is "
<< input_size << "d shape";
return false;
}

return true;
}
Expand Down
Loading