Skip to content

Commit

Permalink
Add a PatternRewriter hook to merge blocks, and use it to support for…
Browse files Browse the repository at this point in the history
… folding branches.

A pattern rewriter hook, mergeBlock, is added that allows for merging the operations of one block into the end of another. This is used to support a canonicalization pattern for branch operations that folds the branch when the successor has a single predecessor(the branch block).

Example:
  ^bb0:
    %c0_i32 = constant 0 : i32
    br ^bb1(%c0_i32 : i32)
  ^bb1(%x : i32):
    return %x : i32

becomes:
  ^bb0:
    %c0_i32 = constant 0 : i32
    return %c0_i32 : i32
PiperOrigin-RevId: 278677825
  • Loading branch information
River707 authored and tensorflower-gardener committed Nov 5, 2019
1 parent 6d24325 commit 2366561
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 7 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/StandardOps/Ops.td
Expand Up @@ -232,6 +232,8 @@ def BranchOp : Std_Op<"br", [Terminator]> {
/// Erase the operand at 'index' from the operand list.
void eraseOperand(unsigned index);
}];

let hasCanonicalizer = 1;
}

def CallOp : Std_Op<"call", [CallOpInterface]> {
Expand Down
11 changes: 8 additions & 3 deletions mlir/include/mlir/IR/PatternMatch.h
Expand Up @@ -359,11 +359,16 @@ class PatternRewriter : public OpBuilder {
/// This method erases an operation that is known to have no uses.
virtual void eraseOp(Operation *op);

/// Merge the operations of block 'source' into the end of block 'dest'.
/// 'source's predecessors must either be empty or only contain 'dest`.
/// 'argValues' is used to replace the block arguments of 'source' after
/// merging.
virtual void mergeBlocks(Block *source, Block *dest,
ArrayRef<Value *> argValues = llvm::None);

/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
virtual Block *splitBlock(Block *block, Block::iterator before) {
return block->splitBlock(before);
}
virtual Block *splitBlock(Block *block, Block::iterator before);

/// This method is used as the final notification hook for patterns that end
/// up modifying the pattern root in place, by changing its operands. This is
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Expand Up @@ -352,6 +352,10 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// PatternRewriter hook for splitting a block into two parts.
Block *splitBlock(Block *block, Block::iterator before) override;

/// PatternRewriter hook for merging a block into another.
void mergeBlocks(Block *source, Block *dest,
ArrayRef<Value *> argValues) override;

/// PatternRewriter hook for moving blocks out of a region.
void inlineRegionBefore(Region &region, Region &parent,
Region::iterator before) override;
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/StandardOps/Ops.cpp
Expand Up @@ -473,6 +473,28 @@ void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
// BranchOp
//===----------------------------------------------------------------------===//

namespace {
/// Simplify a branch to a block that has a single predecessor. This effectively
/// merges the two blocks.
struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
using OpRewritePattern<BranchOp>::OpRewritePattern;

PatternMatchResult matchAndRewrite(BranchOp op,
PatternRewriter &rewriter) const override {
// Check that the successor block has a single predecessor.
Block *succ = op.getDest();
Block *opParent = op.getOperation()->getBlock();
if (succ == opParent || !has_single_element(succ->getPredecessors()))
return matchFailure();

// Merge the successor into the current block and erase the branch.
rewriter.mergeBlocks(succ, opParent, llvm::to_vector<1>(op.getOperands()));
rewriter.eraseOp(op);
return matchSuccess();
}
};
} // end anonymous namespace.

static ParseResult parseBranchOp(OpAsmParser &parser, OperationState &result) {
Block *dest;
SmallVector<Value *, 4> destOperands;
Expand All @@ -495,6 +517,11 @@ void BranchOp::eraseOperand(unsigned index) {
getOperation()->eraseSuccessorOperand(0, index);
}

void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SimplifyBrToBlockWithSinglePred>(context);
}

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/IR/PatternMatch.cpp
Expand Up @@ -109,6 +109,35 @@ void PatternRewriter::eraseOp(Operation *op) {
op->erase();
}

/// Merge the operations of block 'source' into the end of block 'dest'.
/// 'source's predecessors must be empty or only contain 'dest`.
/// 'argValues' is used to replace the block arguments of 'source' after
/// merging.
void PatternRewriter::mergeBlocks(Block *source, Block *dest,
ArrayRef<Value *> argValues) {
assert(llvm::all_of(source->getPredecessors(),
[dest](Block *succ) { return succ == dest; }) &&
"expected 'source' to have no predecessors or only 'dest'");
assert(argValues.size() == source->getNumArguments() &&
"incorrect # of argument replacement values");

// Replace all of the successor arguments with the provided values.
for (auto it : llvm::zip(source->getArguments(), argValues))
std::get<0>(it)->replaceAllUsesWith(std::get<1>(it));

// Splice the operations of the 'source' block into the 'dest' block and erase
// it.
dest->getOperations().splice(dest->end(), source->getOperations());
source->dropAllUses();
source->erase();
}

/// Split the operations starting at "before" (inclusive) out of the given
/// block into a new block, and return it.
Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
return block->splitBlock(before);
}

/// op and newOp are known to have the same number of results, replace the
/// uses of op with uses of newOp
void PatternRewriter::replaceOpWithResultsOfAnotherOp(
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Transforms/DialectConversion.cpp
Expand Up @@ -789,6 +789,14 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
return continuation;
}

/// PatternRewriter hook for merging a block into another.
void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
ArrayRef<Value *> argValues) {
// TODO(riverriddle) This requires fixing the implementation of
// 'replaceUsesOfBlockArgument', which currently isn't undoable.
llvm_unreachable("block merging updates are currently not supported");
}

/// PatternRewriter hook for moving blocks out of a region.
void ConversionPatternRewriter::inlineRegionBefore(Region &region,
Region &parent,
Expand Down
38 changes: 34 additions & 4 deletions mlir/test/Transforms/canonicalize.mlir
Expand Up @@ -406,16 +406,46 @@ func @const_fold_propagate() -> memref<?x?xf32> {
return %Av : memref<?x?xf32>
}

// CHECK-LABEL: func @br_folding
func @br_folding() -> i32 {
// CHECK-NEXT: %[[CST:.*]] = constant 0 : i32
// CHECK-NEXT: return %[[CST]] : i32
%c0_i32 = constant 0 : i32
br ^bb1(%c0_i32 : i32)
^bb1(%x : i32):
return %x : i32
}

// CHECK-LABEL: func @cond_br_folding
func @cond_br_folding(%a : i32) {
func @cond_br_folding(%cond : i1, %a : i32) {
%false_cond = constant 0 : i1
%true_cond = constant 1 : i1
cond_br %cond, ^bb1, ^bb2(%a : i32)

^bb1:
// CHECK: ^bb1:
// CHECK-NEXT: br ^bb3
cond_br %true_cond, ^bb3, ^bb2(%a : i32)

^bb2(%x : i32):
// CHECK: ^bb2
// CHECK: br ^bb3
cond_br %false_cond, ^bb2(%x : i32), ^bb3

// CHECK-NEXT: br ^bb1(%arg0 : i32)
cond_br %true_cond, ^bb1(%a : i32), ^bb2
^bb3:
return
}

// CHECK-LABEL: func @cond_br_and_br_folding
func @cond_br_and_br_folding(%a : i32) {
// Test the compound folding of conditional and unconditional branches.
// CHECK-NEXT: return

%false_cond = constant 0 : i1
%true_cond = constant 1 : i1
cond_br %true_cond, ^bb2, ^bb1(%a : i32)

^bb1(%x : i32):
// CHECK: br ^bb2
cond_br %false_cond, ^bb1(%x : i32), ^bb2

^bb2:
Expand Down

0 comments on commit 2366561

Please sign in to comment.