diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index df2eb9643f61..f7b9b421a767 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -70,12 +70,12 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, regions_.erase(src); } -void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion region, const Expr& expr) { - auto region2 = GetRegion(expr); - if (region2.defined()) { - MergeRegions(region, region2); +void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) { + auto src = GetRegion(expr); + if (src.defined()) { + MergeRegions(src, dest); } else { - region->nodes.insert(expr); + dest->nodes.insert(expr); } } diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index c5db2cc3d202..0b9301133d1c 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -178,10 +178,10 @@ class AnnotatedRegionSetNode : public Object { /*! * \brief Add an expression to a region. * - * \param region The region to add the expression to. + * \param dest The region to add the expression to. * \param expr The expression. */ - void AddToRegion(AnnotatedRegion region, const Expr& expr); + void AddToRegion(AnnotatedRegion dest, const Expr& expr); /*! * \brief Make a new region. diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index c2f7b804cb6a..b546f05b46e4 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -32,6 +32,9 @@ namespace tvm { namespace relay { namespace annotate_target { +// Cache compiler_begin op for equivalence check. +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); + // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. class AnnotateTargetWrapper : public ExprMutator { @@ -52,6 +55,13 @@ class AnnotateTargetWrapper : public ExprMutator { return fannotate[op](call->attrs, call->args); } } + if (expr->IsInstance()) { + TupleGetItem get = Downcast(expr); + if (get->tuple->IsInstance() && + get->tuple.as()->op == compiler_begin_op) { + return true; + } + } return false; } @@ -110,9 +120,14 @@ class AnnotateTargetWrapper : public ExprMutator { auto new_e = ExprMutator::VisitExpr_(op); auto get = Downcast(new_e); - return TupleGetItem( - InsertEnd(get->tuple), - get->index); + if (IsSupported(get->tuple)) { + const auto* begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + CHECK(begin_op); + return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index); + } else { + return TupleGetItem(InsertEnd(get->tuple), get->index); + } } Expr VisitExpr_(const FunctionNode* op) { diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index e6ec93aecd42..4a8ff64a24b8 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -30,10 +30,10 @@ * as external functions. */ +#include #include #include #include -#include #include #include @@ -44,7 +44,6 @@ #include "../analysis/annotated_region_set.h" - namespace tvm { namespace relay { namespace partitioning { @@ -63,7 +62,7 @@ static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); class AnnotateRestDefault : public ExprMutator { public: explicit AnnotateRestDefault(const Expr& expr) { - regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); + regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); } Expr Annotate(const Expr& expr) { @@ -71,141 +70,158 @@ class AnnotateRestDefault : public ExprMutator { func_ = Downcast(expr); // Corner Case CC1 : If the last node does not belong - // to a region nede to add a compiler_end + // to a region node to add a compiler_end auto region = regions_->GetRegion(func_->body); auto mutated_expr = this->VisitExpr(expr); if (!region.defined()) { func_ = Downcast(mutated_expr); // CC1 : add that compiler end after mutation - auto body = AddCompilerEnd_(func_->body); - func_ = Function(func_->params, body, - body->checked_type_, {}, DictAttrs()); + auto body = InsertEnd(func_->body); + func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs()); return Downcast(func_); } return mutated_expr; } /*! \brief This function adds compiler ends to nodes that - * have a region AND they should not be arguments of the - * original function + * don't belong to a region already (default). * \param expr The expression to add a compiler end to. * \return expr The expression with or without a compiler end added. */ - Expr AddCompilerEnd(const Expr& expr) { - auto region = regions_->GetRegion(expr); - auto visited_expr = VisitExpr(expr); - - // The compiler ends are added to nodes that does have a region - // AND they should not be arguments of the original function - if (!region.defined() && - std::find(func_->params.begin(), - func_->params.end(), visited_expr) - == func_->params.end()) { - return AddCompilerEnd_(visited_expr); + Expr InsertEnd(const Expr& expr) { + if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance() && + !expr->IsInstance()) { + const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); + CHECK(end_op); + Expr end = (*end_op)(expr, target_); + return end; } - return visited_expr; + return expr; } - Expr AddCompilerEnd_(const Expr& expr) { - const auto* end_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_end"); - CHECK(end_op); - Expr end = (*end_op)(expr, target_); - return end; + /*! \brief This function adds compiler begins to nodes that + * don't belong to a region already (default). + * \param expr The expression to add a compiler begin to. + * \return expr The expression with or without a compiler begin added. + */ + Expr InsertBegin(const Expr& expr) { + const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + CHECK(begin_op); + Expr begin = (*begin_op)(expr, target_); + annotated_nodes_.insert(begin); + return begin; } - Expr VisitExpr_(const CallNode* call) final { - auto op_node = call->op.as(); - auto ret = GetRef(call); + Expr VisitExpr_(const CallNode* cn) final { + auto region = regions_->GetRegion(GetRef(cn)); + auto new_e = ExprMutator::VisitExpr_(cn); + Call call = Downcast(new_e); + // Add compiler ends if the parent isn't annotated Array args; - - // Add compiler ends if the parent is supported for (auto arg : call->args) { - args.push_back(AddCompilerEnd(arg)); + args.push_back(InsertEnd(arg)); } - if (op_node == nullptr || call->attrs.as() == nullptr) { - // Skip annotatation ops, only add default compiler to actual compute nodes - - auto region = regions_->GetRegion(ret); - if (!region.defined()) { - // if the current node does not belong to annotated region - // annotate the all incoming edges (args) - // with "default" compile_begin and compiler_end annotations. - tvm::Array compiler_begins; - for (auto arg : args) { - const auto* begin_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - CHECK(begin_op); - Expr begin = (*begin_op)(arg, target_); - compiler_begins.push_back(begin); - } - Expr update_call = Call(call->op, compiler_begins, call->attrs); - return update_call; + Expr updated_call = Call(call->op, args, call->attrs); + if (!region.defined()) { + // if the current node does not belong to annotated region + // annotate the all incoming edges (args) + // with "default" compiler_begin annotations. + Array compiler_begins; + for (auto arg : args) { + compiler_begins.push_back(InsertBegin(arg)); } + updated_call = Call(call->op, compiler_begins, call->attrs); + } else { + annotated_nodes_.insert(updated_call); } - return Call(call->op, args, call->attrs); + return updated_call; }; - Expr VisitExpr_(const TupleNode *op) { + Expr VisitExpr_(const TupleNode* op) { + auto region = regions_->GetRegion(GetRef(op)); auto new_e = ExprMutator::VisitExpr_(op); - auto tup = Downcast(new_e); - Array new_fields; + Tuple tup = Downcast(new_e); + + Array fields; for (auto field : tup->fields) { - new_fields.push_back(AddCompilerEnd(field)); + fields.push_back(InsertEnd(field)); } - return Tuple(new_fields); + + Expr updated_tuple = Tuple(fields); + if (!region.defined()) { + Array compiler_begins; + for (const auto& field : fields) { + compiler_begins.push_back(InsertBegin(field)); + } + updated_tuple = Tuple(compiler_begins); + } else { + annotated_nodes_.insert(updated_tuple); + } + return updated_tuple; } - Expr VisitExpr_(const TupleGetItemNode *op) { + Expr VisitExpr_(const TupleGetItemNode* op) { + auto region = regions_->GetRegion(GetRef(op)); auto new_e = ExprMutator::VisitExpr_(op); auto get = Downcast(new_e); - return TupleGetItem(AddCompilerEnd(get->tuple), get->index); + + auto updated_tuple = InsertEnd(get->tuple); + Expr updated_get = TupleGetItem(updated_tuple, get->index); + if (!region.defined()) { + updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index); + } else { + annotated_nodes_.insert(updated_get); + } + return updated_get; } - Expr VisitExpr_(const LetNode *op) { + Expr VisitExpr_(const IfNode* op) { + auto region = regions_->GetRegion(GetRef(op)); auto new_e = ExprMutator::VisitExpr_(op); - auto let = Downcast(new_e); - return Let( - let->var, - AddCompilerEnd(let->value), - AddCompilerEnd(let->body)); + auto iff = Downcast(new_e); + + if (!region.defined()) { + return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)), + InsertBegin(InsertEnd(iff->false_branch))); + } else { + Expr updated_iff = + If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch)); + annotated_nodes_.insert(updated_iff); + return updated_iff; + } } - Expr VisitExpr_(const IfNode *op) { + Expr VisitExpr_(const LetNode* op) { auto new_e = ExprMutator::VisitExpr_(op); - auto iff = Downcast(new_e); - return If( - AddCompilerEnd(iff->cond), - AddCompilerEnd(iff->true_branch), - AddCompilerEnd(iff->false_branch)); + auto let = Downcast(new_e); + return Let(let->var, InsertEnd(let->value), InsertEnd(let->body)); } - Expr VisitExpr_(const RefCreateNode *op) { + Expr VisitExpr_(const RefCreateNode* op) { auto new_e = ExprMutator::VisitExpr_(op); auto create = Downcast(new_e); - return RefCreate(AddCompilerEnd(create->value)); + return RefCreate(InsertEnd(create->value)); } - Expr VisitExpr_(const RefReadNode *op) { + Expr VisitExpr_(const RefReadNode* op) { auto new_e = ExprMutator::VisitExpr_(op); auto read = Downcast(new_e); - return RefRead(AddCompilerEnd(read->ref)); + return RefRead(InsertEnd(read->ref)); } - Expr VisitExpr_(const RefWriteNode *op) { + Expr VisitExpr_(const RefWriteNode* op) { auto new_e = ExprMutator::VisitExpr_(op); auto write = Downcast(new_e); - return RefWrite( - AddCompilerEnd(write->ref), - AddCompilerEnd(write->value)); + return RefWrite(InsertEnd(write->ref), InsertEnd(write->value)); } private: - AnnotatedRegionSet regions_; - const std::string target_ = "default"; - Function func_; + AnnotatedRegionSet regions_; + const std::string target_ = "default"; + Function func_; + std::unordered_set annotated_nodes_; }; class MergeAnnotations : public ExprMutator { @@ -213,6 +229,14 @@ class MergeAnnotations : public ExprMutator { explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} Expr VisitExpr_(const CallNode* call) final { + // remove 'default' annotations + auto attrs = call->attrs.as(); + if (attrs != nullptr && attrs->compiler == "default") { + return VisitExpr(call->args[0]); + } + // Merge annotations which are now internal to a region. + // This happens if we see a compiler begin next to a + // compiler end and they're both in the same region. if (call->op == compiler_begin_op) { if (call->args[0]->IsInstance()) { auto arg = Downcast(call->args[0]); @@ -220,7 +244,7 @@ class MergeAnnotations : public ExprMutator { auto region1 = regions_->GetRegion(GetRef(call)); auto region2 = regions_->GetRegion(arg); if (region1 == region2) { - return ExprMutator::VisitExpr(arg->args[0]); + return VisitExpr(arg->args[0]); } } } @@ -242,7 +266,6 @@ class RegionMerger : public ExprVisitor { // set the region target auto compiler_attrs = call->attrs.as(); region_targets_[region->GetID()] = compiler_attrs->compiler; - std::vector mergeable_regions; // first look at the region args to determine the parent regions for (const auto& arg : region->GetInputs()) { // all args should be begin annotations @@ -256,14 +279,21 @@ class RegionMerger : public ExprVisitor { if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { VisitExpr(begin->args[0]); } + } + // get the mergeable regions now all the parents have been visited + std::vector mergeable_regions; + for (const auto& arg : region->GetInputs()) { + auto begin = Downcast(arg); + CHECK_EQ(begin->op, compiler_begin_op); + auto parent_region = regions_->GetRegion(begin->args[0]); + if (!parent_region.defined()) continue; mergeable_regions.push_back(parent_region); } auto& region_restrictions = region_restrictions_[region->GetID()]; for (const auto& parent_region : mergeable_regions) { // add all the parent restrictions to the current region auto parent_restrictions = region_restrictions_[parent_region->GetID()]; - region_restrictions.insert(parent_restrictions.begin(), - parent_restrictions.end()); + region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end()); } for (const auto& parent_region : mergeable_regions) { bool merged = false; @@ -273,7 +303,8 @@ class RegionMerger : public ExprVisitor { if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) { // merge the parent region into the current region regions_->MergeRegions(parent_region, region); - // update the restrictions of all other regions to reflect the change in id + // update the restrictions of all other regions to reflect the + // change in id for (const auto& r : regions_) { auto& restrictions = region_restrictions_[r->GetID()]; if (restrictions.find(parent_region->GetID()) != restrictions.end()) { @@ -284,9 +315,9 @@ class RegionMerger : public ExprVisitor { merged = true; } } - // if the parent wasn't merged, add it as a restriction to the current region - if (!merged) - region_restrictions.insert(parent_region->GetID()); + // if the parent wasn't merged, add it as a restriction to the current + // region + if (!merged) region_restrictions.insert(parent_region->GetID()); } merged_regions_.insert(region->GetID()); } @@ -300,15 +331,14 @@ class RegionMerger : public ExprVisitor { std::map region_targets_; }; - Expr MergeCompilerRegions(const Expr& expr) { // Annotate all the nodes that aren't annotated as 'default'. AnnotateRestDefault anno_default(expr); auto expr_all_annotated = anno_default.Annotate(expr); // Create regions using the annotations. - AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr_all_annotated, - compiler_begin_op, compiler_end_op); + AnnotatedRegionSet regions = + AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op); // By now, all the nodes have some sort of annotation. // Region merger is an ExprVisitor that will update the @@ -336,7 +366,7 @@ Pass MergeCompilerRegions() { } TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions") -.set_body_typed(transform::MergeCompilerRegions); + .set_body_typed(transform::MergeCompilerRegions); } // namespace transform diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 98de3f5e9b06..87cf7616e232 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -113,7 +113,8 @@ def expected(dtype, ishape, w1shape): padding=(1, 1), groups=32) end0 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl") - begin2 = relay.annotation.compiler_begin(end0, "dnnl") + end1 = relay.annotation.compiler_end(depthwise_conv2d_1, "dnnl") + begin2 = relay.annotation.compiler_begin(end1, "dnnl") begin3 = relay.annotation.compiler_begin(end0, "dnnl") begin4 = relay.annotation.compiler_begin(weight1, "dnnl") depthwise_conv2d_2 = relay.nn.conv2d(begin3, @@ -121,11 +122,11 @@ def expected(dtype, ishape, w1shape): kernel_size=(3, 3), padding=(1, 1), groups=32) - end1 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl") - begin5 = relay.annotation.compiler_begin(end1, "dnnl") + end2 = relay.annotation.compiler_end(depthwise_conv2d_2, "dnnl") + begin5 = relay.annotation.compiler_begin(end2, "dnnl") out = relay.add(begin2, begin5) - end2 = relay.annotation.compiler_end(out, "dnnl") - f = relay.Function([data, weight1], end2) + end3 = relay.annotation.compiler_end(out, "dnnl") + f = relay.Function([data, weight1], end3) mod = tvm.IRModule.from_expr(f) return mod @@ -137,7 +138,7 @@ def test_annotate(): mod = annotated(dtype, ishape, w1shape) mod = transform.AnnotateTarget("dnnl")(mod) ref_mod = expected(dtype, ishape, w1shape) - # tvm.ir.assert_structural_equal(mod, ref_mod) + tvm.ir.assert_structural_equal(mod, ref_mod) def test_run(): if not tvm.get_global_func("relay.ext.dnnl", True): diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py index 364973f0ce8a..f316a41a88da 100644 --- a/tests/python/relay/test_pass_merge_compiler_regions.py +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -66,13 +66,10 @@ def expected(): O_2 = relay.nn.relu(O_1) ce_3 = compiler_end(O_2, "test") - cb_x = compiler_begin(ce_2, "default") - X = relay.tanh(cb_x) - ce_x1 = compiler_end(X, "default") - ce_x2 = compiler_end(X, "default") + X = relay.tanh(ce_2) cb_3 = compiler_begin(ce_3, "test") - cb_4 = compiler_begin(ce_x1, "test") + cb_4 = compiler_begin(X, "test") O_3 = relay.add(cb_3, cb_4) ce_4 = compiler_end(O_3, "test") @@ -162,36 +159,28 @@ def expected(): node1 = relay.add(begin2, begin3) node2 = relay.add(node0, node1) - begin4 = compiler_begin(in_5, "default") - begin5 = compiler_begin(in_6, "default") - begin6 = compiler_begin(in_7, "default") - node3 = relay.subtract(begin4, begin5) - node4 = relay.subtract(begin6, node3) - end0 = compiler_end(node4, "default") - - begin7 = compiler_begin(end0, "test") - begin8 = compiler_begin(in_9, "test") + node3 = relay.subtract(in_5, in_6) + node4 = relay.subtract(in_7, node3) - node5 = relay.add(node2, begin7) + begin4 = compiler_begin(node4, "test") + begin5 = compiler_begin(in_9, "test") + node5 = relay.add(node2, begin4) end1 = compiler_end(node5, "test") - begin9 = compiler_begin(end1, "default") - begin10 = compiler_begin(in_8, "default") - node6 = relay.subtract(begin10, begin9) - end2 = compiler_end(node6, "default") + node6 = relay.subtract(in_8, end1) - node7 = relay.add(begin8, node5) - end3 = compiler_end(node7, "test") - begin11 = compiler_begin(end3, "test") - begin12 = compiler_begin(end2, "test") + node7 = relay.add(begin5, node5) + end2 = compiler_end(node7, "test") + begin6 = compiler_begin(end2, "test") + begin7 = compiler_begin(node6, "test") - node8 = relay.add(begin12, begin11) + node8 = relay.add(begin7, begin6) - begin13 = compiler_begin(in_10, "test") - node9 = relay.add(begin13, node8) - end4 = compiler_end(node9, "test") + begin8 = compiler_begin(in_10, "test") + node9 = relay.add(begin8, node8) + end3 = compiler_end(node9, "test") - f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end4) + f = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end3) mod = tvm.IRModule.from_expr(f) return mod diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 0dfc89d469ca..9d4d71179fd7 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -725,12 +725,12 @@ def expected(): mod = tvm.IRModule() # function 0 - data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32")) - weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32")) - bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32")) - bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32")) - bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32")) - bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32")) + data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32")) + weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32")) + bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32")) + bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32")) + bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32")) + bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32")) conv_o = relay.nn.conv2d( data=data, @@ -743,7 +743,7 @@ def expected(): bn_var) relu_o = relay.nn.relu(bn_o[0]) - tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2])) + tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o)) func0 = relay.Function([data, weight, bn_gamma, bn_beta, bn_mean, bn_var], tuple_o) @@ -752,8 +752,8 @@ def expected(): func0 = func0.with_attr("Compiler", tvm.tir.StringImm("test_target")) func0 = func0.with_attr("ExternalSymbol", - tvm.tir.StringImm("test_target_0")) - gv0 = relay.GlobalVar("test_target_0") + tvm.tir.StringImm("test_target_2")) + gv0 = relay.GlobalVar("test_target_2") mod[gv0] = func0 # body @@ -765,9 +765,9 @@ def expected(): bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32")) f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var) - f0_relu_o = relay.TupleGetItem(f0_o, 0) + f0_relu_o = relay.TupleGetItem(f0_o, 2) f0_mean_o = relay.TupleGetItem(f0_o, 1) - f0_var_o = relay.TupleGetItem(f0_o, 2) + f0_var_o = relay.TupleGetItem(f0_o, 0) f0_mean_abs = relay.abs(f0_mean_o) f0_var_abs = relay.abs(f0_var_o)