Skip to content

Commit

Permalink
[MLIR] Simplify affine.if having yield values and trivial conditions
Browse files Browse the repository at this point in the history
When an affine.if operation is returning/yielding results and has a
trivially true or false condition, then its 'then' or 'else' block,
respectively, is promoted to the affine.if's parent block and then, the
affine.if operation is replaced by the correct results/yield values.
Relevant test cases are also added.

Signed-off-by: Srishti Srivastava <srishti.srivastava@polymagelabs.com>

Differential Revision: https://reviews.llvm.org/D105418
  • Loading branch information
srishti-pm authored and bondhugula committed Jul 7, 2021
1 parent 1894c89 commit 0c1a773
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 28 deletions.
52 changes: 33 additions & 19 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1905,43 +1905,56 @@ struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
}
};

/// Removes Affine.If cond if the condition is always true or false in certain
/// Removes affine.if cond if the condition is always true or false in certain
/// trivial cases. Promotes the then/else block in the parent operation block.
struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
using OpRewritePattern<AffineIfOp>::OpRewritePattern;

LogicalResult matchAndRewrite(AffineIfOp op,
PatternRewriter &rewriter) const override {

// If affine.if is returning results then don't remove it.
// TODO: Similar simplication can be done when affine.if return results.
if (op.getNumResults() > 0)
return failure();
auto isTriviallyFalse = [](IntegerSet iSet) {
return iSet.isEmptyIntegerSet();
};

IntegerSet conditionSet = op.getIntegerSet();
auto isTriviallyTrue = [](IntegerSet iSet) {
return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
iSet.getConstraint(0) == 0);
};

IntegerSet affineIfConditions = op.getIntegerSet();
Block *blockToMove;
if (conditionSet.isEmptyIntegerSet()) {
// If the else region is not there, simply remove the Affine.if
// operation.
if (!op.hasElse()) {
if (isTriviallyFalse(affineIfConditions)) {
// The absence, or equivalently, the emptiness of the else region need not
// be checked when affine.if is returning results because if an affine.if
// operation is returning results, it always has a non-empty else region.
if (op.getNumResults() == 0 && !op.hasElse()) {
// If the else region is absent, or equivalently, empty, remove the
// affine.if operation (which is not returning any results).
rewriter.eraseOp(op);
return success();
}
blockToMove = op.getElseBlock();
} else if (conditionSet.getNumEqualities() == 1 &&
conditionSet.getNumInequalities() == 0 &&
conditionSet.getConstraint(0) == 0) {
// Condition to check for trivially true condition (0==0).
} else if (isTriviallyTrue(affineIfConditions)) {
blockToMove = op.getThenBlock();
} else {
return failure();
}
// Remove the terminator from the block as it already exists in parent
// block.
Operation *blockTerminator = blockToMove->getTerminator();
rewriter.eraseOp(blockTerminator);
Operation *blockToMoveTerminator = blockToMove->getTerminator();
// Promote the "blockToMove" block to the parent operation block between the
// prologue and epilogue of "op".
rewriter.mergeBlockBefore(blockToMove, op);
rewriter.eraseOp(op);
// Replace the "op" operation with the operands of the
// "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
// the affine.yield operation present in the "blockToMove" block. It has no
// operands when affine.if is not returning results and therefore, in that
// case, replaceOp just erases "op". When affine.if is not returning
// results, the affine.yield operation can be omitted. It gets inserted
// implicitly.
rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
// Erase the "blockToMoveTerminator" operation since it is now in the parent
// operation block, which already has its own terminator.
rewriter.eraseOp(blockToMoveTerminator);
return success();
}
};
Expand Down Expand Up @@ -2051,6 +2064,7 @@ IntegerSet AffineIfOp::getIntegerSet() {
->getAttrOfType<IntegerSetAttr>(getConditionAttrName())
.getValue();
}

void AffineIfOp::setIntegerSet(IntegerSet newSet) {
(*this)->setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
}
Expand Down
99 changes: 90 additions & 9 deletions mlir/test/Dialect/Affine/simplify-affine-structures.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ func @test_always_false_if_elimination() {
}


// Testing: Affine.If is not trivially true or false, nothing happens.
// Testing: affine.if is not trivially true or false, nothing happens.
// CHECK-LABEL: func @test_dimensional_if_elimination() {
func @test_dimensional_if_elimination() {
affine.for %arg0 = 1 to 10 {
Expand All @@ -385,16 +385,97 @@ func @test_dimensional_if_elimination() {
return
}

// Testing: Affine.If don't get removed if it is returning results.
// Testing: affine.if gets removed.
// CHECK-LABEL: func @test_num_results_if_elimination
func @test_num_results_if_elimination() -> f32 {
%zero = constant 0.0 : f32
func @test_num_results_if_elimination() -> index {
// CHECK: %[[zero:.*]] = constant 0 : index
%zero = constant 0 : index
%0 = affine.if affine_set<() : ()> () -> index {
affine.yield %zero : index
} else {
affine.yield %zero : index
}
// CHECK-NEXT: return %[[zero]] : index
return %0 : index
}


// Three more test functions involving affine.if operations which are
// returning results:

// Testing: affine.if gets removed. `Else` block get promoted.
// CHECK-LABEL: func @test_trivially_false_returning_two_results
// CHECK-SAME: (%[[arg0:.*]]: index)
func @test_trivially_false_returning_two_results(%arg0: index) -> (index, index) {
// CHECK: %[[c7:.*]] = constant 7 : index
// CHECK: %[[c13:.*]] = constant 13 : index
%c7 = constant 7 : index
%c13 = constant 13 : index
// CHECK: %[[c2:.*]] = constant 2 : index
// CHECK: %[[c3:.*]] = constant 3 : index
%res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%c7, %c13) -> (index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
affine.yield %c0, %c1 : index, index
} else {
%c2 = constant 2 : index
%c3 = constant 3 : index
affine.yield %c7, %arg0 : index, index
}
// CHECK-NEXT: return %[[c7]], %[[arg0]] : index, index
return %res#0, %res#1 : index, index
}

// Testing: affine.if gets removed. `Then` block get promoted.
// CHECK-LABEL: func @test_trivially_true_returning_five_results
func @test_trivially_true_returning_five_results() -> (index, index, index, index, index) {
// CHECK: %[[c12:.*]] = constant 12 : index
// CHECK: %[[c13:.*]] = constant 13 : index
%c12 = constant 12 : index
%c13 = constant 13 : index
// CHECK: %[[c0:.*]] = constant 0 : index
// CHECK: %[[c1:.*]] = constant 1 : index
// CHECK: %[[c2:.*]] = constant 2 : index
// CHECK: %[[c3:.*]] = constant 3 : index
// CHECK: %[[c4:.*]] = constant 4 : index
%res:5 = affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%c12, %c13) -> (index, index, index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
%c4 = constant 4 : index
affine.yield %c0, %c1, %c2, %c3, %c4 : index, index, index, index, index
} else {
%c5 = constant 5 : index
%c6 = constant 6 : index
%c7 = constant 7 : index
%c8 = constant 8 : index
%c9 = constant 9 : index
affine.yield %c5, %c6, %c7, %c8, %c9 : index, index, index, index, index
}
// CHECK-NEXT: return %[[c0]], %[[c1]], %[[c2]], %[[c3]], %[[c4]] : index, index, index, index, index
return %res#0, %res#1, %res#2, %res#3, %res#4 : index, index, index, index, index
}

// Testing: affine.if doesn't get removed.
// CHECK-LABEL: func @test_not_trivially_true_or_false_returning_three_results
func @test_not_trivially_true_or_false_returning_three_results() -> (index, index, index) {
// CHECK: %[[c8:.*]] = constant 8 : index
// CHECK: %[[c13:.*]] = constant 13 : index
%c8 = constant 8 : index
%c13 = constant 13 : index
// CHECK: affine.if
%0 = affine.if affine_set<() : ()> () -> f32 {
affine.yield %zero : f32
// CHECK: else {
%res:3 = affine.if affine_set<(d0, d1) : (d0 - 1 == 0)>(%c8, %c13) -> (index, index, index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
affine.yield %c0, %c1, %c2 : index, index, index
// CHECK: } else {
} else {
affine.yield %zero : f32
%c3 = constant 3 : index
%c4 = constant 4 : index
%c5 = constant 5 : index
affine.yield %c3, %c4, %c5 : index, index, index
}
return %0 : f32
return %res#0, %res#1, %res#2 : index, index, index
}

0 comments on commit 0c1a773

Please sign in to comment.