diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 881e256a8797b..bb07291036667 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -26,6 +26,7 @@ #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -3687,6 +3688,133 @@ LogicalResult scf::WhileOp::verify() { } namespace { +/// Move a scf.if op that is directly before the scf.condition op in the while +/// before region, and whose condition matches the condition of the +/// scf.condition op, down into the while after region. +/// +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// %res = scf.if %cond -> (...) { +/// use(%additional_used_values) +/// ... // then block +/// scf.yield %then_value +/// } else { +/// scf.yield %else_value +/// } +/// scf.condition(%cond) %res, ... +/// } do { +/// ^bb0(%res_arg, ...): +/// use(%res_arg) +/// ... +/// +/// becomes +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// scf.condition(%cond) %else_value, ..., %additional_used_values +/// } do { +/// ^bb0(%res_arg ..., %additional_args): : +/// use(%additional_args) +/// ... // if then block +/// use(%then_value) +/// ... +struct WhileMoveIfDown : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + auto conditionOp = op.getConditionOp(); + + // Only support ifOp right before the condition at the moment. Relaxing this + // would require to: + // - check that the body does not have side-effects conflicting with + // operations between the if and the condition. + // - check that results of the if operation are only used as arguments to + // the condition. + auto ifOp = dyn_cast_or_null(conditionOp->getPrevNode()); + + // Check that the ifOp is directly before the conditionOp and that it + // matches the condition of the conditionOp. Also ensure that the ifOp has + // no else block with content, as that would complicate the transformation. + // TODO: support else blocks with content. + if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() || + (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty())) + return failure(); + + assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) && + *ifOp->user_begin() == conditionOp) && + "ifOp has unexpected uses"); + + Location loc = op.getLoc(); + + // Replace uses of ifOp results in the conditionOp with the yielded values + // from the ifOp branches. + for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) { + auto it = llvm::find(ifOp->getResults(), arg); + if (it != ifOp->getResults().end()) { + size_t ifOpIdx = it.getIndex(); + Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx); + Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx); + + rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue); + rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue); + } + } + + // Collect additional used values from before region. + SetVector additionalUsedValuesSet; + visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { + if (&op.getBefore() == operand->get().getParentRegion()) + additionalUsedValuesSet.insert(operand->get()); + }); + + // Create new whileOp with additional used values as results. + auto additionalUsedValues = additionalUsedValuesSet.getArrayRef(); + auto additionalValueTypes = llvm::map_to_vector( + additionalUsedValues, [](Value val) { return val.getType(); }); + size_t additionalValueSize = additionalUsedValues.size(); + SmallVector newResultTypes(op.getResultTypes()); + newResultTypes.append(additionalValueTypes); + + auto newWhileOp = + scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); + + rewriter.modifyOpInPlace(newWhileOp, [&] { + newWhileOp.getBefore().takeBody(op.getBefore()); + newWhileOp.getAfter().takeBody(op.getAfter()); + newWhileOp.getAfter().addArguments( + additionalValueTypes, + SmallVector(additionalValueSize, loc)); + }); + + rewriter.modifyOpInPlace(conditionOp, [&] { + conditionOp.getArgsMutable().append(additionalUsedValues); + }); + + // Replace uses of additional used values inside the ifOp then region with + // the whileOp after region arguments. + rewriter.replaceUsesWithIf( + additionalUsedValues, + newWhileOp.getAfterArguments().take_back(additionalValueSize), + [&](OpOperand &use) { + return ifOp.getThenRegion().isAncestor( + use.getOwner()->getParentRegion()); + }); + + // Inline ifOp then region into new whileOp after region. + rewriter.eraseOp(ifOp.thenYield()); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(), + newWhileOp.getAfterBody()->begin()); + rewriter.eraseOp(ifOp); + rewriter.replaceOp(op, + newWhileOp->getResults().drop_back(additionalValueSize)); + return success(); + } +}; + /// Replace uses of the condition within the do block with true, since otherwise /// the block would not be evaluated. /// @@ -4399,7 +4527,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); + WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 084c3fc065de3..ac590fc0c47b9 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -974,6 +974,56 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) { // ----- +// CHECK-LABEL: @while_move_if_down +func.func @while_move_if_down() -> i32 { + %defined_outside = "test.get_some_value0" () : () -> (i32) + %0 = scf.while () : () -> (i32) { + %used_value = "test.get_some_value1" () : () -> (i32) + %used_by_subregion = "test.get_some_value2" () : () -> (i32) + %else_value = "test.get_some_value3" () : () -> (i32) + %condition = "test.condition"() : () -> i1 + %res = scf.if %condition -> (i32) { + "test.use0" (%defined_outside) : (i32) -> () + "test.use1" (%used_value) : (i32) -> () + test.alloca_scope_region { + "test.use2" (%used_by_subregion) : (i32) -> () + } + %then_value = "test.get_some_value4" () : () -> (i32) + scf.yield %then_value : i32 + } else { + scf.yield %else_value : i32 + } + scf.condition(%condition) %res : i32 + } do { + ^bb0(%res_arg: i32): + "test.use3" (%res_arg) : (i32) -> () + scf.yield + } + return %0 : i32 +} +// CHECK: %[[defined_outside:.*]] = "test.get_some_value0"() : () -> i32 +// CHECK: %[[WHILE_RES:.*]]:3 = scf.while : () -> (i32, i32, i32) { +// CHECK: %[[used_value:.*]] = "test.get_some_value1"() : () -> i32 +// CHECK: %[[used_by_subregion:.*]] = "test.get_some_value2"() : () -> i32 +// CHECK: %[[else_value:.*]] = "test.get_some_value3"() : () -> i32 +// CHECK: %[[condition:.*]] = "test.condition"() : () -> i1 +// CHECK: scf.condition(%[[condition]]) %[[else_value]], %[[used_value]], %[[used_by_subregion]] : i32, i32, i32 +// CHECK: } do { +// CHECK: ^bb0(%[[res_arg:.*]]: i32, %[[used_value_arg:.*]]: i32, %[[used_by_subregion_arg:.*]]: i32): +// CHECK: "test.use0"(%[[defined_outside]]) : (i32) -> () +// CHECK: "test.use1"(%[[used_value_arg]]) : (i32) -> () +// CHECK: test.alloca_scope_region { +// CHECK: "test.use2"(%[[used_by_subregion_arg]]) : (i32) -> () +// CHECK: } +// CHECK: %[[then_value:.*]] = "test.get_some_value4"() : () -> i32 +// CHECK: "test.use3"(%[[then_value]]) : (i32) -> () +// CHECK: scf.yield +// CHECK: } +// CHECK: return %[[WHILE_RES]]#0 : i32 +// CHECK: } + +// ----- + // CHECK-LABEL: @while_cond_true func.func @while_cond_true() -> i1 { %0 = scf.while () : () -> i1 {