Skip to content

Commit

Permalink
[mlir][scf] Canonicalize nested scf.if's to scf.if + arith.and
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D115930
  • Loading branch information
Hardcode84 committed Dec 20, 2021
1 parent de90490 commit c7f96d5
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 4 deletions.
54 changes: 50 additions & 4 deletions mlir/lib/Dialect/SCF/SCF.cpp
Expand Up @@ -1596,14 +1596,60 @@ struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
}
};

/// Convert nested `if`s into `arith.andi` + single `if`.
///
/// scf.if %arg0 {
/// scf.if %arg1 {
/// ...
/// scf.yield
/// }
/// scf.yield
/// }
/// becomes
///
/// %0 = arith.andi %arg0, %arg1
/// scf.if %0 {
/// ...
/// scf.yield
/// }
struct CombineNestedIfs : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
// Both `if` ops must not yield results and have only `then` block.
if (op->getNumResults() != 0 || op.elseBlock())
return failure();

auto nestedOps = op.thenBlock()->without_terminator();
// Nested `if` must be the only op in block.
if (!llvm::hasSingleElement(nestedOps))
return failure();

auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
if (!nestedIf || nestedIf->getNumResults() != 0 || nestedIf.elseBlock())
return failure();

Location loc = op.getLoc();
Value newCondition = rewriter.create<arith::AndIOp>(loc, op.condition(),
nestedIf.condition());
auto newIf = rewriter.create<IfOp>(loc, newCondition);
Block *newIfBlock = newIf.thenBlock();
rewriter.eraseOp(newIfBlock->getTerminator());
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
rewriter.eraseOp(op);
return success();
}
};

} // namespace

void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
ConditionPropagation, ReplaceIfYieldWithConditionOrValue, CombineIfs,
RemoveEmptyElseBranch>(context);
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
RemoveStaticCondition, RemoveUnusedResults,
ReplaceIfYieldWithConditionOrValue>(context);
}

Block *IfOp::thenBlock() { return &getThenRegion().back(); }
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Expand Up @@ -429,6 +429,24 @@ func @replace_false_if_with_values() {

// -----

// CHECK-LABEL: @merge_nested_if
// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i1)
func @merge_nested_if(%arg0: i1, %arg1: i1) {
// CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]]
// CHECK: scf.if %[[COND]] {
// CHECK-NEXT: "test.op"()
scf.if %arg0 {
scf.if %arg1 {
"test.op"() : () -> ()
scf.yield
}
scf.yield
}
return
}

// -----

// CHECK-LABEL: @remove_zero_iteration_loop
func @remove_zero_iteration_loop() {
%c42 = arith.constant 42 : index
Expand Down

0 comments on commit c7f96d5

Please sign in to comment.