Skip to content

Commit

Permalink
Add valueinfos field to FunctionProto (#5903)
Browse files Browse the repository at this point in the history
### Description

* Extend FunctionProto by adding a field to store ValueInfos for
variables in the function.
* Extend the helper function and parser to support this.

### Motivation and Context

This enables type/shape information to be stored for variables in a
FunctionProto, just like in a GraphProto. Note that this is optional,
just like in a graph. This field may store type/shape info for the
inputs and outputs of a function as well. (Note that, unlike in a graph,
the input/output field in a FunctionProto do not have explicit
types/shapes, they are just strings, not ValueInfo.)

---------

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
  • Loading branch information
gramalingam committed Feb 6, 2024
1 parent 3d976ff commit ffaa707
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 10 deletions.
3 changes: 3 additions & 0 deletions docs/IR.md
Expand Up @@ -200,6 +200,7 @@ input|string[]|The input parameters of the function
output|string[]|The output parameters of the function.
node|Node[]|A list of nodes, forming a partially ordered computation graph. It must be in topological order.
|opset_import|OperatorSetId|A collection of operator set identifiers used by the function implementation.
value_info|ValueInfo[]|Used to store the type and shape information of values used in the function.

The name and domain serve to identify the operator uniquely. An opset version is not explicitly
identified in a FunctionProto, but it is implicitly determined by the opset version of the domain
Expand All @@ -210,6 +211,8 @@ is explicitly included in the signature. The attribute_proto field describes att

The opset_import and node fields describe the implementation of the function.

The value_info field (added in IR version 10) allows a model to store type and shape information about the values used in a function, including its inputs and outputs. Note that this is optional, and ONNX allows functions to be polymorphic.

### Graphs

A graph is used to describe a side-effect-free computation (function).
Expand Down
11 changes: 7 additions & 4 deletions docs/Syntax.md
Expand Up @@ -56,7 +56,9 @@ The grammar below describes the syntax:
| 'optional' '(' type ')' | 'sparse_tensor' '(' tensor-type ')'
value-info ::= type id
value-infos ::= value-info (',' value-info)*
value-info-list ::= '(' value-infos? ')'
value-info-list ::= '(' value-infos? ')
id-or-value-info ::= type? id
id-or-value-infos ::= id-or-value-info (',' id-or-value-info)*
quoted-str :== '"' ([^"])* '"'
str-str :== quoted-str ':' quoted-str
str-str-list :== '[' str-str (',' str-str)* ']'
Expand Down Expand Up @@ -84,8 +86,9 @@ The grammar below describes the syntax:
other-data ::= id ':' value
other-data-list ::= '<' other-data (',' other-data)* '>'
fun-attr-list ::= '<' id | attr (',' id | attr)* '>'
fun-input-list ::= '(' id-list ')'
fun-output-list ::= '(' id-list ')'
function ::= other-data-list? id fun-attr-list? fun-input-list '=>' fun-output-list node-list
fun-input-list ::= '(' id-or-value-infos ')'
fun-output-list ::= '(' id-or-value-infos ')'
fun-value-infos ::= ( '<' value-infos '>' )?
function ::= other-data-list? id fun-attr-list? fun-input-list '=>' fun-output-list fun-value-infos node-list
model ::= other-data-list? graph function*
```
46 changes: 41 additions & 5 deletions onnx/defs/parser.cc
Expand Up @@ -316,12 +316,44 @@ Status OnnxParser::Parse(ValueInfoProto& valueinfo) {
return Status::OK();
}

Status OnnxParser::Parse(ValueInfoList& vilist) {
Status OnnxParser::Parse(char open, ValueInfoList& vilist, char close) {
MATCH(open);
if (!Matches(close)) {
do {
PARSE(*vilist.Add());
} while (Matches(','));
MATCH(close);
}
return Status::OK();
}

Status OnnxParser::ParseGraphInputOutput(ValueInfoList& vilist) {
vilist.Clear();
PARSE('(', vilist, ')');
return Status::OK();
}

Status OnnxParser::ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist) {
// Do not clear vilist, as it accumulates values over inputs and outputs.
idlist.Clear();
MATCH('(');
if (!Matches(')')) {
do {
PARSE(*vilist.Add());
// Function inputs/outputs can be optionally typed.
// Syntax: Name | Type Name
// The name is added to idlist. If the optional type is present, an entry is
// added to vilist.

std::string* name = idlist.Add();
ValueInfoProto* vi = nullptr;

if (NextIsType()) {
vi = vilist.Add();
PARSE(*(vi->mutable_type()));
}
CHECK_PARSER_STATUS(ParseIdentifier(*name));
if (vi != nullptr)
vi->set_name(*name);
} while (Matches(','));
MATCH(')');
}
Expand Down Expand Up @@ -715,7 +747,7 @@ Status OnnxParser::Parse(std::string name, GraphProto& graph) {
CHECK_PARSER_STATUS(ParseInput(*graph.mutable_input(), *graph.mutable_initializer()));
MATCH('=');
MATCH('>', false);
PARSE(*graph.mutable_output());
CHECK_PARSER_STATUS(ParseGraphInputOutput(*graph.mutable_output()));
CHECK_PARSER_STATUS(ParseValueInfo(*graph.mutable_value_info(), *graph.mutable_initializer()));
return Parse(*graph.mutable_node());
}
Expand Down Expand Up @@ -751,10 +783,14 @@ Status OnnxParser::Parse(FunctionProto& fn) {
fn.set_name(id);

PARSE('<', *fn.mutable_attribute(), *fn.mutable_attribute_proto(), '>');
PARSE('(', *fn.mutable_input(), ')');
fn.mutable_value_info()->Clear();
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_input(), *fn.mutable_value_info()));
MATCH('=');
MATCH('>', false);
PARSE('(', *fn.mutable_output(), ')');
CHECK_PARSER_STATUS(ParseFunctionInputOutput(*fn.mutable_output(), *fn.mutable_value_info()));
if (NextChar() == '<') {
PARSE('<', *fn.mutable_value_info(), '>');
}
return Parse(*fn.mutable_node());
}

Expand Down
6 changes: 5 additions & 1 deletion onnx/defs/parser.h
Expand Up @@ -431,7 +431,11 @@ class OnnxParser : public ParserBase {

Status Parse(ValueInfoProto& valueinfo);

Status Parse(ValueInfoList& vilist);
Status ParseGraphInputOutput(ValueInfoList& vilist);

Status ParseFunctionInputOutput(IdList& idlist, ValueInfoList& vilist);

Status Parse(char open, ValueInfoList& vilist, char close);

Status ParseInput(ValueInfoList& vilist, TensorList& initializers);

Expand Down
4 changes: 4 additions & 0 deletions onnx/helper.py
Expand Up @@ -259,11 +259,14 @@ def make_function(
attributes: Optional[Sequence[str]] = None,
attribute_protos: Optional[Sequence[AttributeProto]] = None,
doc_string: Optional[str] = None,
value_info: Optional[Sequence[ValueInfoProto]] = None,
) -> FunctionProto:
if attributes is None:
attributes = []
if attribute_protos is None:
attribute_protos = []
if value_info is None:
value_info = []
f = FunctionProto()
f.domain = domain
f.name = fname
Expand All @@ -275,6 +278,7 @@ def make_function(
f.attribute_proto.extend(attribute_protos)
if doc_string:
f.doc_string = doc_string
f.value_info.extend(value_info)
return f


Expand Down
6 changes: 6 additions & 0 deletions onnx/onnx-ml.proto
Expand Up @@ -853,6 +853,12 @@ message FunctionProto {
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional string domain = 10;

// Information for the values in the function. The ValueInfoProto.name's
// must be distinct and refer to names in the function (including inputs,
// outputs, and intermediate values). It is optional for a value to appear
// in value_info list.
repeated ValueInfoProto value_info = 12;
}

// For using protobuf-lite
Expand Down
6 changes: 6 additions & 0 deletions onnx/onnx-ml.proto3
Expand Up @@ -853,6 +853,12 @@ message FunctionProto {
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
string domain = 10;

// Information for the values in the function. The ValueInfoProto.name's
// must be distinct and refer to names in the function (including inputs,
// outputs, and intermediate values). It is optional for a value to appear
// in value_info list.
repeated ValueInfoProto value_info = 12;
}

// For using protobuf-lite
Expand Down
6 changes: 6 additions & 0 deletions onnx/onnx.in.proto
Expand Up @@ -854,4 +854,10 @@ message FunctionProto {
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional string domain = 10;

// Information for the values in the function. The ValueInfoProto.name's
// must be distinct and refer to names in the function (including inputs,
// outputs, and intermediate values). It is optional for a value to appear
// in value_info list.
repeated ValueInfoProto value_info = 12;
}
6 changes: 6 additions & 0 deletions onnx/onnx.proto
Expand Up @@ -837,6 +837,12 @@ message FunctionProto {
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
optional string domain = 10;

// Information for the values in the function. The ValueInfoProto.name's
// must be distinct and refer to names in the function (including inputs,
// outputs, and intermediate values). It is optional for a value to appear
// in value_info list.
repeated ValueInfoProto value_info = 12;
}

// For using protobuf-lite
Expand Down
6 changes: 6 additions & 0 deletions onnx/onnx.proto3
Expand Up @@ -837,6 +837,12 @@ message FunctionProto {
// The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
// the FunctionProto.
string domain = 10;

// Information for the values in the function. The ValueInfoProto.name's
// must be distinct and refer to names in the function (including inputs,
// outputs, and intermediate values). It is optional for a value to appear
// in value_info list.
repeated ValueInfoProto value_info = 12;
}

// For using protobuf-lite
Expand Down
80 changes: 80 additions & 0 deletions onnx/test/cpp/parser_test.cc
Expand Up @@ -363,6 +363,86 @@ f (y, z) => (w)
EXPECT_EQ(fp.opset_import_size(), 1);
}

TEST(ParserTest, FunctionValueInfoTest) {
const char* code = R"ONNX(
<
opset_import: [ "" : 10 ],
domain: "ai.onnx.ml",
doc_string: "A function test case."
>
f (float[N] y, float[N] z) => (float[N] w)
{
x = Add(y, z)
w = Mul(x, y)
}
)ONNX";

FunctionProto fp;
Parse(fp, code);

EXPECT_EQ(fp.input_size(), 2);
EXPECT_EQ(fp.output_size(), 1);
ASSERT_EQ(fp.value_info_size(), 3);
EXPECT_EQ(fp.value_info(0).name(), "y");
EXPECT_EQ(fp.value_info(1).name(), "z");
EXPECT_EQ(fp.value_info(2).name(), "w");
}

TEST(ParserTest, FunctionValueInfoTest2) {
const char* code = R"ONNX(
<
opset_import: [ "" : 10 ],
domain: "ai.onnx.ml",
doc_string: "A function test case."
>
f (float[N] y, float[N] z) => (float[N] w)
<float[N] x>
{
x = Add(y, z)
w = Mul(x, y)
}
)ONNX";

FunctionProto fp;
Parse(fp, code);

EXPECT_EQ(fp.input_size(), 2);
EXPECT_EQ(fp.value_info_size(), 4);
ASSERT_EQ(fp.output_size(), 1);
EXPECT_EQ(fp.value_info(0).name(), "y");
EXPECT_EQ(fp.value_info(1).name(), "z");
EXPECT_EQ(fp.value_info(2).name(), "w");
EXPECT_EQ(fp.value_info(3).name(), "x");
}

TEST(ParserTest, FunctionValueInfoTest3) {
const char* code = R"ONNX(
<
opset_import: [ "" : 10 ],
domain: "ai.onnx.ml",
doc_string: "A function test case."
>
f (float[N] y, z) => (float[N] w)
<float[N] x, float[N] t>
{
x = Add(y, z)
t = Add(x, x)
w = Mul(t, y)
}
)ONNX";

FunctionProto fp;
Parse(fp, code);

EXPECT_EQ(fp.input_size(), 2);
ASSERT_EQ(fp.value_info_size(), 4);
EXPECT_EQ(fp.output_size(), 1);
EXPECT_EQ(fp.value_info(0).name(), "y");
EXPECT_EQ(fp.value_info(1).name(), "w");
EXPECT_EQ(fp.value_info(2).name(), "x");
EXPECT_EQ(fp.value_info(3).name(), "t");
}

TEST(ParserTest, InitializerTest) {
const char* code = R"ONNX(
agraph (float y = {1.0}, float[N] z) => (float[N] w)
Expand Down

0 comments on commit ffaa707

Please sign in to comment.