Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process subgraphs in inliner #5841

Merged
merged 5 commits into from Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
71 changes: 66 additions & 5 deletions onnx/inliner/inliner.cc
Expand Up @@ -445,13 +445,37 @@ using FunctionMap = std::unordered_map<FunctionIdKey, std::pair<const FunctionPr

using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;

// Shared utility used for inlining into either a GraphProto or a FunctionProto.
void InlineFunctions(NodeList& nodes, const FunctionMap& map, NameGenerator& name_generator, ModelProto* model) {
/** Utility function used for inlining into a GraphProto.
* @param graph Mutable graph
* @param map Map from function-id to function for functions to be inlined
* @param name_generator Name generator for generating unique names for inlined variables
* @param model If non-null, the model being inlined into. Used for version conversion.
* @param inline_count Mutable counter for number of inlined calls. Used for name generation.
*/
void InlineFunctions(
GraphProto& graph,
const FunctionMap& map,
NameGenerator& name_generator,
ModelProto* model,
int& inline_count);

/** Shared utility function used for inlining into either a GraphProto or a FunctionProto.
* @param nodes Mutable list of nodes (of function or graph)
* @param map Map from function-id to function for functions to be inlined
* @param name_generator Name generator for generating unique names for inlined variables
* @param model If non-null, the model being inlined into. Used for version conversion.
* @param inline_count Mutable counter for number of inlined calls. Used for name generation.
*/
void InlineFunctions(
NodeList& nodes,
const FunctionMap& map,
NameGenerator& name_generator,
ModelProto* model,
int& inline_count) {
NodeList original_nodes;
// Move all nodes into original_nodes
original_nodes.Swap(&nodes);

int inline_count = 0;
std::function<void(NodeProto & node)> append_node = [&](NodeProto& node) {
FunctionProto callee;
auto iter = map.find(GetCalleeId(node));
Expand Down Expand Up @@ -479,6 +503,16 @@ void InlineFunctions(NodeList& nodes, const FunctionMap& map, NameGenerator& nam
// Append node without inlining.
// TODO: use std::move instead of copying. Use of move doesn't seem to work with
// protobuf in some platforms/settings. [nodes->Add(std::move(node));]

for (auto& attr : *node.mutable_attribute()) {
if (attr.has_g()) {
InlineFunctions(*attr.mutable_g(), map, name_generator, model, inline_count);
}
for (auto& g : *attr.mutable_graphs()) {
InlineFunctions(g, map, name_generator, model, inline_count);
}
}

*nodes.Add() = node;
}
};
Expand All @@ -487,16 +521,43 @@ void InlineFunctions(NodeList& nodes, const FunctionMap& map, NameGenerator& nam
}
}

/** Utility function used for inlining into a GraphProto.
* @param graph Mutable graph
* @param map Map from function-id to function for functions to be inlined
* @param name_generator Name generator for generating unique names for inlined variables
* @param model If non-null, the model being inlined into. Used for version conversion.
* @param inline_count Mutable counter for number of inlined calls. Used for name generation.
*/
void InlineFunctions(
GraphProto& graph,
const FunctionMap& map,
NameGenerator& name_generator,
ModelProto* model,
int& inline_count) {
auto* nodes = graph.mutable_node();
InlineFunctions(*nodes, map, name_generator, model, inline_count);
}

/** Utility function used for inlining into a ModelProto.
* @param model Mutable model
* @param map Map from function-id to function for functions to be inlined
*/
void InlineFunctions(ModelProto& model, FunctionMap& map) {
int inline_count = 0;
auto* graph = model.mutable_graph();
NameGenerator name_generator(*graph);
auto* nodes = graph->mutable_node();
InlineFunctions(*nodes, map, name_generator, &model);
InlineFunctions(*nodes, map, name_generator, &model, inline_count);
}

/** Utility function used for inlining into a FunctionProto.
* @param function Mutable function
* @param map Map from function-id to function for functions to be inlined
*/
void InlineFunctions(FunctionProto& function, FunctionMap& map) {
int inline_count = 0;
NameGenerator name_generator(function);
InlineFunctions(*function.mutable_node(), map, name_generator, nullptr);
InlineFunctions(*function.mutable_node(), map, name_generator, nullptr, inline_count);
}

class VectorSet : public FunctionIdSet {
Expand Down
40 changes: 40 additions & 0 deletions onnx/test/cpp/inliner_test.cc
Expand Up @@ -80,6 +80,46 @@
ASSERT_EQ(num_functions, 0);
}

// Test that inlining processes subgraphs.
TEST(FunctionInliner, SubgraphTest) {

Check warning on line 84 in onnx/test/cpp/inliner_test.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: all parameters should be named in a function [readability-named-parameter] ```suggestion TEST(FunctionInliner /*unused*/, SubgraphTest /*unused*/) { ```
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 10, "local" : 1 ]
>
agraph (bool cond, float[N] X) => (float[N] Y)
{
Y = If (cond) <
then_branch = then_graph () => (y) {
y = local.square (X)
},
else_branch = else_graph () => (y) {
y = local.square (X)
}
>
}

<
opset_import: [ "" : 10 ],
domain: "local",
doc_string: "Function square."
>
square (x) => (y) {
y = Mul (x, x)
}
)ONNX";

ModelProto model;

Check warning on line 112 in onnx/test/cpp/inliner_test.cc

View workflow job for this annotation

GitHub Actions / clang-tidy-review

clang-tidy

warning: variable 'model' is not initialized [cppcoreguidelines-init-variables] ```suggestion ModelProto model = 0; ```
InlineFunctions(model, code);
auto& if_node = model.graph().node(0);
auto& graph1 = if_node.attribute(0).g();
ASSERT_EQ(graph1.node(0).op_type(), "Mul");
auto& graph2 = if_node.attribute(1).g();
ASSERT_EQ(graph2.node(0).op_type(), "Mul");
auto num_functions = model.functions_size();
ASSERT_EQ(num_functions, 0);
}

TEST(FunctionInliner, Nested) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
Expand Down