From 0ca31fd5c3f3401b51d59bd5caad747183c6c25f Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 15 Apr 2020 13:33:31 -0700 Subject: [PATCH] [BYOC] Prevent duplicate outputs in subgraph Tuple (#5320) * Fix duplicate output in partitiongraph * Add test case * Fix test_annotated_regions with duplicate compiler_end outputs * Revert "Fix duplicate output in partitiongraph" This reverts commit e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa. * Prevent duplicate outputs in Tuple in PartitionGraph * Fix lint * Add another test case for when regions are merged, and when TupleGetItem was duplicated * Pull GetFunctionOutput out of branch, improve description of GetFunctionOutput * Use std::move for GetFunctionOutput. Fix typo with testcase name * Use tvm.transform.Sequential --- src/relay/transforms/partition_graph.cc | 226 ++++++++++-------- .../python/relay/test_pass_partition_graph.py | 135 +++++++++++ 2 files changed, 260 insertions(+), 101 deletions(-) diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index c8367fb140f21..15ad60be3a955 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -205,99 +205,13 @@ class Partitioner : public ExprMutator { // region_function_calls is map that maintains // (each annotated regions) --> created function - if (region_function_calls.find(region) != region_function_calls.end()) { - // This section is executed only if there are multiple outputs in the - // region Thus, the function is always created and at the end there - // would be a tuple node Therefore, we insert a tuple get item node. - - // Use the already created tuple node - auto sg_call = region_function_calls[region]; - int index = GetRetIdx(region, GetRef(call)); - CHECK_NE(index, -1); - - auto tuple_get_item_ = TupleGetItem(sg_call, index); - tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; - return std::move(tuple_get_item_); - } else { - // First time this region is encountered in the traversal - // Creating the function - - Array fields; - - for (auto ret : region->GetOutputs()) { - auto ret_expr = VisitExpr(Downcast(ret)->args[0]); - fields.push_back(ret_expr); - } - int index = GetRetIdx(region, GetRef(call)); - CHECK_NE(index, -1); - - Array params; - Array param_expr; - std::unordered_map params_bind; - - for (auto pair : region_args[region]) { - params.push_back(pair.first); - if (const auto* cn = pair.second.as()) { - params_bind[pair.first->name_hint()] = cn->data; - } else { - param_expr.push_back(pair.second); - } - } - - Function global_region_func; - if (region->GetOutputs().size() == 1) { - // If there are only a single output; no need to add a tuple - global_region_func = - Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); - } else { - auto tuple = Tuple(fields); - global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); - } - - std::string target = call->attrs.as()->compiler; - std::string name = target + "_" + std::to_string(region->GetID()); - - global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, - runtime::String(name)); - global_region_func = - WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); - global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, - tvm::runtime::String(target)); - global_region_func = - WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); - - // Constant propagation - if (!params_bind.empty()) { - global_region_func = backend::BindParamsByName(global_region_func, params_bind); - } - - std::string fname = name; - CHECK(!module_->ContainGlobalVar(fname)) - << "Global function " << fname << " already exists"; - // Create a global function and add it to the IRModule for the region. - // This way we lift the functions that should be handled by external - // codegen to the module scope and rely on the pass manager to prevent - // relay function level passes (i.e. simplify inference and fusion) - // optimizing it. - GlobalVar glob_func(fname); - module_->Add(glob_func, global_region_func); - - // The return type of callnode is the same as the type of the - // compiler_end node. - auto ret = Call(glob_func, param_expr); - region_function_calls[region] = ret; - - if (region->GetOutputs().size() == 1) { - // If there is only a single output; no need to add a tuplegetitem - // node - return std::move(ret); - } else { - // Add a tuplegetitem node to select this output out of many - auto tuple_get_item_ = TupleGetItem(ret, index); - tuple_get_item_->checked_type_ = GetRef(call)->args[0]->checked_type_; - return std::move(tuple_get_item_); - } + if (region_function_calls.find(region) == region_function_calls.end()) { + // First time this region is encountered in the traversal. + // Creating the function. + CreateFunction(region, call); } + // Retrieve this particular output of function. + return GetFunctionOutput(region, GetRef(call)); } } @@ -456,18 +370,111 @@ class Partitioner : public ExprMutator { } /*! - * \brief Get the index of the return(output); - * this is to be used as tuplegetitem idx + * \brief This function is called first time that we encounter a compiler_end + * node to create the function for the subgraph. */ - int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { - int idx = 0; - for (auto arg_ : sg->GetOutputs()) { - if (arg == arg_) { - return idx; + void CreateFunction(AnnotatedRegion region, const CallNode* call) { + // Create fields which is a unique list of outputs. Also populate + // region_return_indices_ map which maps parent of compiler_end node to + // corresponding index in fields. + Array fields; + int i = 0; + for (auto ret : region->GetOutputs()) { + auto ret_node = Downcast(ret)->args[0]; + // Don't duplicate outputs. + if (!region_return_indices_.count(region) || + !region_return_indices_[region].count(ret_node)) { + auto ret_expr = VisitExpr(ret_node); + fields.push_back(ret_expr); + region_return_indices_[region][ret_node] = i; + i++; } - idx++; } - return -1; + + Array params; + Array param_expr; + std::unordered_map params_bind; + + for (auto pair : region_args[region]) { + params.push_back(pair.first); + if (const auto* cn = pair.second.as()) { + params_bind[pair.first->name_hint()] = cn->data; + } else { + param_expr.push_back(pair.second); + } + } + + Function global_region_func; + if (fields.size() == 1) { + // If there are only a single output; no need to add a tuple + global_region_func = + Function(params, fields[0], call->args[0]->checked_type_, {}, DictAttrs()); + } else { + auto tuple = Tuple(fields); + global_region_func = Function(params, tuple, tuple->checked_type_, {}, DictAttrs()); + } + + std::string target = call->attrs.as()->compiler; + std::string name = target + "_" + std::to_string(region->GetID()); + + global_region_func = WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, + runtime::String(name)); + global_region_func = + WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1)); + global_region_func = WithAttr(std::move(global_region_func), attr::kCompiler, + tvm::runtime::String(target)); + global_region_func = + WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1)); + + // Constant propagation + if (!params_bind.empty()) { + global_region_func = backend::BindParamsByName(global_region_func, params_bind); + } + + std::string fname = name; + CHECK(!module_->ContainGlobalVar(fname)) + << "Global function " << fname << " already exists"; + // Create a global function and add it to the IRModule for the region. + // This way we lift the functions that should be handled by external + // codegen to the module scope and rely on the pass manager to prevent + // relay function level passes (i.e. simplify inference and fusion) + // optimizing it. + GlobalVar glob_func(fname); + module_->Add(glob_func, global_region_func); + + // The return type of callnode is the same as the type of the + // compiler_end node. + auto ret = Call(glob_func, param_expr); + region_function_calls[region] = ret; + } + + /*! + * \brief Get the return(output) of the function for compiler end node "end_arg". + * This will return either a Call (for a function with a single output) or a + * TupleGetItem (for a function with multiple outputs). + */ + Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) { + Expr arg = Downcast(end_arg)->args[0]; + // Function has one output. + if (region_return_indices_[region].size() == 1) { + return region_function_calls[region]; + } + // Function has multiple outputs. + // Use already made TupleGetItem. + if (region_return_tuplegetitem_.count(region) && + region_return_tuplegetitem_[region].count(arg)) { + return region_return_tuplegetitem_[region][arg]; + } + // Create new TupleGetItem. + CHECK(region_return_indices_.count(region) && + region_return_indices_[region].count(arg)); + int index = region_return_indices_[region][arg]; + + auto func_call = region_function_calls[region]; + auto tuple_get_item_ = TupleGetItem(func_call, index); + tuple_get_item_->checked_type_ = arg->checked_type_; + region_return_tuplegetitem_[region][arg] = tuple_get_item_; + return std::move(tuple_get_item_); } /*! @@ -485,6 +492,23 @@ class Partitioner : public ExprMutator { std::unordered_map>, ObjectHash, ObjectEqual> region_args; + /*! + * \brief This map maintains the index of an output in the subgraph function + * for a given region. If there are multiple entries for a region, then the + * function has a tuple of multiple outputs for its return. + */ + using RegionRetIndexMap = std::unordered_map; + std::unordered_map + region_return_indices_; + + /*! + * \brief This map holds already created TupleGetItem nodes for accessing + * outputs of a function. + */ + using RegionRetTupleGetItemMap = std::unordered_map; + std::unordered_map + region_return_tuplegetitem_; + /*! * \brief Each region set is associated with a function in the module. * This map maintains the mapping between regionsets and the function it diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 2ee8538e30ed9..8827fbf1b8b0a 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -23,6 +23,7 @@ import tvm import tvm.relay.testing +import tvm.relay.op as reg from tvm import relay from tvm import runtime from tvm.relay import transform @@ -1036,6 +1037,138 @@ def test_different_output_region(): test_same_output_region() test_different_output_region() +def test_duplicate_outputs(): + target = "test_duplicate_outputs" + + @reg.register("abs", "target." + target) + def abs(attrs, args): # pylint: disable=unused-variable + return True + + def create_graph(): + data = relay.var('data', shape=(10, 10)) + x = relay.abs(data) + out_1 = relay.nn.relu(x) + out_2 = relay.tanh(x) + out_3 = relay.log(x) + out = relay.Tuple([out_1, out_2, out_3]) + func = relay.Function([data], out) + return func + + def expected(): + mod = tvm.IRModule() + + # function 0 + f0_i0 = relay.var(target+"_0_i0", shape=(10, 10)) + f0_o0 = relay.abs(f0_i0) + func0 = relay.Function([f0_i0], f0_o0) + + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", target) + func0 = func0.with_attr("global_symbol", target+"_0") + gv0 = relay.GlobalVar(target+"_0") + mod[gv0] = func0 + + # body + data = relay.var('data', shape=(10, 10)) + function_out = gv0(data) + out_1 = relay.nn.relu(function_out) + out_2 = relay.tanh(function_out) + out_3 = relay.log(function_out) + out = relay.Tuple([out_1, out_2, out_3]) + func = relay.Function([data], out) + mod["main"] = func + return mod + + mod = tvm.IRModule() + mod["main"] = create_graph() + + seq = tvm.transform.Sequential([ + transform.AnnotateTarget(target), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ]) + + ref_mod = expected() + partitioned = seq(mod) + assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) + +def test_duplicate_merge_and_tuplegetitem(): + target = "test_duplicate_merge_and_tuplegetitem" + + @reg.register("nn.batch_norm", "target." + target) + def abs(attrs, args): # pylint: disable=unused-variable + return True + + @reg.register("nn.relu", "target." + target) + def abs(attrs, args): # pylint: disable=unused-variable + return True + + def create_graph(): + data = relay.var('data', shape=(10, 10)) + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + x = relay.nn.batch_norm(data, bn_gamma, bn_beta, bn_mmean, bn_mvar) + out_1 = relay.nn.relu(x[0]) + bn_out_1 = x[1] + out_2 = relay.tanh(bn_out_1) + out_3 = relay.log(bn_out_1) + out = relay.Tuple([out_1, out_2, out_3]) + func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out) + return func + + def expected(): + mod = tvm.IRModule() + + # function 0 + f0_i0 = relay.var(target+"_1_i0", shape=(10, 10)) + f0_i1 = relay.var(target+"_1_i1") + f0_i2 = relay.var(target+"_1_i2") + f0_i3 = relay.var(target+"_1_i3") + f0_i4 = relay.var(target+"_1_i4") + f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4) + f0_n1 = f0_n0[1] + f0_n2 = relay.nn.relu(f0_n0[0]) + f0_o0 = relay.Tuple([f0_n1, f0_n2]) + func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0) + + func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + func0 = func0.with_attr("Compiler", target) + func0 = func0.with_attr("global_symbol", target+"_1") + gv0 = relay.GlobalVar(target+"_1") + mod[gv0] = func0 + + # body + data = relay.var('data', shape=(10, 10)) + bn_gamma = relay.var("bn_gamma") + bn_beta = relay.var("bn_beta") + bn_mmean = relay.var("bn_mean") + bn_mvar = relay.var("bn_var") + function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar) + get_out0 = relay.TupleGetItem(function_out, 0) + get_out1 = relay.TupleGetItem(function_out, 1) + out_2 = relay.tanh(get_out0) + out_3 = relay.log(get_out0) + out = relay.Tuple([get_out1, out_2, out_3]) + func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out) + mod["main"] = func + return mod + + mod = tvm.IRModule() + mod["main"] = create_graph() + + seq = tvm.transform.Sequential([ + transform.AnnotateTarget(target), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ]) + + ref_mod = expected() + partitioned = seq(mod) + assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True) if __name__ == "__main__": test_multi_node_compiler() @@ -1051,3 +1184,5 @@ def test_different_output_region(): test_mixed_single_multiple_outputs() test_dnnl_fuse() test_multiple_use_of_an_output() + test_duplicate_outputs() + test_duplicate_merge_and_tuplegetitem()