diff --git a/onnx/cpp2py_export.cc b/onnx/cpp2py_export.cc index 3adb6f4dee8..e1474b83c72 100644 --- a/onnx/cpp2py_export.cc +++ b/onnx/cpp2py_export.cc @@ -492,6 +492,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) { parser.def("parse_model", Parse); parser.def("parse_graph", Parse); parser.def("parse_function", Parse); + parser.def("parse_node", Parse); // Submodule `printer` auto printer = onnx_cpp2py_export.def_submodule("printer"); diff --git a/onnx/defs/parser.cc b/onnx/defs/parser.cc index c4d6ea76bf8..93ef445d7c8 100644 --- a/onnx/defs/parser.cc +++ b/onnx/defs/parser.cc @@ -392,6 +392,12 @@ Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypePr return Status::OK(); } +bool OnnxParser::NextIsIdentifier() { + std::string id(""); + (void)PeekIdentifier(id); + return !(id.empty()); +} + bool OnnxParser::NextIsType() { std::string id(""); (void)PeekIdentifier(id); @@ -413,8 +419,19 @@ Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr) { auto next = NextChar(); if (isalpha(next) || next == '_') { if (NextIsType()) { - attr.set_type(AttributeProto_AttributeType_TENSOR); - Parse(*attr.mutable_t()); + TypeProto typeProto; + Parse(typeProto); + next = NextChar(); + if ((next == '{') || (next == '=') || (NextIsIdentifier())) { + attr.set_type(AttributeProto_AttributeType_TENSOR); + auto& tensorProto = *attr.mutable_t(); + ParseOptionalIdentifier(*tensorProto.mutable_name()); + (void)Matches('='); // Optional, to unify handling of initializers + Parse(tensorProto, typeProto); + } else { + attr.set_type(AttributeProto_AttributeType_TYPE_PROTO); + attr.mutable_tp()->CopyFrom(typeProto); + } } else { attr.set_type(AttributeProto_AttributeType_GRAPH); Parse(*attr.mutable_g()); diff --git a/onnx/defs/parser.h b/onnx/defs/parser.h index 0888a51a177..03e45061279 100644 --- a/onnx/defs/parser.h +++ b/onnx/defs/parser.h @@ -432,6 +432,8 @@ class OnnxParser : public ParserBase { Status Parse(OpsetIdList& opsets); bool NextIsType(); + + bool NextIsIdentifier(); }; } // namespace ONNX_NAMESPACE diff --git a/onnx/defs/printer.cc b/onnx/defs/printer.cc index d68e9ff47e9..43a590e24d3 100644 --- a/onnx/defs/printer.cc +++ b/onnx/defs/printer.cc @@ -307,6 +307,12 @@ void ProtoPrinter::print(const AttributeProto& attr) { case AttributeProto_AttributeType_TENSORS: printSet("[", ", ", "]", attr.tensors()); break; + case AttributeProto_AttributeType_TYPE_PROTO: + print(attr.tp()); + break; + case AttributeProto_AttributeType_TYPE_PROTOS: + printSet("[", ", ", "]", attr.type_protos()); + break; default: break; } diff --git a/onnx/helper.py b/onnx/helper.py index b9c0a2128c6..abeb41a7193 100644 --- a/onnx/helper.py +++ b/onnx/helper.py @@ -658,6 +658,15 @@ def get_attribute_value(attr: AttributeProto) -> Any: raise ValueError(f"Unsupported ONNX attribute: {attr}") +def get_node_attr_value(node: NodeProto, attr_name: str) -> Any: + matching = [x for x in node.attribute if x.name == attr_name] + if len(matching) > 1: + raise ValueError(f"Node has multiple attributes with name {attr_name}") + if len(matching) < 1: + raise ValueError(f"Node has no attribute with name {attr_name}") + return get_attribute_value(matching[0]) + + def make_empty_tensor_value_info(name: str) -> ValueInfoProto: value_info_proto = ValueInfoProto() value_info_proto.name = name diff --git a/onnx/onnx_cpp2py_export/parser.pyi b/onnx/onnx_cpp2py_export/parser.pyi index 1cdb1c0a9cc..5cbed57c5c9 100644 --- a/onnx/onnx_cpp2py_export/parser.pyi +++ b/onnx/onnx_cpp2py_export/parser.pyi @@ -23,3 +23,11 @@ def parse_function(function: str) -> Tuple[bool, bytes, bytes]: Otherwise, error-message contains a string describing the parse error. """ ... + +def parse_node(node: str) -> Tuple[bool, bytes, bytes]: + """ + Returns (success-flag, error-message, serialized-proto). + If success-flag is true, then serialized-proto contains the parsed NodeProto. + Otherwise, error-message contains a string describing the parse error. + """ + ... diff --git a/onnx/parser.py b/onnx/parser.py index 5fb61d8f2e1..28e74745a65 100644 --- a/onnx/parser.py +++ b/onnx/parser.py @@ -51,3 +51,19 @@ def parse_function(function_text: str) -> onnx.FunctionProto: function_proto.ParseFromString(function_proto_str) return function_proto raise ParseError(msg) + + +def parse_node(node_text: str) -> onnx.NodeProto: + """Parse a string to build a NodeProto. + + Arguments: + node_text: formatted string + Returns: + NodeProto + """ + (success, msg, node_proto_str) = C.parse_node(node_text) + if success: + node_proto = onnx.NodeProto() + node_proto.ParseFromString(node_proto_str) + return node_proto + raise ParseError(msg) diff --git a/onnx/shape_inference/implementation.cc b/onnx/shape_inference/implementation.cc index 8604c107b41..ef6f1b3d775 100644 --- a/onnx/shape_inference/implementation.cc +++ b/onnx/shape_inference/implementation.cc @@ -569,18 +569,24 @@ class ShapeInferenceImplBase { reuse_constant_tensors = false; // Get a temporary tensor-shape map + const int num_actual_inputs = static_cast(ctx.getNumInputs()); const auto num_func_inputs = func_proto.input_size(); std::vector types_cache(num_func_inputs); for (int i = 0; i < num_func_inputs; ++i) { - if (ctx.getInputType(i) == nullptr) { - fail_type_inference("Input ", i, " type is missing."); - } - types_cache[i] = *ctx.getInputType(i); // TODO: investigate whether we can remove cache - value_types_by_name[func_proto.input().Get(i)] = &types_cache[i]; + auto& parameter_name = func_proto.input().Get(i); + auto* type_ptr = (i < num_actual_inputs) ? ctx.getInputType(i) : nullptr; + // nullptr is valid, and indicates a missing optional input + if (type_ptr != nullptr) { + // Use a temporary copy of original type. + // TODO: investigate whether we can eliminate use of temporary copy + types_cache[i] = *type_ptr; + value_types_by_name[parameter_name] = &types_cache[i]; + } else + value_types_by_name[parameter_name] = nullptr; } // Create a temporary initializer value map - for (int i = 0; i < static_cast(ctx.getNumInputs()) && i < num_func_inputs; ++i) { + for (int i = 0; i < num_actual_inputs && i < num_func_inputs; ++i) { const TypeProto* type = ctx.getInputType(i); if (type->value_case() == TypeProto::kTensorType && ctx.getInputData(i) != nullptr) { input_data_by_name[func_proto.input().Get(i)] = ctx.getInputData(i); diff --git a/onnx/test/cpp/parser_test.cc b/onnx/test/cpp/parser_test.cc index 7a89989e159..b1ee1ec00d6 100644 --- a/onnx/test/cpp/parser_test.cc +++ b/onnx/test/cpp/parser_test.cc @@ -196,6 +196,12 @@ TEST(ParserTest, AttributeTest) { )ONNX"); EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH); EXPECT_EQ(attr.g().node_size(), 2); + + Parse(attr, "type = float[3]"); + EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_TYPE_PROTO); + EXPECT_TRUE(attr.tp().has_tensor_type()); + int float_type = static_cast(TensorProto_DataType::TensorProto_DataType_FLOAT); + EXPECT_EQ(attr.tp().tensor_type().elem_type(), float_type); } TEST(ParserTest, AttrListTest) { diff --git a/onnx/test/parser_test.py b/onnx/test/parser_test.py index 7d97001d1fb..6eb3e2a5efd 100644 --- a/onnx/test/parser_test.py +++ b/onnx/test/parser_test.py @@ -207,6 +207,18 @@ def expect_model_function_attribute(model): expect_model_function_attribute(model) expect_custom_node_attribute(model.graph.node[0], expected_attribute) + def test_parse_node(self): + node = onnx.parser.parse_node( + "out1, out2 = SomeDomain.SomeOp (in1, in2)" + ) + self.assertEqual(list(node.input), ["in1", "in2"]) + self.assertEqual(list(node.output), ["out1", "out2"]) + self.assertEqual(len(node.attribute), 1) + attr_val = onnx.helper.get_node_attr_value(node, "attr1") + self.assertEqual(attr_val, 1) + self.assertEqual(node.domain, "SomeDomain") + self.assertEqual(node.op_type, "SomeOp") + if __name__ == "__main__": unittest.main()