Skip to content

Commit

Permalink
[mlir] add canonicalization patterns for trivial SCF 'for' and 'if'
Browse files Browse the repository at this point in the history
Add canoncalization patterns to remove zero-iteration 'for' loops, replace
single-iteration 'for' loops with their bodies; remove known-false conditionals
with no 'else' branch and replace conditionals with known value by the
respective region. Although similar transformations are performed at the CFG
level, not all flows reach that level, e.g., the GPU flow may want to remove
single-iteration loops before deciding on loop mapping to thread dimensions.

Reviewed By: herhut

Differential Revision: https://reviews.llvm.org/D91865
  • Loading branch information
ftynse committed Nov 20, 2020
1 parent dbcc692 commit 18d0f7d
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 19 deletions.
77 changes: 75 additions & 2 deletions mlir/lib/Dialect/SCF/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,19 @@ LoopNest mlir::scf::buildLoopNest(
});
}

/// Replaces the given op with the contents of the given single-block region,
/// using the operands of the block terminator to replace operation results.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
Region &region, ValueRange blockArgs = {}) {
assert(llvm::hasSingleElement(region) && "expected single-region block");
Block *block = &region.front();
Operation *terminator = block->getTerminator();
ValueRange results = terminator->getOperands();
rewriter.mergeBlockBefore(block, op, blockArgs);
rewriter.replaceOp(op, results);
rewriter.eraseOp(terminator);
}

namespace {
// Fold away ForOp iter arguments that are also yielded by the op.
// These arguments must be defined outside of the ForOp region and can just be
Expand Down Expand Up @@ -500,11 +513,51 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
return success();
}
};

/// Rewriting pattern that erases loops that are known not to iterate and
/// replaces single-iteration loops with their bodies.
struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ForOp op,
PatternRewriter &rewriter) const override {
auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
auto ub = op.upperBound().getDefiningOp<ConstantOp>();
if (!lb || !ub)
return failure();

// If the loop is known to have 0 iterations, remove it.
llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
if (lbValue.sge(ubValue)) {
rewriter.replaceOp(op, op.getIterOperands());
return success();
}

auto step = op.step().getDefiningOp<ConstantOp>();
if (!step)
return failure();

// If the loop is known to have 1 iteration, inline its body and remove the
// loop.
llvm::APInt stepValue = lb.getValue().cast<IntegerAttr>().getValue();
if ((lbValue + stepValue).sge(ubValue)) {
SmallVector<Value, 4> blockArgs;
blockArgs.reserve(op.getNumIterOperands() + 1);
blockArgs.push_back(op.lowerBound());
llvm::append_range(blockArgs, op.getIterOperands());
replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
return success();
}

return failure();
}
};
} // namespace

void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ForOpIterArgsFolder>(context);
results.insert<ForOpIterArgsFolder, SimplifyTrivialLoops>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -710,11 +763,31 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
return success();
}
};

struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IfOp op,
PatternRewriter &rewriter) const override {
auto constant = op.condition().getDefiningOp<ConstantOp>();
if (!constant)
return failure();

if (constant.getValue().cast<BoolAttr>().getValue())
replaceOpWithRegion(rewriter, op, op.thenRegion());
else if (!op.elseRegion().empty())
replaceOpWithRegion(rewriter, op, op.elseRegion());
else
rewriter.eraseOp(op);

return success();
}
};
} // namespace

void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RemoveUnusedResults>(context);
results.insert<RemoveUnusedResults, RemoveStaticCondition>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
140 changes: 123 additions & 17 deletions mlir/test/Dialect/SCF/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,10 @@ func @no_iteration(%A: memref<?x?xi32>) {

// -----

func @one_unused() -> (index) {
func @one_unused(%cond: i1) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%true = constant true
%0, %1 = scf.if %true -> (index, index) {
%0, %1 = scf.if %cond -> (index, index) {
scf.yield %c0, %c1 : index, index
} else {
scf.yield %c0, %c1 : index, index
Expand All @@ -70,8 +69,7 @@ func @one_unused() -> (index) {

// CHECK-LABEL: func @one_unused
// CHECK: [[C0:%.*]] = constant 1 : index
// CHECK: [[C1:%.*]] = constant true
// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) {
// CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) {
// CHECK: scf.yield [[C0]] : index
// CHECK: } else
// CHECK: scf.yield [[C0]] : index
Expand All @@ -80,12 +78,11 @@ func @one_unused() -> (index) {

// -----

func @nested_unused() -> (index) {
func @nested_unused(%cond1: i1, %cond2: i1) -> (index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%true = constant true
%0, %1 = scf.if %true -> (index, index) {
%2, %3 = scf.if %true -> (index, index) {
%0, %1 = scf.if %cond1 -> (index, index) {
%2, %3 = scf.if %cond2 -> (index, index) {
scf.yield %c0, %c1 : index, index
} else {
scf.yield %c0, %c1 : index, index
Expand All @@ -99,9 +96,8 @@ func @nested_unused() -> (index) {

// CHECK-LABEL: func @nested_unused
// CHECK: [[C0:%.*]] = constant 1 : index
// CHECK: [[C1:%.*]] = constant true
// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) {
// CHECK: [[V1:%.*]] = scf.if [[C1]] -> (index) {
// CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) {
// CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) {
// CHECK: scf.yield [[C0]] : index
// CHECK: } else
// CHECK: scf.yield [[C0]] : index
Expand All @@ -115,11 +111,10 @@ func @nested_unused() -> (index) {
// -----

func private @side_effect() {}
func @all_unused() {
func @all_unused(%cond: i1) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%true = constant true
%0, %1 = scf.if %true -> (index, index) {
%0, %1 = scf.if %cond -> (index, index) {
call @side_effect() : () -> ()
scf.yield %c0, %c1 : index, index
} else {
Expand All @@ -130,8 +125,7 @@ func @all_unused() {
}

// CHECK-LABEL: func @all_unused
// CHECK: [[C1:%.*]] = constant true
// CHECK: scf.if [[C1]] {
// CHECK: scf.if %{{.*}} {
// CHECK: call @side_effect() : () -> ()
// CHECK: } else
// CHECK: call @side_effect() : () -> ()
Expand Down Expand Up @@ -172,3 +166,115 @@ func @for_yields_3(%lb : index, %ub : index, %step : index) -> (i32, i32, i32) {
// CHECK-NEXT: scf.yield %[[c]] : i32
// CHECK-NEXT: }
// CHECK-NEXT: return %[[a]], %[[r1]], %[[b]] : i32, i32, i32

// CHECK-LABEL: @replace_true_if
func @replace_true_if() {
%true = constant true
// CHECK-NOT: scf.if
// CHECK: "test.op"
scf.if %true {
"test.op"() : () -> ()
scf.yield
}
return
}

// CHECK-LABEL: @remove_false_if
func @remove_false_if() {
%false = constant false
// CHECK-NOT: scf.if
// CHECK-NOT: "test.op"
scf.if %false {
"test.op"() : () -> ()
scf.yield
}
return
}

// CHECK-LABEL: @replace_true_if_with_values
func @replace_true_if_with_values() {
%true = constant true
// CHECK-NOT: scf.if
// CHECK: %[[VAL:.*]] = "test.op"
%0 = scf.if %true -> (i32) {
%1 = "test.op"() : () -> i32
scf.yield %1 : i32
} else {
%2 = "test.other_op"() : () -> i32
scf.yield %2 : i32
}
// CHECK: "test.consume"(%[[VAL]])
"test.consume"(%0) : (i32) -> ()
return
}

// CHECK-LABEL: @replace_false_if_with_values
func @replace_false_if_with_values() {
%false = constant false
// CHECK-NOT: scf.if
// CHECK: %[[VAL:.*]] = "test.other_op"
%0 = scf.if %false -> (i32) {
%1 = "test.op"() : () -> i32
scf.yield %1 : i32
} else {
%2 = "test.other_op"() : () -> i32
scf.yield %2 : i32
}
// CHECK: "test.consume"(%[[VAL]])
"test.consume"(%0) : (i32) -> ()
return
}

// CHECK-LABEL: @remove_zero_iteration_loop
func @remove_zero_iteration_loop() {
%c42 = constant 42 : index
%c1 = constant 1 : index
// CHECK: %[[INIT:.*]] = "test.init"
%init = "test.init"() : () -> i32
// CHECK-NOT: scf.for
%0 = scf.for %i = %c42 to %c1 step %c1 iter_args(%arg = %init) -> (i32) {
%1 = "test.op"(%i, %arg) : (index, i32) -> i32
scf.yield %1 : i32
}
// CHECK: "test.consume"(%[[INIT]])
"test.consume"(%0) : (i32) -> ()
return
}

// CHECK-LABEL: @replace_single_iteration_loop
func @replace_single_iteration_loop() {
// CHECK: %[[LB:.*]] = constant 42
%c42 = constant 42 : index
%c43 = constant 43 : index
%c1 = constant 1 : index
// CHECK: %[[INIT:.*]] = "test.init"
%init = "test.init"() : () -> i32
// CHECK-NOT: scf.for
// CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
%0 = scf.for %i = %c42 to %c43 step %c1 iter_args(%arg = %init) -> (i32) {
%1 = "test.op"(%i, %arg) : (index, i32) -> i32
scf.yield %1 : i32
}
// CHECK: "test.consume"(%[[VAL]])
"test.consume"(%0) : (i32) -> ()
return
}

// CHECK-LABEL: @replace_single_iteration_loop_non_unit_step
func @replace_single_iteration_loop_non_unit_step() {
// CHECK: %[[LB:.*]] = constant 42
%c42 = constant 42 : index
%c47 = constant 47 : index
%c5 = constant 5 : index
// CHECK: %[[INIT:.*]] = "test.init"
%init = "test.init"() : () -> i32
// CHECK-NOT: scf.for
// CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
%0 = scf.for %i = %c42 to %c47 step %c5 iter_args(%arg = %init) -> (i32) {
%1 = "test.op"(%i, %arg) : (index, i32) -> i32
scf.yield %1 : i32
}
// CHECK: "test.consume"(%[[VAL]])
"test.consume"(%0) : (i32) -> ()
return
}

0 comments on commit 18d0f7d

Please sign in to comment.