Skip to content

Commit

Permalink
Process subgraphs in inliner (#5841)
Browse files Browse the repository at this point in the history
Addresses issue mentioned in #5817

The inliner was not processing subgraphs for inlining.

---------

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
  • Loading branch information
gramalingam committed Jan 3, 2024
1 parent 75c6892 commit 75671b3
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 5 deletions.
71 changes: 66 additions & 5 deletions onnx/inliner/inliner.cc
Expand Up @@ -445,13 +445,37 @@ using FunctionMap = std::unordered_map<FunctionIdKey, std::pair<const FunctionPr

using NodeList = google::protobuf::RepeatedPtrField<NodeProto>;

// Shared utility used for inlining into either a GraphProto or a FunctionProto.
void InlineFunctions(NodeList& nodes, const FunctionMap& map, NameGenerator& name_generator, ModelProto* model) {
/** Utility function used for inlining into a GraphProto.
* @param graph Mutable graph
* @param map Map from function-id to function for functions to be inlined
* @param name_generator Name generator for generating unique names for inlined variables
* @param model If non-null, the model being inlined into. Used for version conversion.
* @param inline_count Mutable counter for number of inlined calls. Used for name generation.
*/
void InlineFunctions(
GraphProto& graph,
const FunctionMap& map,
NameGenerator& name_generator,
ModelProto* model,
int& inline_count);

/** Shared utility function used for inlining into either a GraphProto or a FunctionProto.
* @param nodes Mutable list of nodes (of function or graph)
* @param map Map from function-id to function for functions to be inlined
* @param name_generator Name generator for generating unique names for inlined variables
* @param model If non-null, the model being inlined into. Used for version conversion.
* @param inline_count Mutable counter for number of inlined calls. Used for name generation.
*/
void InlineFunctions(
NodeList& nodes,
const FunctionMap& map,
NameGenerator& name_generator,
ModelProto* model,
int& inline_count) {
NodeList original_nodes;
// Move all nodes into original_nodes
original_nodes.Swap(&nodes);

int inline_count = 0;
std::function<void(NodeProto & node)> append_node = [&](NodeProto& node) {
FunctionProto callee;
auto iter = map.find(GetCalleeId(node));
Expand Down Expand Up @@ -479,6 +503,16 @@ void InlineFunctions(NodeList& nodes, const FunctionMap& map, NameGenerator& nam
// Append node without inlining.
// TODO: use std::move instead of copying. Use of move doesn't seem to work with
// protobuf in some platforms/settings. [nodes->Add(std::move(node));]

for (auto& attr : *node.mutable_attribute()) {
if (attr.has_g()) {
InlineFunctions(*attr.mutable_g(), map, name_generator, model, inline_count);
}
for (auto& g : *attr.mutable_graphs()) {
InlineFunctions(g, map, name_generator, model, inline_count);
}
}

*nodes.Add() = node;
}
};
Expand All @@ -487,16 +521,43 @@ void InlineFunctions(NodeList& nodes, const FunctionMap& map, NameGenerator& nam
}
}

/** Utility function used for inlining into a GraphProto.
* @param graph Mutable graph
* @param map Map from function-id to function for functions to be inlined
* @param name_generator Name generator for generating unique names for inlined variables
* @param model If non-null, the model being inlined into. Used for version conversion.
* @param inline_count Mutable counter for number of inlined calls. Used for name generation.
*/
void InlineFunctions(
GraphProto& graph,
const FunctionMap& map,
NameGenerator& name_generator,
ModelProto* model,
int& inline_count) {
auto* nodes = graph.mutable_node();
InlineFunctions(*nodes, map, name_generator, model, inline_count);
}

/** Utility function used for inlining into a ModelProto.
* @param model Mutable model
* @param map Map from function-id to function for functions to be inlined
*/
void InlineFunctions(ModelProto& model, FunctionMap& map) {
int inline_count = 0;
auto* graph = model.mutable_graph();
NameGenerator name_generator(*graph);
auto* nodes = graph->mutable_node();
InlineFunctions(*nodes, map, name_generator, &model);
InlineFunctions(*nodes, map, name_generator, &model, inline_count);
}

/** Utility function used for inlining into a FunctionProto.
* @param function Mutable function
* @param map Map from function-id to function for functions to be inlined
*/
void InlineFunctions(FunctionProto& function, FunctionMap& map) {
int inline_count = 0;
NameGenerator name_generator(function);
InlineFunctions(*function.mutable_node(), map, name_generator, nullptr);
InlineFunctions(*function.mutable_node(), map, name_generator, nullptr, inline_count);
}

class VectorSet : public FunctionIdSet {
Expand Down
40 changes: 40 additions & 0 deletions onnx/test/cpp/inliner_test.cc
Expand Up @@ -80,6 +80,46 @@ square (x) => (y) {
ASSERT_EQ(num_functions, 0);
}

// Test that inlining processes subgraphs.
TEST(FunctionInliner, SubgraphTest) {
const char* code = R"ONNX(
<
ir_version: 8,
opset_import: [ "" : 10, "local" : 1 ]
>
agraph (bool cond, float[N] X) => (float[N] Y)
{
Y = If (cond) <
then_branch = then_graph () => (y) {
y = local.square (X)
},
else_branch = else_graph () => (y) {
y = local.square (X)
}
>
}
<
opset_import: [ "" : 10 ],
domain: "local",
doc_string: "Function square."
>
square (x) => (y) {
y = Mul (x, x)
}
)ONNX";

ModelProto model;
InlineFunctions(model, code);
auto& if_node = model.graph().node(0);
auto& graph1 = if_node.attribute(0).g();
ASSERT_EQ(graph1.node(0).op_type(), "Mul");
auto& graph2 = if_node.attribute(1).g();
ASSERT_EQ(graph2.node(0).op_type(), "Mul");
auto num_functions = model.functions_size();
ASSERT_EQ(num_functions, 0);
}

TEST(FunctionInliner, Nested) {
const char* code = R"ONNX(
<ir_version: 8, opset_import: [ "" : 17, "local" : 1 ]>
Expand Down

0 comments on commit 75671b3

Please sign in to comment.