diff --git a/onnx/defs/schema.cc b/onnx/defs/schema.cc index 257fddd3604..74bd0a63043 100644 --- a/onnx/defs/schema.cc +++ b/onnx/defs/schema.cc @@ -5,6 +5,8 @@ #include "onnx/defs/schema.h" #include +#include +#include #include #include @@ -107,30 +109,10 @@ OpSchemaRegistry* OpSchemaRegistry::Instance() { void OpSchema::CheckInputOutputType(struct InferenceContext& ctx) const { std::unordered_map type_constraints; - if (inputs_.empty() && ctx.getNumInputs() > 0) { - fail_check( - "Node (", - domain(), - "::", - Name(), - ":", - since_version(), - ") takes zero inputs, but got ", - ctx.getNumInputs(), - " in graph"); - } - if (outputs_.empty() && ctx.getNumOutputs() > 0) { - fail_check( - "Node (", - domain(), - "::", - Name(), - ":", - since_version(), - ") yields zero outputs, but got ", - ctx.getNumOutputs(), - " in graph"); - } + // Check the number of inputs / output. + VerifyInputNum(ctx.getNumInputs()); + VerifyOutputNum(ctx.getNumOutputs()); + // check all input types for (size_t in_idx = 0; in_idx < ctx.getNumInputs(); ++in_idx) { // If the last input is Variadic by definition, checker still needs to check the rest of actual input's type @@ -200,41 +182,8 @@ void OpSchema::Verify(const NodeProto& node) const { fail_check("Operator '", name_, "' has been deprecated since version ", since_version_); } - // Check the number of inputs. - if (node.input_size() < min_input_ || node.input_size() > max_input_) { - fail_check( - "Node (", - node.name(), - ") has input size ", - node.input_size(), - " not in range [min=", - min_input_, - ", max=", - max_input_, - "]."); - } - - if (!num_inputs_allowed_(node.input_size())) { - fail_check("Node (", node.name(), ") has input size ", node.input_size(), " not in allowed input sizes."); - } - - // Check the number of outputs. - if (node.output_size() < min_output_ || node.output_size() > max_output_) { - fail_check( - "Node (", - node.name(), - ") has output size ", - node.output_size(), - " not in range [min=", - min_output_, - ", max=", - max_output_, - "]."); - } - - if (!num_outputs_allowed_(node.output_size())) { - fail_check("Node (", node.name(), "has output size ", node.output_size(), " not in allowed output sizes."); - } + VerifyInputNum(node.input_size(), node.name()); + VerifyOutputNum(node.output_size(), node.name()); // Check the values of inputs / outputs for (int in_idx = 0; in_idx < node.input_size(); ++in_idx) { @@ -381,6 +330,51 @@ void OpSchema::Verify(const NodeProto& node) const { // Phew. All verifications passed. } +std::string OpSchema::VerifyFailPrefix(std::string_view node_name) const { + std::string str = "Node"; + if (!node_name.empty()) { + str = str + "(" + std::string(node_name) + ")"; + } + str = str + " with schema(" + domain() + "::" + Name() + ":" + std::to_string(since_version()) + ")"; + return str; +} + +void OpSchema::VerifyInputNum(int input_num, std::string_view node_name) const { + if (input_num < min_input_ || input_num > max_input_) { + fail_check( + VerifyFailPrefix(node_name), + " has input size ", + input_num, + " not in range [min=", + min_input_, + ", max=", + max_input_, + "]."); + } + + if (!num_inputs_allowed_(input_num)) { + fail_check(VerifyFailPrefix(node_name), " has input size ", input_num, " not in allowed input sizes."); + } +} + +void OpSchema::VerifyOutputNum(int output_num, std::string_view node_name) const { + if (output_num < min_output_ || output_num > max_output_) { + fail_check( + VerifyFailPrefix(node_name), + " has output size ", + output_num, + " not in range [min=", + min_output_, + ", max=", + max_output_, + "]."); + } + + if (!num_outputs_allowed_(output_num)) { + fail_check(VerifyFailPrefix(node_name), " has output size ", output_num, " not in allowed output sizes."); + } +} + OpSchema& OpSchema::SinceVersion(OperatorSetVersion v) { since_version_ = v; diff --git a/onnx/defs/schema.h b/onnx/defs/schema.h index 3659ae74ca9..e786a96b299 100644 --- a/onnx/defs/schema.h +++ b/onnx/defs/schema.h @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -1098,6 +1099,27 @@ class OpSchema final { std::set* updated_ops = nullptr) const; void UpdateFunctionProtoOpsetImportVersion(FunctionProto& function_proto, int opset_version) const; + /** + * @brief A common function to generate a prefix string for use in fail_check during the verify function. + * @param node_name If empty, the returned string will not include the node name. + * @return std::string The prefix string. + */ + std::string VerifyFailPrefix(std::string_view node_name) const; + + /** + * @brief Verifies if the input number matches the pattern specified in the schema. + * @param input_num The number of inputs to be verified against the schema. + * @param node_info The prefix string used if the check fails. + */ + void VerifyInputNum(int input_num, std::string_view node_name = "") const; + + /** + * @brief Verifies if the output number matches the pattern specified in the schema. + * @param output_num The number of outputs to be verified against the schema. + * @param node_info The prefix string used if the check fails. + */ + void VerifyOutputNum(int output_num, std::string_view node_name = "") const; + std::string name_; std::string file_; std::string doc_;