From e1f8ef3f4ca5b2aaa31ace6fa968bb50e5e4d1fa Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 10 Apr 2020 00:13:16 +0000 Subject: [PATCH] Fix duplicate output in partitiongraph --- src/relay/analysis/annotated_region_set.cc | 8 +++++++- src/relay/transforms/partition_graph.cc | 9 ++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index 94c7621e60afc..53b58d494365b 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -131,7 +131,13 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { CHECK_EQ(region->GetTarget(), target); } region->nodes_.insert(GetRef(call)); - region->outs_.push_back(GetRef(call)); + if (!std::any_of(region->outs_.begin(), region->outs_.end(), + [call](Expr& out) { + return Downcast(out)->args[0] == + GetRef(call)->args[0]; + })) { + region->outs_.push_back(GetRef(call)); + } } ExprVisitor::VisitExpr_(call); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index c8367fb140f21..172faa0c267aa 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -207,9 +207,12 @@ class Partitioner : public ExprMutator { if (region_function_calls.find(region) != region_function_calls.end()) { // This section is executed only if there are multiple outputs in the - // region Thus, the function is always created and at the end there + // region or the same output is being accessed multiple times. + // Thus, the function is always created and at the end there // would be a tuple node Therefore, we insert a tuple get item node. - + if (region->GetOutputs().size() == 1) { + return region_function_calls[region]; + } // Use the already created tuple node auto sg_call = region_function_calls[region]; int index = GetRetIdx(region, GetRef(call)); @@ -462,7 +465,7 @@ class Partitioner : public ExprMutator { int GetRetIdx(AnnotatedRegion sg, const Expr& arg) { int idx = 0; for (auto arg_ : sg->GetOutputs()) { - if (arg == arg_) { + if (Downcast(arg)->args[0] == Downcast(arg_)->args[0]) { return idx; } idx++;