Skip to content

Commit

Permalink
[MLIR] Correct block merge bug
Browse files Browse the repository at this point in the history
Block merging in MLIR will incorrectly merge blocks with operations whose values are used outside of that block. This change forbids this behavior and provides a test where it is illegal to perform such a merge.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D91745
  • Loading branch information
wsmoses authored and ftynse committed Nov 20, 2020
1 parent 88e6208 commit f5c5fd1
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 18 deletions.
23 changes: 10 additions & 13 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Expand Up @@ -464,10 +464,6 @@ class BlockMergeCluster {
/// A set of operand+index pairs that correspond to operands that need to be
/// replaced by arguments when the cluster gets merged.
std::set<std::pair<int, int>> operandsToMerge;

/// A map of operations with external uses to a replacement within the leader
/// block.
DenseMap<Operation *, Operation *> opsToReplace;
};
} // end anonymous namespace

Expand All @@ -480,7 +476,6 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {

// A set of operands that mismatch between the leader and the new block.
SmallVector<std::pair<int, int>, 8> mismatchedOperands;
SmallVector<std::pair<Operation *, Operation *>, 2> newOpsToReplace;
auto lhsIt = leaderBlock->begin(), lhsE = leaderBlock->end();
auto rhsIt = blockData.block->begin(), rhsE = blockData.block->end();
for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) {
Expand Down Expand Up @@ -519,17 +514,23 @@ LogicalResult BlockMergeCluster::addToCluster(BlockEquivalenceData &blockData) {
return failure();
}

// If the rhs has external uses, it will need to be replaced.
if (rhsIt->isUsedOutsideOfBlock(mergeBlock))
newOpsToReplace.emplace_back(&*rhsIt, &*lhsIt);
// If the lhs or rhs has external uses, the blocks cannot be merged as the
// merged version of this operation will not be either the lhs or rhs
// alone (thus semantically incorrect), but some mix dependening on which
// block preceeded this.
// TODO allow merging of operations when one block does not dominate the
// other
if (rhsIt->isUsedOutsideOfBlock(mergeBlock) ||
lhsIt->isUsedOutsideOfBlock(leaderBlock)) {
return failure();
}
}
// Make sure that the block sizes are equivalent.
if (lhsIt != lhsE || rhsIt != rhsE)
return failure();

// If we get here, the blocks are equivalent and can be merged.
operandsToMerge.insert(mismatchedOperands.begin(), mismatchedOperands.end());
opsToReplace.insert(newOpsToReplace.begin(), newOpsToReplace.end());
blocksToMerge.insert(blockData.block);
return success();
}
Expand Down Expand Up @@ -561,10 +562,6 @@ LogicalResult BlockMergeCluster::merge() {
!llvm::all_of(blocksToMerge, ableToUpdatePredOperands))
return failure();

// Replace any necessary operations.
for (std::pair<Operation *, Operation *> &it : opsToReplace)
it.first->replaceAllUsesWith(it.second);

// Collect the iterators for each of the blocks to merge. We will walk all
// of the iterators at once to avoid operand index invalidation.
SmallVector<Block::iterator, 2> blockIterators;
Expand Down
37 changes: 32 additions & 5 deletions mlir/test/Transforms/canonicalize-block-merge.mlir
Expand Up @@ -174,26 +174,24 @@ func @contains_regions(%cond : i1) {
return
}

// Check that properly handles back edges and the case where a value from one
// block is used in another.
// Check that properly handles back edges.

// CHECK-LABEL: func @mismatch_loop(
// CHECK-SAME: %[[ARG:.*]]: i1, %[[ARG2:.*]]: i1
func @mismatch_loop(%cond : i1, %cond2 : i1) {
// CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op"
// CHECK: cond_br %{{.*}}, ^bb1(%[[ARG2]] : i1), ^bb2

%cond3 = "foo.op"() : () -> (i1)
cond_br %cond, ^bb2, ^bb3

^bb1:
// CHECK: ^bb1(%[[ARG3:.*]]: i1):
// CHECK-NEXT: %[[LOOP_CARRY:.*]] = "foo.op"
// CHECK-NEXT: cond_br %[[ARG3]], ^bb1(%[[LOOP_CARRY]] : i1), ^bb2

%ignored = "foo.op"() : () -> (i1)
cond_br %cond3, ^bb1, ^bb3

^bb2:
%cond3 = "foo.op"() : () -> (i1)
cond_br %cond2, ^bb1, ^bb3

^bb3:
Expand Down Expand Up @@ -224,3 +222,32 @@ func @mismatch_operand_types(%arg0 : i1, %arg1 : memref<i32>, %arg2 : memref<i1>
store %true, %arg2[] : memref<i1>
br ^bb1
}

// Check that it is illegal to merge blocks containing an operand
// with an external user. Incorrectly performing the optimization
// anyways will result in print(merged, merged) rather than
// distinct operands.
func private @print(%arg0: i32, %arg1: i32)
// CHECK-LABEL: @nomerge
func @nomerge(%arg0: i32, %i: i32) {
%c1_i32 = constant 1 : i32
%icmp = cmpi "slt", %i, %arg0 : i32
cond_br %icmp, ^bb2, ^bb3

^bb2: // pred: ^bb1
%ip1 = addi %i, %c1_i32 : i32
br ^bb4(%ip1 : i32)

^bb7: // pred: ^bb5
%jp1 = addi %j, %c1_i32 : i32
br ^bb4(%jp1 : i32)

^bb4(%j: i32): // 2 preds: ^bb2, ^bb7
%jcmp = cmpi "slt", %j, %arg0 : i32
// CHECK-NOT: call @print(%[[arg1:.+]], %[[arg1]])
call @print(%j, %ip1) : (i32, i32) -> ()
cond_br %jcmp, ^bb7, ^bb3

^bb3: // pred: ^bb1
return
}

0 comments on commit f5c5fd1

Please sign in to comment.