Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class BinaryOpBuilder : public BaseOpBuilder {
private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -50,6 +54,24 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
return Status::OK();
}

bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType device_type,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();

// XNNPACK prelu operator expects slope to be a static value.
// https://github.com/google/XNNPACK/issues/4692
// TODO: Remove this check after it is solved.
if (op_type == "PRelu" && !Contains(initializers, input_defs[1]->Name()) && device_type == WebnnDeviceType::CPU) {
LOGS(logger, VERBOSE) << "The second input (slope) for PRelu must be a constant initializer for WebNN CPU backend.";
return false;
}

return true;
}

void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down