From 10892921aab959c62e5a0e37db0fba721d76e817 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 22 Dec 2023 00:37:45 +0100 Subject: [PATCH 1/4] [mlir][scf] Align `scf.while` `before` block args in canonicalizer If `before` block args are directly forwarded to `scf.condition` make sure they are passes in the same order. This is needed for `scf.while` uplifting https://github.com/llvm/llvm-project/pull/76108 --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 77 ++++++++++++++++++++++++- mlir/test/Dialect/SCF/canonicalize.mlir | 29 ++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 5bca8e85f889d..3c3eb6de5986d 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3884,6 +3884,81 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern { return success(); } }; + +/// If both ranges contain same values return mappping indices from args1 to +/// args2. Otherwise return std::nullopt +static std::optional> getArgsMapping(ValueRange args1, + ValueRange args2) { + if (args1.size() != args2.size()) + return std::nullopt; + + SmallVector ret(args1.size()); + for (auto &&[i, arg1] : llvm::enumerate(args1)) { + auto it = llvm::find(args2, arg1); + if (it == args2.end()) + return std::nullopt; + + auto j = it - args2.begin(); + ret[j] = static_cast(i); + } + + return ret; +} + +/// If `before` block args are directly forwarded to `scf.condition`, rearrange +/// `scf.condition` args into same order as block args. Update `after` block +// args and results values accordingly. +/// Needed to simplify `scf.while` -> `scf.for` uplifting. +struct WhileOpAlignBeforeArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp loop, + PatternRewriter &rewriter) const override { + auto oldBefore = loop.getBeforeBody(); + ConditionOp oldTerm = loop.getConditionOp(); + ValueRange beforeArgs = oldBefore->getArguments(); + ValueRange termArgs = oldTerm.getArgs(); + if (beforeArgs == termArgs) + return failure(); + + auto mapping = getArgsMapping(beforeArgs, termArgs); + if (!mapping) + return failure(); + + { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(oldTerm); + rewriter.replaceOpWithNewOp(oldTerm, oldTerm.getCondition(), + beforeArgs); + } + + auto oldAfter = loop.getAfterBody(); + + SmallVector newResultTypes(beforeArgs.size()); + for (auto &&[i, j] : llvm::enumerate(*mapping)) + newResultTypes[j] = loop.getResult(i).getType(); + + auto newLoop = rewriter.create(loop.getLoc(), newResultTypes, + loop.getInits(), nullptr, nullptr); + auto newBefore = newLoop.getBeforeBody(); + auto newAfter = newLoop.getAfterBody(); + + SmallVector newResults(beforeArgs.size()); + SmallVector newAfterArgs(beforeArgs.size()); + for (auto &&[i, j] : llvm::enumerate(*mapping)) { + newResults[i] = newLoop.getResult(j); + newAfterArgs[i] = newAfter->getArgument(j); + } + + rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(), + newBefore->getArguments()); + rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(), + newAfterArgs); + + rewriter.replaceOp(loop, newResults); + return success(); + } +}; } // namespace void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -3891,7 +3966,7 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); + WhileRemoveUnusedArgs, WhileOpAlignBeforeArgs>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index 52e0fdfa36d6c..b4c9ed4db94e0 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1198,6 +1198,35 @@ func.func @while_unused_arg2(%val0: i32) -> i32 { // CHECK: return %[[RES]] : i32 +// ----- + +// CHECK-LABEL: func @test_align_args +// CHECK: %[[RES:.*]]:3 = scf.while (%[[ARG0:.*]] = %{{.*}}, %[[ARG1:.*]] = %{{.*}}, %[[ARG2:.*]] = %{{.*}}) : (f32, i32, i64) -> (f32, i32, i64) { +// CHECK: scf.condition(%{{.*}}) %[[ARG0]], %[[ARG1]], %[[ARG2]] : f32, i32, i64 +// CHECK: ^bb0(%[[ARG3:.*]]: f32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i64): +// CHECK: %[[R1:.*]] = "test.test"(%[[ARG5]]) : (i64) -> f32 +// CHECK: %[[R2:.*]] = "test.test"(%[[ARG3]]) : (f32) -> i32 +// CHECK: %[[R3:.*]] = "test.test"(%[[ARG4]]) : (i32) -> i64 +// CHECK: scf.yield %[[R1]], %[[R2]], %[[R3]] : f32, i32, i64 +// CHECK: return %[[RES]]#2, %[[RES]]#0, %[[RES]]#1 +func.func @test_align_args() -> (i64, f32, i32) { + %0 = "test.test"() : () -> (f32) + %1 = "test.test"() : () -> (i32) + %2 = "test.test"() : () -> (i64) + %3:3 = scf.while (%arg0 = %0, %arg1 = %1, %arg2 = %2) : (f32, i32, i64) -> (i64, f32, i32) { + %cond = "test.test"() : () -> (i1) + scf.condition(%cond) %arg2, %arg0, %arg1 : i64, f32, i32 + } do { + ^bb0(%arg3: i64, %arg4: f32, %arg5: i32): + %4 = "test.test"(%arg3) : (i64) -> (f32) + %5 = "test.test"(%arg4) : (f32) -> (i32) + %6 = "test.test"(%arg5) : (i32) -> (i64) + scf.yield %4, %5, %6 : f32, i32, i64 + } + return %3#0, %3#1, %3#2 : i64, f32, i32 +} + + // ----- // CHECK-LABEL: @combineIfs From 496a0d4185a5463c0d92a9399af97d35c3098544 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 30 Mar 2024 14:43:44 +0100 Subject: [PATCH 2/4] review comments --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 3c3eb6de5986d..a04913610fcb6 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3885,8 +3885,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern { } }; -/// If both ranges contain same values return mappping indices from args1 to -/// args2. Otherwise return std::nullopt +/// If both ranges contain same values return mappping indices from args2 to +/// args1. Otherwise return std::nullopt. static std::optional> getArgsMapping(ValueRange args1, ValueRange args2) { if (args1.size() != args2.size()) @@ -3898,16 +3898,26 @@ static std::optional> getArgsMapping(ValueRange args1, if (it == args2.end()) return std::nullopt; - auto j = it - args2.begin(); - ret[j] = static_cast(i); + ret[std::distance(args2.begin(), it)] = static_cast(i); } return ret; } +static bool hasDuplicates(ValueRange args) { + llvm::SmallDenseSet set; + for (Value arg : args) { + if (set.contains(arg)) + return true; + + set.insert(arg); + } + return false; +} + /// If `before` block args are directly forwarded to `scf.condition`, rearrange /// `scf.condition` args into same order as block args. Update `after` block -// args and results values accordingly. +// args and op result values accordingly. /// Needed to simplify `scf.while` -> `scf.for` uplifting. struct WhileOpAlignBeforeArgs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3921,6 +3931,9 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern { if (beforeArgs == termArgs) return failure(); + if (hasDuplicates(termArgs)) + return failure(); + auto mapping = getArgsMapping(beforeArgs, termArgs); if (!mapping) return failure(); From 1f8b1b106424eb970763698f3711336eb4e2f12a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 30 Mar 2024 14:47:23 +0100 Subject: [PATCH 3/4] missing / --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index a04913610fcb6..9b652cdc07dec 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3917,7 +3917,7 @@ static bool hasDuplicates(ValueRange args) { /// If `before` block args are directly forwarded to `scf.condition`, rearrange /// `scf.condition` args into same order as block args. Update `after` block -// args and op result values accordingly. +/// args and op result values accordingly. /// Needed to simplify `scf.while` -> `scf.for` uplifting. struct WhileOpAlignBeforeArgs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; From 0bd2faaaf8416f1229c21d0f3d9f15f9a23c4bc4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 2 Apr 2024 14:20:57 +0200 Subject: [PATCH 4/4] add args comments --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 9b652cdc07dec..7a1aafc9f1c2f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3951,8 +3951,9 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern { for (auto &&[i, j] : llvm::enumerate(*mapping)) newResultTypes[j] = loop.getResult(i).getType(); - auto newLoop = rewriter.create(loop.getLoc(), newResultTypes, - loop.getInits(), nullptr, nullptr); + auto newLoop = rewriter.create( + loop.getLoc(), newResultTypes, loop.getInits(), + /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); auto newBefore = newLoop.getBeforeBody(); auto newAfter = newLoop.getAfterBody();