Skip to content

Commit

Permalink
Parsing NodeProto and handling optional inputs of functions (#5066)
Browse files Browse the repository at this point in the history
* A couple of fixes

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Missed ref

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* Update onnx/parser.py

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: G. Ramalingam <grama@microsoft.com>

* Add python unit test

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* lint changes

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

* lint changes

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>

---------

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: G. Ramalingam <grama@microsoft.com>
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Andreas Fehlner <fehlner@arcor.de>
  • Loading branch information
3 people committed Apr 3, 2023
1 parent 50eacec commit 44ce303
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 8 deletions.
1 change: 1 addition & 0 deletions onnx/cpp2py_export.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ PYBIND11_MODULE(onnx_cpp2py_export, onnx_cpp2py_export) {
parser.def("parse_model", Parse<ModelProto>);
parser.def("parse_graph", Parse<GraphProto>);
parser.def("parse_function", Parse<FunctionProto>);
parser.def("parse_node", Parse<NodeProto>);

// Submodule `printer`
auto printer = onnx_cpp2py_export.def_submodule("printer");
Expand Down
21 changes: 19 additions & 2 deletions onnx/defs/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,12 @@ Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypePr
return Status::OK();
}

bool OnnxParser::NextIsIdentifier() {
std::string id("");
(void)PeekIdentifier(id);
return !(id.empty());
}

bool OnnxParser::NextIsType() {
std::string id("");
(void)PeekIdentifier(id);
Expand All @@ -413,8 +419,19 @@ Status OnnxParser::ParseSingleAttributeValue(AttributeProto& attr) {
auto next = NextChar();
if (isalpha(next) || next == '_') {
if (NextIsType()) {
attr.set_type(AttributeProto_AttributeType_TENSOR);
Parse(*attr.mutable_t());
TypeProto typeProto;
Parse(typeProto);
next = NextChar();
if ((next == '{') || (next == '=') || (NextIsIdentifier())) {
attr.set_type(AttributeProto_AttributeType_TENSOR);
auto& tensorProto = *attr.mutable_t();
ParseOptionalIdentifier(*tensorProto.mutable_name());
(void)Matches('='); // Optional, to unify handling of initializers
Parse(tensorProto, typeProto);
} else {
attr.set_type(AttributeProto_AttributeType_TYPE_PROTO);
attr.mutable_tp()->CopyFrom(typeProto);
}
} else {
attr.set_type(AttributeProto_AttributeType_GRAPH);
Parse(*attr.mutable_g());
Expand Down
2 changes: 2 additions & 0 deletions onnx/defs/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,8 @@ class OnnxParser : public ParserBase {
Status Parse(OpsetIdList& opsets);

bool NextIsType();

bool NextIsIdentifier();
};

} // namespace ONNX_NAMESPACE
6 changes: 6 additions & 0 deletions onnx/defs/printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,12 @@ void ProtoPrinter::print(const AttributeProto& attr) {
case AttributeProto_AttributeType_TENSORS:
printSet("[", ", ", "]", attr.tensors());
break;
case AttributeProto_AttributeType_TYPE_PROTO:
print(attr.tp());
break;
case AttributeProto_AttributeType_TYPE_PROTOS:
printSet("[", ", ", "]", attr.type_protos());
break;
default:
break;
}
Expand Down
9 changes: 9 additions & 0 deletions onnx/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,15 @@ def get_attribute_value(attr: AttributeProto) -> Any:
raise ValueError(f"Unsupported ONNX attribute: {attr}")


def get_node_attr_value(node: NodeProto, attr_name: str) -> Any:
matching = [x for x in node.attribute if x.name == attr_name]
if len(matching) > 1:
raise ValueError(f"Node has multiple attributes with name {attr_name}")
if len(matching) < 1:
raise ValueError(f"Node has no attribute with name {attr_name}")
return get_attribute_value(matching[0])


def make_empty_tensor_value_info(name: str) -> ValueInfoProto:
value_info_proto = ValueInfoProto()
value_info_proto.name = name
Expand Down
8 changes: 8 additions & 0 deletions onnx/onnx_cpp2py_export/parser.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ def parse_function(function: str) -> Tuple[bool, bytes, bytes]:
Otherwise, error-message contains a string describing the parse error.
"""
...

def parse_node(node: str) -> Tuple[bool, bytes, bytes]:
"""
Returns (success-flag, error-message, serialized-proto).
If success-flag is true, then serialized-proto contains the parsed NodeProto.
Otherwise, error-message contains a string describing the parse error.
"""
...
16 changes: 16 additions & 0 deletions onnx/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,19 @@ def parse_function(function_text: str) -> onnx.FunctionProto:
function_proto.ParseFromString(function_proto_str)
return function_proto
raise ParseError(msg)


def parse_node(node_text: str) -> onnx.NodeProto:
"""Parse a string to build a NodeProto.
Arguments:
node_text: formatted string
Returns:
NodeProto
"""
(success, msg, node_proto_str) = C.parse_node(node_text)
if success:
node_proto = onnx.NodeProto()
node_proto.ParseFromString(node_proto_str)
return node_proto
raise ParseError(msg)
18 changes: 12 additions & 6 deletions onnx/shape_inference/implementation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -569,18 +569,24 @@ class ShapeInferenceImplBase {
reuse_constant_tensors = false;

// Get a temporary tensor-shape map
const int num_actual_inputs = static_cast<int>(ctx.getNumInputs());
const auto num_func_inputs = func_proto.input_size();
std::vector<TypeProto> types_cache(num_func_inputs);
for (int i = 0; i < num_func_inputs; ++i) {
if (ctx.getInputType(i) == nullptr) {
fail_type_inference("Input ", i, " type is missing.");
}
types_cache[i] = *ctx.getInputType(i); // TODO: investigate whether we can remove cache
value_types_by_name[func_proto.input().Get(i)] = &types_cache[i];
auto& parameter_name = func_proto.input().Get(i);
auto* type_ptr = (i < num_actual_inputs) ? ctx.getInputType(i) : nullptr;
// nullptr is valid, and indicates a missing optional input
if (type_ptr != nullptr) {
// Use a temporary copy of original type.
// TODO: investigate whether we can eliminate use of temporary copy
types_cache[i] = *type_ptr;
value_types_by_name[parameter_name] = &types_cache[i];
} else
value_types_by_name[parameter_name] = nullptr;
}

// Create a temporary initializer value map
for (int i = 0; i < static_cast<int>(ctx.getNumInputs()) && i < num_func_inputs; ++i) {
for (int i = 0; i < num_actual_inputs && i < num_func_inputs; ++i) {
const TypeProto* type = ctx.getInputType(i);
if (type->value_case() == TypeProto::kTensorType && ctx.getInputData(i) != nullptr) {
input_data_by_name[func_proto.input().Get(i)] = ctx.getInputData(i);
Expand Down
6 changes: 6 additions & 0 deletions onnx/test/cpp/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ TEST(ParserTest, AttributeTest) {
)ONNX");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_GRAPH);
EXPECT_EQ(attr.g().node_size(), 2);

Parse(attr, "type = float[3]");
EXPECT_EQ(attr.type(), AttributeProto_AttributeType::AttributeProto_AttributeType_TYPE_PROTO);
EXPECT_TRUE(attr.tp().has_tensor_type());
int float_type = static_cast<int>(TensorProto_DataType::TensorProto_DataType_FLOAT);
EXPECT_EQ(attr.tp().tensor_type().elem_type(), float_type);
}

TEST(ParserTest, AttrListTest) {
Expand Down
12 changes: 12 additions & 0 deletions onnx/test/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,18 @@ def expect_model_function_attribute(model):
expect_model_function_attribute(model)
expect_custom_node_attribute(model.graph.node[0], expected_attribute)

def test_parse_node(self):
node = onnx.parser.parse_node(
"out1, out2 = SomeDomain.SomeOp <attr1 = 1> (in1, in2)"
)
self.assertEqual(list(node.input), ["in1", "in2"])
self.assertEqual(list(node.output), ["out1", "out2"])
self.assertEqual(len(node.attribute), 1)
attr_val = onnx.helper.get_node_attr_value(node, "attr1")
self.assertEqual(attr_val, 1)
self.assertEqual(node.domain, "SomeDomain")
self.assertEqual(node.op_type, "SomeOp")


if __name__ == "__main__":
unittest.main()

0 comments on commit 44ce303

Please sign in to comment.