From 31f82075c3fd174d19c6b80b8aed721f0953369b Mon Sep 17 00:00:00 2001 From: yanming Date: Mon, 27 Oct 2025 16:35:03 +0800 Subject: [PATCH 1/3] [MLIR][SCF] Sink scf.if from scf.while before region into after region. --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 125 +++++++++++++++++++++++- mlir/test/Dialect/SCF/canonicalize.mlir | 37 +++++++ 2 files changed, 161 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 881e256a8797b..79dcf562db993 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -26,6 +26,7 @@ #include "mlir/Interfaces/ParallelCombiningOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" @@ -3687,6 +3688,127 @@ LogicalResult scf::WhileOp::verify() { } namespace { +/// Move a scf.if op that is directly before the scf.condition op in the while +/// before region, and whose condition matches the condition of the +/// scf.condition op, down into the while after region. +/// +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// %res = scf.if %cond -> (...) { +/// use(%additional_used_values) +/// ... // then block +/// scf.yield %then_value +/// } else { +/// scf.yield %else_value +/// } +/// scf.condition(%cond) %res, ... +/// } do { +/// ^bb0(%res_arg, ...): +/// use(%res_arg) +/// ... +/// +/// becomes +/// scf.while (..) : (...) -> ... { +/// %additional_used_values = ... +/// %cond = ... +/// ... +/// scf.condition(%cond) %else_value, ..., %additional_used_values +/// } do { +/// ^bb0(%res_arg ..., %additional_args): : +/// use(%additional_args) +/// ... // if then block +/// use(%then_value) +/// ... +struct WhileMoveIfDown : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::WhileOp op, + PatternRewriter &rewriter) const override { + auto conditionOp = + cast(op.getBeforeBody()->getTerminator()); + auto ifOp = dyn_cast_or_null(conditionOp->getPrevNode()); + + // Check that the ifOp is directly before the conditionOp and that it + // matches the condition of the conditionOp. Also ensure that the ifOp has + // no else block with content, as that would complicate the transformation. + // TODO: support else blocks with content. + if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() || + (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty())) + return failure(); + + assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) && + *ifOp->user_begin() == conditionOp) && + "ifOp has unexpected uses"); + + Location loc = op.getLoc(); + + // Replace uses of ifOp results in the conditionOp with the yielded values + // from the ifOp branches. + for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) { + auto it = llvm::find(ifOp->getResults(), arg); + if (it != ifOp->getResults().end()) { + size_t ifOpIdx = it.getIndex(); + Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx); + Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx); + + rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue); + rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue); + } + } + + // Collect additional used values from before region. + SetVector additionalUsedValues; + visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { + if (op.getBefore().isAncestor(operand->get().getParentRegion())) + additionalUsedValues.insert(operand->get()); + }); + + // Create new whileOp with additional used values as results. + auto additionalValueTypes = llvm::map_to_vector( + additionalUsedValues, [](Value val) { return val.getType(); }); + size_t additionalValueSize = additionalUsedValues.size(); + SmallVector newResultTypes(op.getResultTypes()); + newResultTypes.append(additionalValueTypes); + + auto newWhileOp = + scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); + + newWhileOp.getBefore().takeBody(op.getBefore()); + newWhileOp.getAfter().takeBody(op.getAfter()); + newWhileOp.getAfter().addArguments( + additionalValueTypes, SmallVector(additionalValueSize, loc)); + + SmallVector conditionArgs = conditionOp.getArgs(); + llvm::append_range(conditionArgs, additionalUsedValues); + + // Update conditionOp inside new whileOp before region. + rewriter.setInsertionPoint(conditionOp); + rewriter.replaceOpWithNewOp( + conditionOp, conditionOp.getCondition(), conditionArgs); + + // Replace uses of additional used values inside the ifOp then region with + // the whileOp after region arguments. + rewriter.replaceUsesWithIf( + additionalUsedValues.takeVector(), + newWhileOp.getAfterArguments().take_back(additionalValueSize), + [&](OpOperand &use) { + return ifOp.getThenRegion().isAncestor( + use.getOwner()->getParentRegion()); + }); + + // Inline ifOp then region into new whileOp after region. + rewriter.eraseOp(ifOp.thenYield()); + rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(), + newWhileOp.getAfterBody()->begin()); + rewriter.eraseOp(ifOp); + rewriter.replaceOp(op, + newWhileOp->getResults().drop_back(additionalValueSize)); + return success(); + } +}; + /// Replace uses of the condition within the do block with true, since otherwise /// the block would not be evaluated. /// @@ -4399,7 +4521,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); + WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs, WhileMoveIfDown>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 084c3fc065de3..b02cbc07880b9 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -974,6 +974,43 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) { // ----- +// CHECK-LABEL: @while_move_if_down +func.func @while_move_if_down() -> i32 { + %0 = scf.while () : () -> (i32) { + %additional_used_value = "test.get_some_value1" () : () -> (i32) + %else_value = "test.get_some_value2" () : () -> (i32) + %condition = "test.condition"() : () -> i1 + %res = scf.if %condition -> (i32) { + "test.use1" (%additional_used_value) : (i32) -> () + %then_value = "test.get_some_value3" () : () -> (i32) + scf.yield %then_value : i32 + } else { + scf.yield %else_value : i32 + } + scf.condition(%condition) %res : i32 + } do { + ^bb0(%res_arg: i32): + "test.use2" (%res_arg) : (i32) -> () + scf.yield + } + return %0 : i32 +} +// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) { +// CHECK-NEXT: %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32 +// CHECK-NEXT: %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32 +// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1 +// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32): +// CHECK-NEXT: "test.use1"(%[[VAL_4]]) : (i32) -> () +// CHECK-NEXT: %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32 +// CHECK-NEXT: "test.use2"(%[[VAL_5]]) : (i32) -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return %[[VAL_6:.*]]#0 : i32 + +// ----- + // CHECK-LABEL: @while_cond_true func.func @while_cond_true() -> i1 { %0 = scf.while () : () -> i1 { From 2c61dcd03ca4cd4e1d8f6c03481ba94774564820 Mon Sep 17 00:00:00 2001 From: Ming Yan Date: Fri, 28 Nov 2025 22:22:59 +0800 Subject: [PATCH 2/3] Simplify the code and update the tests. --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 40 ++++++++++++++----------- mlir/test/Dialect/SCF/canonicalize.mlir | 34 +++++++++++---------- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 79dcf562db993..bb07291036667 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3726,8 +3726,14 @@ struct WhileMoveIfDown : public OpRewritePattern { LogicalResult matchAndRewrite(scf::WhileOp op, PatternRewriter &rewriter) const override { - auto conditionOp = - cast(op.getBeforeBody()->getTerminator()); + auto conditionOp = op.getConditionOp(); + + // Only support ifOp right before the condition at the moment. Relaxing this + // would require to: + // - check that the body does not have side-effects conflicting with + // operations between the if and the condition. + // - check that results of the if operation are only used as arguments to + // the condition. auto ifOp = dyn_cast_or_null(conditionOp->getPrevNode()); // Check that the ifOp is directly before the conditionOp and that it @@ -3759,13 +3765,14 @@ struct WhileMoveIfDown : public OpRewritePattern { } // Collect additional used values from before region. - SetVector additionalUsedValues; + SetVector additionalUsedValuesSet; visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { - if (op.getBefore().isAncestor(operand->get().getParentRegion())) - additionalUsedValues.insert(operand->get()); + if (&op.getBefore() == operand->get().getParentRegion()) + additionalUsedValuesSet.insert(operand->get()); }); // Create new whileOp with additional used values as results. + auto additionalUsedValues = additionalUsedValuesSet.getArrayRef(); auto additionalValueTypes = llvm::map_to_vector( additionalUsedValues, [](Value val) { return val.getType(); }); size_t additionalValueSize = additionalUsedValues.size(); @@ -3775,23 +3782,22 @@ struct WhileMoveIfDown : public OpRewritePattern { auto newWhileOp = scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); - newWhileOp.getBefore().takeBody(op.getBefore()); - newWhileOp.getAfter().takeBody(op.getAfter()); - newWhileOp.getAfter().addArguments( - additionalValueTypes, SmallVector(additionalValueSize, loc)); - - SmallVector conditionArgs = conditionOp.getArgs(); - llvm::append_range(conditionArgs, additionalUsedValues); + rewriter.modifyOpInPlace(newWhileOp, [&] { + newWhileOp.getBefore().takeBody(op.getBefore()); + newWhileOp.getAfter().takeBody(op.getAfter()); + newWhileOp.getAfter().addArguments( + additionalValueTypes, + SmallVector(additionalValueSize, loc)); + }); - // Update conditionOp inside new whileOp before region. - rewriter.setInsertionPoint(conditionOp); - rewriter.replaceOpWithNewOp( - conditionOp, conditionOp.getCondition(), conditionArgs); + rewriter.modifyOpInPlace(conditionOp, [&] { + conditionOp.getArgsMutable().append(additionalUsedValues); + }); // Replace uses of additional used values inside the ifOp then region with // the whileOp after region arguments. rewriter.replaceUsesWithIf( - additionalUsedValues.takeVector(), + additionalUsedValues, newWhileOp.getAfterArguments().take_back(additionalValueSize), [&](OpOperand &use) { return ifOp.getThenRegion().isAncestor( diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index b02cbc07880b9..3b9e219403986 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -976,12 +976,14 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) { // CHECK-LABEL: @while_move_if_down func.func @while_move_if_down() -> i32 { + %defined_outside = "test.get_some_value" () : () -> (i32) %0 = scf.while () : () -> (i32) { %additional_used_value = "test.get_some_value1" () : () -> (i32) %else_value = "test.get_some_value2" () : () -> (i32) %condition = "test.condition"() : () -> i1 %res = scf.if %condition -> (i32) { - "test.use1" (%additional_used_value) : (i32) -> () + "test.use1" (%defined_outside) : (i32) -> () + "test.use2" (%additional_used_value) : (i32) -> () %then_value = "test.get_some_value3" () : () -> (i32) scf.yield %then_value : i32 } else { @@ -990,24 +992,26 @@ func.func @while_move_if_down() -> i32 { scf.condition(%condition) %res : i32 } do { ^bb0(%res_arg: i32): - "test.use2" (%res_arg) : (i32) -> () + "test.use3" (%res_arg) : (i32) -> () scf.yield } return %0 : i32 } -// CHECK-NEXT: %[[WHILE_0:.*]]:2 = scf.while : () -> (i32, i32) { -// CHECK-NEXT: %[[VAL_0:.*]] = "test.get_some_value1"() : () -> i32 -// CHECK-NEXT: %[[VAL_1:.*]] = "test.get_some_value2"() : () -> i32 -// CHECK-NEXT: %[[VAL_2:.*]] = "test.condition"() : () -> i1 -// CHECK-NEXT: scf.condition(%[[VAL_2]]) %[[VAL_1]], %[[VAL_0]] : i32, i32 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32): -// CHECK-NEXT: "test.use1"(%[[VAL_4]]) : (i32) -> () -// CHECK-NEXT: %[[VAL_5:.*]] = "test.get_some_value3"() : () -> i32 -// CHECK-NEXT: "test.use2"(%[[VAL_5]]) : (i32) -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: } -// CHECK-NEXT: return %[[VAL_6:.*]]#0 : i32 +// CHECK-NEXT: %[[defined_outside:.*]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: %[[while_res:.*]]:2 = scf.while : () -> (i32, i32) { +// CHECK-NEXT: %[[additional_used_value:.*]] = "test.get_some_value1"() : () -> i32 +// CHECK-NEXT: %[[else_value:.*]] = "test.get_some_value2"() : () -> i32 +// CHECK-NEXT: %[[condition:.*]] = "test.condition"() : () -> i1 +// CHECK-NEXT: scf.condition(%[[condition]]) %[[else_value]], %[[additional_used_value]] : i32, i32 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[res_arg:.*]]: i32, %[[additional_used_value_arg:.*]]: i32): +// CHECK-NEXT: "test.use1"(%[[defined_outside]]) : (i32) -> () +// CHECK-NEXT: "test.use2"(%[[additional_used_value_arg]]) : (i32) -> () +// CHECK-NEXT: %[[then_value:.*]] = "test.get_some_value3"() : () -> i32 +// CHECK-NEXT: "test.use3"(%[[then_value]]) : (i32) -> () +// CHECK-NEXT: scf.yield +// CHECK-NEXT: } +// CHECK-NEXT: return %[[while_res:.*]]#0 : i32 // ----- From b33005812866d7047f5cf31430dbdcf2fd305fe2 Mon Sep 17 00:00:00 2001 From: Ming Yan Date: Sat, 29 Nov 2025 23:35:48 +0800 Subject: [PATCH 3/3] Update test. --- mlir/test/Dialect/SCF/canonicalize.mlir | 51 +++++++++++++++---------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 3b9e219403986..ac590fc0c47b9 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -976,15 +976,19 @@ func.func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) { // CHECK-LABEL: @while_move_if_down func.func @while_move_if_down() -> i32 { - %defined_outside = "test.get_some_value" () : () -> (i32) + %defined_outside = "test.get_some_value0" () : () -> (i32) %0 = scf.while () : () -> (i32) { - %additional_used_value = "test.get_some_value1" () : () -> (i32) - %else_value = "test.get_some_value2" () : () -> (i32) + %used_value = "test.get_some_value1" () : () -> (i32) + %used_by_subregion = "test.get_some_value2" () : () -> (i32) + %else_value = "test.get_some_value3" () : () -> (i32) %condition = "test.condition"() : () -> i1 %res = scf.if %condition -> (i32) { - "test.use1" (%defined_outside) : (i32) -> () - "test.use2" (%additional_used_value) : (i32) -> () - %then_value = "test.get_some_value3" () : () -> (i32) + "test.use0" (%defined_outside) : (i32) -> () + "test.use1" (%used_value) : (i32) -> () + test.alloca_scope_region { + "test.use2" (%used_by_subregion) : (i32) -> () + } + %then_value = "test.get_some_value4" () : () -> (i32) scf.yield %then_value : i32 } else { scf.yield %else_value : i32 @@ -997,21 +1001,26 @@ func.func @while_move_if_down() -> i32 { } return %0 : i32 } -// CHECK-NEXT: %[[defined_outside:.*]] = "test.get_some_value"() : () -> i32 -// CHECK-NEXT: %[[while_res:.*]]:2 = scf.while : () -> (i32, i32) { -// CHECK-NEXT: %[[additional_used_value:.*]] = "test.get_some_value1"() : () -> i32 -// CHECK-NEXT: %[[else_value:.*]] = "test.get_some_value2"() : () -> i32 -// CHECK-NEXT: %[[condition:.*]] = "test.condition"() : () -> i1 -// CHECK-NEXT: scf.condition(%[[condition]]) %[[else_value]], %[[additional_used_value]] : i32, i32 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[res_arg:.*]]: i32, %[[additional_used_value_arg:.*]]: i32): -// CHECK-NEXT: "test.use1"(%[[defined_outside]]) : (i32) -> () -// CHECK-NEXT: "test.use2"(%[[additional_used_value_arg]]) : (i32) -> () -// CHECK-NEXT: %[[then_value:.*]] = "test.get_some_value3"() : () -> i32 -// CHECK-NEXT: "test.use3"(%[[then_value]]) : (i32) -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: } -// CHECK-NEXT: return %[[while_res:.*]]#0 : i32 +// CHECK: %[[defined_outside:.*]] = "test.get_some_value0"() : () -> i32 +// CHECK: %[[WHILE_RES:.*]]:3 = scf.while : () -> (i32, i32, i32) { +// CHECK: %[[used_value:.*]] = "test.get_some_value1"() : () -> i32 +// CHECK: %[[used_by_subregion:.*]] = "test.get_some_value2"() : () -> i32 +// CHECK: %[[else_value:.*]] = "test.get_some_value3"() : () -> i32 +// CHECK: %[[condition:.*]] = "test.condition"() : () -> i1 +// CHECK: scf.condition(%[[condition]]) %[[else_value]], %[[used_value]], %[[used_by_subregion]] : i32, i32, i32 +// CHECK: } do { +// CHECK: ^bb0(%[[res_arg:.*]]: i32, %[[used_value_arg:.*]]: i32, %[[used_by_subregion_arg:.*]]: i32): +// CHECK: "test.use0"(%[[defined_outside]]) : (i32) -> () +// CHECK: "test.use1"(%[[used_value_arg]]) : (i32) -> () +// CHECK: test.alloca_scope_region { +// CHECK: "test.use2"(%[[used_by_subregion_arg]]) : (i32) -> () +// CHECK: } +// CHECK: %[[then_value:.*]] = "test.get_some_value4"() : () -> i32 +// CHECK: "test.use3"(%[[then_value]]) : (i32) -> () +// CHECK: scf.yield +// CHECK: } +// CHECK: return %[[WHILE_RES]]#0 : i32 +// CHECK: } // -----