Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process subgraphs in inliner #5841

Merged
merged 5 commits into from Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
20 changes: 18 additions & 2 deletions onnx/inliner/inliner.cc
Expand Up @@ -445,13 +445,14 @@

using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;

void InlineFunctions(GraphProto& graph, const FunctionMap& map, NameGenerator& name_generator, ModelProto* model, int inline_count = 0) ;

Check warning on line 448 in onnx/inliner/inliner.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnx/inliner/inliner.cc:448: Lines should be <= 120 characters long [whitespace/line_length] [2]

Check warning on line 448 in onnx/inliner/inliner.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Extra space before last semicolon. If this should be an empty statement, use {} instead. [whitespace/semicolon] [5] Raw Output: onnx/inliner/inliner.cc:448: Extra space before last semicolon. If this should be an empty statement, use {} instead. [whitespace/semicolon] [5]
gramalingam marked this conversation as resolved.
Show resolved Hide resolved

// Shared utility used for inlining into either a GraphProto or a FunctionProto.
void InlineFunctions(NodeList& nodes, const FunctionMap& map, NameGenerator& name_generator, ModelProto* model) {
void InlineFunctions(NodeList& nodes, const FunctionMap& map, NameGenerator& name_generator, ModelProto* model, int inline_count = 0) {

Check warning on line 451 in onnx/inliner/inliner.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnx/inliner/inliner.cc:451: Lines should be <= 120 characters long [whitespace/line_length] [2]
NodeList original_nodes;
// Move all nodes into original_nodes
original_nodes.Swap(&nodes);

int inline_count = 0;
std::function<void(NodeProto & node)> append_node = [&](NodeProto& node) {
FunctionProto callee;
auto iter = map.find(GetCalleeId(node));
Expand Down Expand Up @@ -479,6 +480,16 @@
// Append node without inlining.
// TODO: use std::move instead of copying. Use of move doesn't seem to work with
// protobuf in some platforms/settings. [nodes->Add(std::move(node));]

for (auto& attr : *node.mutable_attribute()) {
if (attr.has_g()) {
InlineFunctions(*attr.mutable_g(), map, name_generator, model, inline_count);
}
for (auto&g : *attr.mutable_graphs()) {
InlineFunctions(g, map, name_generator, model, inline_count);
}
}

*nodes.Add() = node;
}
};
Expand All @@ -487,6 +498,11 @@
}
}

void InlineFunctions(GraphProto& graph, const FunctionMap& map, NameGenerator& name_generator, ModelProto* model, int inline_count) {

Check warning on line 501 in onnx/inliner/inliner.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnx/inliner/inliner.cc:501: Lines should be <= 120 characters long [whitespace/line_length] [2]
auto* nodes = graph.mutable_node();
InlineFunctions(*nodes, map, name_generator, model, inline_count);
}

void InlineFunctions(ModelProto& model, FunctionMap& map) {
auto* graph = model.mutable_graph();
NameGenerator name_generator(*graph);
Expand Down
40 changes: 40 additions & 0 deletions onnx/test/cpp/inliner_test.cc
Expand Up @@ -80,6 +80,46 @@
ASSERT_EQ(num_functions, 0);
}

// Test that inlining processes subgraphs.
TEST(FunctionInliner, SubgraphTest) {

Check warning on line 84 in onnx/test/cpp/inliner_test.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: all parameters should be named in a function [readability-named-parameter] ```suggestion TEST(FunctionInliner /*unused*/, SubgraphTest /*unused*/) { ```
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 10, "local" : 1 ]
>
agraph (bool cond, float[N] X) => (float[N] Y)
{
Y = If (cond) <
then_branch = then_graph () => (y) {
y = local.square (X)
},
else_branch = else_graph () => (y) {
y = local.square (X)
}
>
}

<
opset_import: [ "" : 10 ],
domain: "local",
doc_string: "Function square."
>
square (x) => (y) {
y = Mul (x, x)
}
)ONNX";

ModelProto model;

Check warning on line 112 in onnx/test/cpp/inliner_test.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variable 'model' is not initialized [cppcoreguidelines-init-variables] ```suggestion ModelProto model = 0; ```
InlineFunctions(model, code);
auto& if_node = model.graph().node(0);
auto& graph1 = if_node.attribute(0).g();
ASSERT_EQ(graph1.node(0).op_type(), "Mul");
auto& graph2 = if_node.attribute(1).g();
ASSERT_EQ(graph2.node(0).op_type(), "Mul");
auto num_functions = model.functions_size();
ASSERT_EQ(num_functions, 0);
}

TEST(FunctionInliner, Nested) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
Expand Down