From 66fa4c5cdf65fdc3b48ab969d9e36eef4bf92b8c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 1 Feb 2024 16:39:41 -0800 Subject: [PATCH 1/5] Add valueinfos support for functions Signed-off-by: Ganesan Ramalingam --- docs/Syntax.md | 11 +++++--- onnx/defs/parser.cc | 43 ++++++++++++++++++++++++++--- onnx/defs/parser.h | 6 ++++- onnx/helper.py | 4 +++ onnx/onnx.in.proto | 6 +++++ onnx/test/cpp/parser_test.cc | 52 ++++++++++++++++++++++++++++++++++++ 6 files changed, 113 insertions(+), 9 deletions(-) diff --git a/docs/Syntax.md b/docs/Syntax.md index 6bca64b4969..ddd8b2c0731 100644 --- a/docs/Syntax.md +++ b/docs/Syntax.md @@ -56,7 +56,9 @@ The grammar below describes the syntax: | 'optional' '(' type ')' | 'sparse_tensor' '(' tensor-type ')' value-info ::= type id value-infos ::= value-info (',' value-info)* - value-info-list ::= '(' value-infos? ')' + value-info-list ::= '(' value-infos? ') + id-or-value-info ::= type? id + id-or-value-infos ::= id-or-value-info (',' id-or-value-info)* quoted-str :== '"' ([^"])* '"' str-str :== quoted-str ':' quoted-str str-str-list :== '[' str-str (',' str-str)* ']' @@ -84,8 +86,9 @@ The grammar below describes the syntax: other-data ::= id ':' value other-data-list ::= '<' other-data (',' other-data)* '>' fun-attr-list ::= '<' id | attr (',' id | attr)* '>' - fun-input-list ::= '(' id-list ')' - fun-output-list ::= '(' id-list ')' - function ::= other-data-list? id fun-attr-list? fun-input-list '=>' fun-output-list node-list + fun-input-list ::= '(' id-or-value-infos ')' + fun-output-list ::= '(' id-or-value-infos ')' + fun-value-infos ::= ( '<' value-infos '>' )? + function ::= other-data-list? id fun-attr-list? fun-input-list '=>' fun-output-list fun-value-infos node-list model ::= other-data-list? graph function* ``` diff --git a/onnx/defs/parser.cc b/onnx/defs/parser.cc index a8e9956107e..0d9148d8602 100644 --- a/onnx/defs/parser.cc +++ b/onnx/defs/parser.cc @@ -316,12 +316,44 @@ Status OnnxParser::Parse(ValueInfoProto& valueinfo) { return Status::OK(); } -Status OnnxParser::Parse(ValueInfoList& vilist) { +Status OnnxParser::Parse(char open, ValueInfoList& vilist, char close) { + MATCH(open); + if (!Matches(close)) { + do { + PARSE(*vilist.Add()); + } while (Matches(',')); + MATCH(close); + } + return Status::OK(); +} + +Status OnnxParser::ParseGraphInputOutput(ValueInfoList& vilist) { + vilist.Clear(); + PARSE('(', vilist, ')'); + return Status::OK(); +} + +Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist) { + idlist.Clear(); vilist.Clear(); MATCH('('); if (!Matches(')')) { do { - PARSE(*vilist.Add()); + // Function inputs/outputs can be optionally typed. + // Syntax: Name | Type Name + // The name is added to idlist. If the optional type is present, an entry is + // added to vilist. + + std::string *name = *idlist.Add(); + ValueInfoProto *vi = nullptr; + + if (NextIsType()) { + vi = vilist.Add(); + PARSE(*valueinfo.mutable_type()); + } + CHECK_PARSER_STATUS(ParseIdentifier(*name)); + if (vi != nullptr) + vi->set_name(*name); } while (Matches(',')); MATCH(')'); } @@ -751,10 +783,13 @@ Status OnnxParser::Parse(FunctionProto& fn) { fn.set_name(id); PARSE('<', *fn.mutable_attribute(), *fn.mutable_attribute_proto(), '>'); - PARSE('(', *fn.mutable_input(), ')'); + CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.value_info())); MATCH('='); MATCH('>', false); - PARSE('(', *fn.mutable_output(), ')'); + CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.value_info())); + if (NextChar() == '<') { + PARSE('<', *fn.mutable_value_info(), '>'); + } return Parse(*fn.mutable_node()); } diff --git a/onnx/defs/parser.h b/onnx/defs/parser.h index 21bb777b3a1..23e8af6bcb2 100644 --- a/onnx/defs/parser.h +++ b/onnx/defs/parser.h @@ -431,7 +431,11 @@ class OnnxParser : public ParserBase { Status Parse(ValueInfoProto& valueinfo); - Status Parse(ValueInfoList& vilist); + Status ParseGraphInputOutput(ValueInfoList& vilist); + + Status ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist); + + Status Parse(char open, ValueInfoList& vilist, char close); Status ParseInput(ValueInfoList& vilist, TensorList& initializers); diff --git a/onnx/helper.py b/onnx/helper.py index 0e12dbc2207..353d296ab81 100644 --- a/onnx/helper.py +++ b/onnx/helper.py @@ -258,11 +258,14 @@ def make_function( attributes: Optional[Sequence[str]] = None, attribute_protos: Optional[Sequence[AttributeProto]] = None, doc_string: Optional[str] = None, + value_info: Optional[Sequence[ValueInfoProto]] = None, ) -> FunctionProto: if attributes is None: attributes = [] if attribute_protos is None: attribute_protos = [] + if value_info is None: + value_info = [] f = FunctionProto() f.domain = domain f.name = fname @@ -274,6 +277,7 @@ def make_function( f.attribute_proto.extend(attribute_protos) if doc_string: f.doc_string = doc_string + f.value_info.extend(value_info) return f diff --git a/onnx/onnx.in.proto b/onnx/onnx.in.proto index 9a6c85922fd..26414a1a54d 100644 --- a/onnx/onnx.in.proto +++ b/onnx/onnx.in.proto @@ -854,4 +854,10 @@ message FunctionProto { // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of // the FunctionProto. optional string domain = 10; + + // Information for the values in the function. The ValueInfoProto.name's + // must be distinct and refer to names in the function (including inputs, + // outputs, and intermediate values). It is optional for a value to appear + // in value_info list. + repeated ValueInfoProto value_info = 12; } diff --git a/onnx/test/cpp/parser_test.cc b/onnx/test/cpp/parser_test.cc index 34048d41e75..260f4c3550b 100644 --- a/onnx/test/cpp/parser_test.cc +++ b/onnx/test/cpp/parser_test.cc @@ -363,6 +363,58 @@ f (y, z) => (w) EXPECT_EQ(fp.opset_import_size(), 1); } +TEST(ParserTest, FunctionValueInfoTest) { + const char* code = R"ONNX( +< + opset_import: [ "" : 10 ], + domain: "ai.onnx.ml", + doc_string: "A function test case." +> +f (float[N] y, float[N] z) => (float[N] w) +{ + x = Add(y, z) + w = Mul(x, y) +} +)ONNX"; + + FunctionProto fp; + Parse(fp, code); + + EXPECT_EQ(fp.input_size(), 2); + EXPECT_EQ(fp.value_info_size(), 3); + EXPECT_EQ(fp.output_size(), 1); + EXPECT_EQ(fp.value_info(0).name(), "y"); + EXPECT_EQ(fp.value_info(1).name(), "z"); + EXPECT_EQ(fp.value_info(2).name(), "w"); +} + +TEST(ParserTest, FunctionValueInfoTest2) { + const char* code = R"ONNX( +< + opset_import: [ "" : 10 ], + domain: "ai.onnx.ml", + doc_string: "A function test case." +> +f (float[N] y, float[N] z) => (float[N] w) + +{ + x = Add(y, z) + w = Mul(x, y) +} +)ONNX"; + + FunctionProto fp; + Parse(fp, code); + + EXPECT_EQ(fp.input_size(), 2); + EXPECT_EQ(fp.value_info_size(), 4); + EXPECT_EQ(fp.output_size(), 1); + EXPECT_EQ(fp.value_info(0).name(), "y"); + EXPECT_EQ(fp.value_info(1).name(), "z"); + EXPECT_EQ(fp.value_info(2).name(), "w"); + EXPECT_EQ(fp.value_info(3).name(), "x"); +} + TEST(ParserTest, InitializerTest) { const char* code = R"ONNX( agraph (float y = {1.0}, float[N] z) => (float[N] w) From 1188ce48ca89685c13fc389a0ee23c2ab4617842 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 2 Feb 2024 09:28:00 -0800 Subject: [PATCH 2/5] Minor fixes Signed-off-by: Ganesan Ramalingam --- onnx/defs/parser.cc | 13 +++++++------ onnx/test/cpp/parser_test.cc | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/onnx/defs/parser.cc b/onnx/defs/parser.cc index 0d9148d8602..3d21a7efa19 100644 --- a/onnx/defs/parser.cc +++ b/onnx/defs/parser.cc @@ -334,8 +334,8 @@ Status OnnxParser::ParseGraphInputOutput(ValueInfoList& vilist) { } Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist) { + // Do not clear vilist, as it accumulates values over inputs and outputs. idlist.Clear(); - vilist.Clear(); MATCH('('); if (!Matches(')')) { do { @@ -344,12 +344,12 @@ Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilis // The name is added to idlist. If the optional type is present, an entry is // added to vilist. - std::string *name = *idlist.Add(); + std::string *name = idlist.Add(); ValueInfoProto *vi = nullptr; if (NextIsType()) { vi = vilist.Add(); - PARSE(*valueinfo.mutable_type()); + PARSE(*(vi->mutable_type())); } CHECK_PARSER_STATUS(ParseIdentifier(*name)); if (vi != nullptr) @@ -747,7 +747,7 @@ Status OnnxParser::Parse(std::string name, GraphProto& graph) { CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer())); MATCH('='); MATCH('>', false); - PARSE(*graph.mutable_output()); + CHECK_PARSER_STATUS(ParseGraphInputOutput(*graph.mutable_output())); CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer())); return Parse(*graph.mutable_node()); } @@ -783,10 +783,11 @@ Status OnnxParser::Parse(FunctionProto& fn) { fn.set_name(id); PARSE('<', *fn.mutable_attribute(), *fn.mutable_attribute_proto(), '>'); - CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.value_info())); + fn.mutable_value_info()->Clear(); + CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.mutable_value_info())); MATCH('='); MATCH('>', false); - CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.value_info())); + CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.mutable_value_info())); if (NextChar() == '<') { PARSE('<', *fn.mutable_value_info(), '>'); } diff --git a/onnx/test/cpp/parser_test.cc b/onnx/test/cpp/parser_test.cc index 260f4c3550b..e347475cc42 100644 --- a/onnx/test/cpp/parser_test.cc +++ b/onnx/test/cpp/parser_test.cc @@ -381,8 +381,8 @@ f (float[N] y, float[N] z) => (float[N] w) Parse(fp, code); EXPECT_EQ(fp.input_size(), 2); - EXPECT_EQ(fp.value_info_size(), 3); EXPECT_EQ(fp.output_size(), 1); + ASSERT_EQ(fp.value_info_size(), 3); EXPECT_EQ(fp.value_info(0).name(), "y"); EXPECT_EQ(fp.value_info(1).name(), "z"); EXPECT_EQ(fp.value_info(2).name(), "w"); From 637768faf70b5661c965cf6b2fb2dd7a3f058e74 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 2 Feb 2024 10:12:14 -0800 Subject: [PATCH 3/5] Update IR documentation Signed-off-by: Ganesan Ramalingam --- docs/IR.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/IR.md b/docs/IR.md index 3255f86549e..aabcf353a73 100644 --- a/docs/IR.md +++ b/docs/IR.md @@ -200,6 +200,7 @@ input|string[]|The input parameters of the function output|string[]|The output parameters of the function. node|Node[]|A list of nodes, forming a partially ordered computation graph. It must be in topological order. |opset_import|OperatorSetId|A collection of operator set identifiers used by the function implementation. +value_info|ValueInfo[]|Used to store the type and shape information of values used in the function. The name and domain serve to identify the operator uniquely. An opset version is not explicitly identified in a FunctionProto, but it is implicitly determined by the opset version of the domain @@ -210,6 +211,8 @@ is explicitly included in the signature. The attribute_proto field describes att The opset_import and node fields describe the implementation of the function. +The value_info field (added in IR version 10) allows a model to store type and shape information about the values used in a function, including its inputs and outputs. Note that this is optional, and ONNX allows functions to be polymorphic. + ### Graphs A graph is used to describe a side-effect-free computation (function). From f57147dbeee057195cc9f237abd9c967191fabe1 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 2 Feb 2024 13:09:28 -0800 Subject: [PATCH 4/5] Add another test case Signed-off-by: Ganesan Ramalingam --- onnx/defs/parser.cc | 6 +++--- onnx/test/cpp/parser_test.cc | 30 +++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/onnx/defs/parser.cc b/onnx/defs/parser.cc index 3d21a7efa19..34bdc76d0c5 100644 --- a/onnx/defs/parser.cc +++ b/onnx/defs/parser.cc @@ -344,9 +344,9 @@ Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilis // The name is added to idlist. If the optional type is present, an entry is // added to vilist. - std::string *name = idlist.Add(); - ValueInfoProto *vi = nullptr; - + std::string* name = idlist.Add(); + ValueInfoProto* vi = nullptr; + if (NextIsType()) { vi = vilist.Add(); PARSE(*(vi->mutable_type())); diff --git a/onnx/test/cpp/parser_test.cc b/onnx/test/cpp/parser_test.cc index e347475cc42..05fc3508f1a 100644 --- a/onnx/test/cpp/parser_test.cc +++ b/onnx/test/cpp/parser_test.cc @@ -408,13 +408,41 @@ f (float[N] y, float[N] z) => (float[N] w) EXPECT_EQ(fp.input_size(), 2); EXPECT_EQ(fp.value_info_size(), 4); - EXPECT_EQ(fp.output_size(), 1); + ASSERT_EQ(fp.output_size(), 1); EXPECT_EQ(fp.value_info(0).name(), "y"); EXPECT_EQ(fp.value_info(1).name(), "z"); EXPECT_EQ(fp.value_info(2).name(), "w"); EXPECT_EQ(fp.value_info(3).name(), "x"); } +TEST(ParserTest, FunctionValueInfoTest3) { + const char* code = R"ONNX( +< + opset_import: [ "" : 10 ], + domain: "ai.onnx.ml", + doc_string: "A function test case." +> +f (float[N] y, z) => (float[N] w) + +{ + x = Add(y, z) + t = Add(x, x) + w = Mul(t, y) +} +)ONNX"; + + FunctionProto fp; + Parse(fp, code); + + EXPECT_EQ(fp.input_size(), 2); + ASSERT_EQ(fp.value_info_size(), 4); + EXPECT_EQ(fp.output_size(), 1); + EXPECT_EQ(fp.value_info(0).name(), "y"); + EXPECT_EQ(fp.value_info(1).name(), "w"); + EXPECT_EQ(fp.value_info(2).name(), "x"); + EXPECT_EQ(fp.value_info(3).name(), "t"); +} + TEST(ParserTest, InitializerTest) { const char* code = R"ONNX( agraph (float y = {1.0}, float[N] z) => (float[N] w) From 76dd899eaeafc83b712e87860e52a6ba9597acc4 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 5 Feb 2024 16:10:47 -0800 Subject: [PATCH 5/5] Generate proto files Signed-off-by: Ganesan Ramalingam --- onnx/onnx-ml.proto | 6 ++++++ onnx/onnx-ml.proto3 | 6 ++++++ onnx/onnx.proto | 6 ++++++ onnx/onnx.proto3 | 6 ++++++ 4 files changed, 24 insertions(+) diff --git a/onnx/onnx-ml.proto b/onnx/onnx-ml.proto index 5e3ef1aebc5..fadd86d019a 100644 --- a/onnx/onnx-ml.proto +++ b/onnx/onnx-ml.proto @@ -853,6 +853,12 @@ message FunctionProto { // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of // the FunctionProto. optional string domain = 10; + + // Information for the values in the function. The ValueInfoProto.name's + // must be distinct and refer to names in the function (including inputs, + // outputs, and intermediate values). It is optional for a value to appear + // in value_info list. + repeated ValueInfoProto value_info = 12; } // For using protobuf-lite diff --git a/onnx/onnx-ml.proto3 b/onnx/onnx-ml.proto3 index 95dc5021f8c..958f685c1bf 100644 --- a/onnx/onnx-ml.proto3 +++ b/onnx/onnx-ml.proto3 @@ -853,6 +853,12 @@ message FunctionProto { // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of // the FunctionProto. string domain = 10; + + // Information for the values in the function. The ValueInfoProto.name's + // must be distinct and refer to names in the function (including inputs, + // outputs, and intermediate values). It is optional for a value to appear + // in value_info list. + repeated ValueInfoProto value_info = 12; } // For using protobuf-lite diff --git a/onnx/onnx.proto b/onnx/onnx.proto index 1e5c7015f85..5c8c3666ff6 100644 --- a/onnx/onnx.proto +++ b/onnx/onnx.proto @@ -837,6 +837,12 @@ message FunctionProto { // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of // the FunctionProto. optional string domain = 10; + + // Information for the values in the function. The ValueInfoProto.name's + // must be distinct and refer to names in the function (including inputs, + // outputs, and intermediate values). It is optional for a value to appear + // in value_info list. + repeated ValueInfoProto value_info = 12; } // For using protobuf-lite diff --git a/onnx/onnx.proto3 b/onnx/onnx.proto3 index 9e49b2b13b9..a4d21b1b30b 100644 --- a/onnx/onnx.proto3 +++ b/onnx/onnx.proto3 @@ -837,6 +837,12 @@ message FunctionProto { // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of // the FunctionProto. string domain = 10; + + // Information for the values in the function. The ValueInfoProto.name's + // must be distinct and refer to names in the function (including inputs, + // outputs, and intermediate values). It is optional for a value to appear + // in value_info list. + repeated ValueInfoProto value_info = 12; } // For using protobuf-lite