Skip to content

Commit

Permalink
Fix duplicate output in partitiongraph
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris committed Apr 13, 2020
1 parent 5958d60 commit e1f8ef3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/relay/analysis/annotated_region_set.cc
Expand Up @@ -131,7 +131,13 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
CHECK_EQ(region->GetTarget(), target);
}
region->nodes_.insert(GetRef<Call>(call));
region->outs_.push_back(GetRef<Call>(call));
if (!std::any_of(region->outs_.begin(), region->outs_.end(),
[call](Expr& out) {
return Downcast<Call>(out)->args[0] ==
GetRef<Call>(call)->args[0];
})) {
region->outs_.push_back(GetRef<Call>(call));
}
}
ExprVisitor::VisitExpr_(call);
}
Expand Down
9 changes: 6 additions & 3 deletions src/relay/transforms/partition_graph.cc
Expand Up @@ -207,9 +207,12 @@ class Partitioner : public ExprMutator {

if (region_function_calls.find(region) != region_function_calls.end()) {
// This section is executed only if there are multiple outputs in the
// region Thus, the function is always created and at the end there
// region or the same output is being accessed multiple times.
// Thus, the function is always created and at the end there
// would be a tuple node Therefore, we insert a tuple get item node.

if (region->GetOutputs().size() == 1) {
return region_function_calls[region];
}
// Use the already created tuple node
auto sg_call = region_function_calls[region];
int index = GetRetIdx(region, GetRef<Call>(call));
Expand Down Expand Up @@ -462,7 +465,7 @@ class Partitioner : public ExprMutator {
int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
int idx = 0;
for (auto arg_ : sg->GetOutputs()) {
if (arg == arg_) {
if (Downcast<Call>(arg)->args[0] == Downcast<Call>(arg_)->args[0]) {
return idx;
}
idx++;
Expand Down

0 comments on commit e1f8ef3

Please sign in to comment.