Skip to content

Commit

Permalink
Generalize input/output number check in shape inference (#6005)
Browse files Browse the repository at this point in the history
### Description

- Add `OpSchema::VerifyInputNum` and `OpSchema::VerifyOutputNum`
methods. The logic has been moved from `OpSchema::Verify`.
- Replaced the input/output number check logic in
`OpSchema::CheckInputOutputType` with calls to the new interfaces.
- Without using the `Node(domain::op_type::version)` string pattern in
`OpSchema::CheckInputOutputType` when calling `fail_check`, as the
`op_type` will be displayed by the shape inference common exception.
Otherwise, it would appear redundant.

### Motivation and Context

follow-up #5990

resolve #5993

---------

Signed-off-by: opluss <opluss@qq.com>
  • Loading branch information
OYCN committed Mar 19, 2024
1 parent 17dbae7 commit 0bb2775
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 59 deletions.
112 changes: 53 additions & 59 deletions onnx/defs/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "onnx/defs/schema.h"

#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_set>
#include <utility>

Expand Down Expand Up @@ -107,30 +109,10 @@ OpSchemaRegistry* OpSchemaRegistry::Instance() {

void OpSchema::CheckInputOutputType(struct InferenceContext& ctx) const {
std::unordered_map<std::string, std::string> type_constraints;
if (inputs_.empty() && ctx.getNumInputs() > 0) {
fail_check(
"Node (",
domain(),
"::",
Name(),
":",
since_version(),
") takes zero inputs, but got ",
ctx.getNumInputs(),
" in graph");
}
if (outputs_.empty() && ctx.getNumOutputs() > 0) {
fail_check(
"Node (",
domain(),
"::",
Name(),
":",
since_version(),
") yields zero outputs, but got ",
ctx.getNumOutputs(),
" in graph");
}
// Check the number of inputs / output.
VerifyInputNum(ctx.getNumInputs());
VerifyOutputNum(ctx.getNumOutputs());

// check all input types
for (size_t in_idx = 0; in_idx < ctx.getNumInputs(); ++in_idx) {
// If the last input is Variadic by definition, checker still needs to check the rest of actual input's type
Expand Down Expand Up @@ -200,41 +182,8 @@ void OpSchema::Verify(const NodeProto& node) const {
fail_check("Operator '", name_, "' has been deprecated since version ", since_version_);
}

// Check the number of inputs.
if (node.input_size() < min_input_ || node.input_size() > max_input_) {
fail_check(
"Node (",
node.name(),
") has input size ",
node.input_size(),
" not in range [min=",
min_input_,
", max=",
max_input_,
"].");
}

if (!num_inputs_allowed_(node.input_size())) {
fail_check("Node (", node.name(), ") has input size ", node.input_size(), " not in allowed input sizes.");
}

// Check the number of outputs.
if (node.output_size() < min_output_ || node.output_size() > max_output_) {
fail_check(
"Node (",
node.name(),
") has output size ",
node.output_size(),
" not in range [min=",
min_output_,
", max=",
max_output_,
"].");
}

if (!num_outputs_allowed_(node.output_size())) {
fail_check("Node (", node.name(), "has output size ", node.output_size(), " not in allowed output sizes.");
}
VerifyInputNum(node.input_size(), node.name());
VerifyOutputNum(node.output_size(), node.name());

// Check the values of inputs / outputs
for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) {
Expand Down Expand Up @@ -381,6 +330,51 @@ void OpSchema::Verify(const NodeProto& node) const {
// Phew. All verifications passed.
}

std::string OpSchema::VerifyFailPrefix(std::string_view node_name) const {
std::string str = "Node";
if (!node_name.empty()) {
str = str + "(" + std::string(node_name) + ")";
}
str = str + " with schema(" + domain() + "::" + Name() + ":" + std::to_string(since_version()) + ")";
return str;
}

void OpSchema::VerifyInputNum(int input_num, std::string_view node_name) const {
if (input_num < min_input_ || input_num > max_input_) {
fail_check(
VerifyFailPrefix(node_name),
" has input size ",
input_num,
" not in range [min=",
min_input_,
", max=",
max_input_,
"].");
}

if (!num_inputs_allowed_(input_num)) {
fail_check(VerifyFailPrefix(node_name), " has input size ", input_num, " not in allowed input sizes.");
}
}

void OpSchema::VerifyOutputNum(int output_num, std::string_view node_name) const {
if (output_num < min_output_ || output_num > max_output_) {
fail_check(
VerifyFailPrefix(node_name),
" has output size ",
output_num,
" not in range [min=",
min_output_,
", max=",
max_output_,
"].");
}

if (!num_outputs_allowed_(output_num)) {
fail_check(VerifyFailPrefix(node_name), " has output size ", output_num, " not in allowed output sizes.");
}
}

OpSchema& OpSchema::SinceVersion(OperatorSetVersion v) {
since_version_ = v;

Expand Down
22 changes: 22 additions & 0 deletions onnx/defs/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <ostream>
#include <set>
#include <string>
#include <string_view>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -1098,6 +1099,27 @@ class OpSchema final {
std::set<std::string>* updated_ops = nullptr) const;
void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const;

/**
* @brief A common function to generate a prefix string for use in fail_check during the verify function.
* @param node_name If empty, the returned string will not include the node name.
* @return std::string The prefix string.
*/
std::string VerifyFailPrefix(std::string_view node_name) const;

/**
* @brief Verifies if the input number matches the pattern specified in the schema.
* @param input_num The number of inputs to be verified against the schema.
* @param node_info The prefix string used if the check fails.
*/
void VerifyInputNum(int input_num, std::string_view node_name = "") const;

/**
* @brief Verifies if the output number matches the pattern specified in the schema.
* @param output_num The number of outputs to be verified against the schema.
* @param node_info The prefix string used if the check fails.
*/
void VerifyOutputNum(int output_num, std::string_view node_name = "") const;

std::string name_;
std::string file_;
std::string doc_;
Expand Down

0 comments on commit 0bb2775

Please sign in to comment.