Skip to content

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Apr 18, 2024

scf.while -> scf.for uplifting expects before block consisting of single cmp op, so we need to cleanup it before running the uplifting. One of the possible cleanups is LICM.
Second one is moving and duplicating ops from before block to after block and after the loop. Add the pattern for such transformation.

`scf.while` -> `scf.for` uplifting expects `before` block consisting of single cmp op, so we need to cleanup it before running the uplifting.
One of the possible cleanups is LICM.
Second one is moving and duplicting ops from `before` block to `after` block and after the loop.
Add the pattern for such transformation.
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

scf.while -> scf.for uplifting expects before block consisting of single cmp op, so we need to cleanup it before running the uplifting. One of the possible cleanups is LICM.
Second one is moving and duplicating ops from before block to after block and after the loop. Add the pattern for such transformation.


Full diff: https://github.com/llvm/llvm-project/pull/89222.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h (+4)
  • (modified) mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp (+192)
  • (added) mlir/test/Dialect/SCF/uplift-while-prepare.mlir (+74)
  • (modified) mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp (+23)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
index fdf25706269803..244423274c0555 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h
@@ -79,6 +79,10 @@ void populateSCFLoopPipeliningPatterns(RewritePatternSet &patterns,
 /// loop bounds and loop steps are canonicalized.
 void populateSCFForLoopCanonicalizationPatterns(RewritePatternSet &patterns);
 
+/// Populate patterns to prepare scf.while loops for upliting, e.g. for before
+/// block cleanup.
+void populatePrepareUpliftWhileToForPatterns(RewritePatternSet &patterns);
+
 /// Populate patterns to uplift `scf.while` ops to `scf.for`.
 /// Uplifitng expects a specific ops pattern:
 ///  * `before` block consisting of single arith.cmp op
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index 7b4024b6861a72..959c30315ae8af 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -20,7 +20,194 @@
 
 using namespace mlir;
 
+static Operation *findOpToMoveFromBefore(scf::WhileOp loop) {
+  Block *body = loop.getBeforeBody();
+  if (body->without_terminator().empty())
+    return nullptr;
+
+  // Check last op first.
+  // TODO: It's usually safe to move and duplicate last op even if it has side
+  // effects, as long as the sequence of the ops executed on each path will stay
+  // the same. Exceptions are GPU barrier/group ops, LLVM proper has
+  // convergent attribute/semantics to check this, but we doesn't model it yet.
+  Operation *lastOp = &(*std::prev(body->without_terminator().end()));
+
+  auto term = loop.getConditionOp();
+  Operation *termCondOp = term.getCondition().getDefiningOp();
+  if (lastOp != termCondOp)
+    return lastOp;
+
+  // Try to move terminator args producers.
+  for (Value termArg : term.getArgs()) {
+    Operation *op = termArg.getDefiningOp();
+    if (!op || op->getParentOp() != loop || op == termCondOp || !isPure(op))
+      continue;
+
+    // Each result must be only used as terminator arg, meaning it can have one
+    // use at max, duplicated terminator args must be already cleaned up
+    // by canonicalizations at this point.
+    if (!llvm::all_of(op->getResults(), [&](Value val) {
+          return val.hasOneUse() || val.use_empty();
+        }))
+      continue;
+
+    return op;
+  }
+  return nullptr;
+}
+
 namespace {
+/// `scf.while` uplifting expects before block consisting of single cmp op,
+/// try to move ops from before block to after block and to after loop.
+///
+/// ```
+/// scf.while(...) {
+/// before:
+///   ...
+///   some_op()
+///   scf.condition ..
+/// after:
+///   ...
+/// }
+/// ```
+/// to
+/// ```
+/// scf.while(...) {
+/// before:
+///   ...
+///   scf.condition ..
+/// after:
+///   some_op()
+///   ...
+/// }
+/// some_op()
+/// ```
+struct MoveOpsFromBefore : public OpRewritePattern<scf::WhileOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(scf::WhileOp loop,
+                                PatternRewriter &rewriter) const override {
+    Operation *opToMove = findOpToMoveFromBefore(loop);
+    if (!opToMove)
+      return rewriter.notifyMatchFailure(loop, "No suitable ops found");
+
+    auto condOp = loop.getConditionOp();
+    SmallVector<Value> newCondArgs;
+
+    // Populate new terminator args.
+
+    // Add original terminator args, except args produced by the op we decided
+    // to move.
+    for (Value arg : condOp.getArgs()) {
+      if (arg.getDefiningOp() == opToMove)
+        continue;
+
+      newCondArgs.emplace_back(arg);
+    }
+    auto originalArgsOffset = newCondArgs.size();
+
+    // Add moved op operands to terminator args, if they are defined in loop
+    // block.
+    DominanceInfo dom;
+    for (Value arg : opToMove->getOperands()) {
+      if (dom.properlyDominates(arg, loop))
+        continue;
+
+      newCondArgs.emplace_back(arg);
+    }
+
+    // Create new loop.
+    ValueRange tempRange(newCondArgs);
+    auto newLoop = rewriter.create<mlir::scf::WhileOp>(
+        loop.getLoc(), TypeRange(tempRange), loop.getInits(), nullptr, nullptr);
+
+    OpBuilder::InsertionGuard g(rewriter);
+
+    // Create new terminator, old terminator will be deleted later.
+    rewriter.setInsertionPoint(condOp);
+    rewriter.create<scf::ConditionOp>(condOp.getLoc(), condOp.getCondition(),
+                                      newCondArgs);
+
+    Block *oldBefore = loop.getBeforeBody();
+    Block *newBefore = newLoop.getBeforeBody();
+
+    // Inline before block as is.
+    rewriter.inlineBlockBefore(oldBefore, newBefore, newBefore->begin(),
+                               newBefore->getArguments());
+
+    Block *oldAfter = loop.getAfterBody();
+    Block *newAfter = newLoop.getAfterBody();
+
+    // Build mapping between original op args and new after block args/new loop
+    // results.
+    IRMapping afterBodyMapping;
+    IRMapping afterLoopMapping;
+    {
+      ValueRange blockArgs =
+          newAfter->getArguments().drop_front(originalArgsOffset);
+      ValueRange newLoopArgs =
+          newLoop.getResults().drop_front(originalArgsOffset);
+      for (Value arg : opToMove->getOperands()) {
+        if (dom.properlyDominates(arg, loop))
+          continue;
+
+        assert(!blockArgs.empty());
+        assert(!newLoopArgs.empty());
+        afterBodyMapping.map(arg, blockArgs.front());
+        afterLoopMapping.map(arg, newLoopArgs.front());
+        blockArgs = blockArgs.drop_front();
+        newLoopArgs = newLoopArgs.drop_front();
+      }
+    }
+
+    {
+      // Clone op into after body.
+      rewriter.setInsertionPointToStart(oldAfter);
+      Operation *newAfterBodyOp = rewriter.clone(*opToMove, afterBodyMapping);
+
+      // Clone op after loop.
+      rewriter.setInsertionPointAfter(newLoop);
+      Operation *newAfterLoopOp = rewriter.clone(*opToMove, afterLoopMapping);
+
+      // Build mapping between old and new after block args and between old and
+      // new loop results.
+      ValueRange blockArgs =
+          newAfter->getArguments().take_front(originalArgsOffset);
+      ValueRange newLoopArgs =
+          newLoop.getResults().take_front(originalArgsOffset);
+      SmallVector<Value> argsMapping;
+      SmallVector<Value> newLoopResults;
+      for (Value arg : condOp.getArgs()) {
+        if (arg.getDefiningOp() == opToMove) {
+          auto resNumber = cast<OpResult>(arg).getResultNumber();
+          argsMapping.emplace_back(newAfterBodyOp->getResult(resNumber));
+          newLoopResults.emplace_back(newAfterLoopOp->getResult(resNumber));
+          continue;
+        }
+
+        assert(!blockArgs.empty());
+        assert(!newLoopArgs.empty());
+        argsMapping.emplace_back(blockArgs.front());
+        newLoopResults.emplace_back(newLoopArgs.front());
+        blockArgs = blockArgs.drop_front();
+        newLoopArgs = newLoopArgs.drop_front();
+      }
+
+      // Inline after block.
+      rewriter.inlineBlockBefore(oldAfter, newAfter, newAfter->begin(),
+                                 argsMapping);
+
+      // Replace loop.
+      rewriter.replaceOp(loop, newLoopResults);
+    }
+
+    // Finally, we can remove old terminator and the original op.
+    rewriter.eraseOp(condOp);
+    rewriter.eraseOp(opToMove);
+    return success();
+  }
+};
+
 struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -209,6 +396,11 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
   return newLoop;
 }
 
+void mlir::scf::populatePrepareUpliftWhileToForPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<MoveOpsFromBefore>(patterns.getContext());
+}
+
 void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
   patterns.add<UpliftWhileOp>(patterns.getContext());
 }
diff --git a/mlir/test/Dialect/SCF/uplift-while-prepare.mlir b/mlir/test/Dialect/SCF/uplift-while-prepare.mlir
new file mode 100644
index 00000000000000..fd359efa20ba0c
--- /dev/null
+++ b/mlir/test/Dialect/SCF/uplift-while-prepare.mlir
@@ -0,0 +1,74 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(test-scf-prepare-uplift-while-to-for))' -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: func.func @test()
+//       CHECK:  scf.while
+//   CHECK-NOT:  "test.test1"
+//       CHECK:  scf.condition(%{{.*}})
+//       CHECK:  } do {
+//       CHECK:  "test.test1"() : () -> ()
+//       CHECK:  "test.test2"() : () -> ()
+//       CHECK:  scf.yield
+//       CHECK:  "test.test1"() : () -> ()
+//       CHECK:  return
+func.func @test() {
+  scf.while () : () -> () {
+    %1 = "test.cond"() : () -> i1
+    "test.test1"() : () -> ()
+    scf.condition(%1)
+  } do {
+  ^bb0():
+    "test.test2"() : () -> ()
+    scf.yield
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test()
+//       CHECK:  scf.while
+//   CHECK-NOT:  "test.test1"
+//       CHECK:  scf.condition(%{{.*}})
+//       CHECK:  } do {
+//       CHECK:  %[[R1:.*]]:2 = "test.test1"() : () -> (i32, i64)
+//       CHECK:  "test.test2"(%[[R1]]#1, %[[R1]]#0) : (i64, i32) -> ()
+//       CHECK:  scf.yield
+//       CHECK:  %[[R2:.*]]:2 = "test.test1"() : () -> (i32, i64)
+//       CHECK:  return %[[R2]]#1, %[[R2]]#0 : i64, i32
+func.func @test() -> (i64, i32) {
+  %0:2 = scf.while () : () -> (i64, i32) {
+    %1 = "test.cond"() : () -> i1
+    %2:2 = "test.test1"() : () -> (i32, i64)
+    scf.condition(%1) %2#1, %2#0 : i64, i32
+  } do {
+  ^bb0(%arg1: i64, %arg2: i32):
+    "test.test2"(%arg1, %arg2) : (i64, i32) -> ()
+    scf.yield
+  }
+  return %0#0, %0#1 : i64, i32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @test
+//  CHECK-SAME:  (%[[ARG0:.*]]: index)
+//       CHECK:  %[[RES:.*]] = scf.while (%[[ARG1:.*]] = %[[ARG0]]) : (index) -> index {
+//   CHECK-NOT:  arith.addi
+//       CHECK:  scf.condition(%{{.*}}) %[[ARG1]] : index
+//       CHECK:  } do {
+//       CHECK:  ^bb0(%[[ARG2:.*]]: index):
+//       CHECK:  %[[A1:.*]] = arith.addi %[[ARG0]], %[[ARG2]] : index
+//       CHECK:  scf.yield %[[A1]]
+//       CHECK:  %[[A2:.*]] = arith.addi %[[ARG0]], %[[RES]] : index
+//       CHECK:  return %[[A2]]
+func.func @test(%arg0: index) -> index {
+  %res = scf.while (%arg1 = %arg0) : (index) -> (index) {
+    %0 = arith.addi %arg0, %arg1 : index
+    %1 = "test.cond"() : () -> i1
+    scf.condition(%1) %0 : index
+  } do {
+  ^bb0(%arg2: index):
+    scf.yield %arg2 : index
+  }
+  return %res : index
+}
diff --git a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp
index 468bc0ca78489f..3eaad9eaa8a731 100644
--- a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp
@@ -19,6 +19,28 @@ using namespace mlir;
 
 namespace {
 
+struct TestSCFPrepareUpliftWhileToFor
+    : public PassWrapper<TestSCFPrepareUpliftWhileToFor, OperationPass<void>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPrepareUpliftWhileToFor)
+
+  StringRef getArgument() const final {
+    return "test-scf-prepare-uplift-while-to-for";
+  }
+
+  StringRef getDescription() const final {
+    return "test scf while to for uplifting preparation";
+  }
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    MLIRContext *ctx = op->getContext();
+    RewritePatternSet patterns(ctx);
+    scf::populatePrepareUpliftWhileToForPatterns(patterns);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+
 struct TestSCFUpliftWhileToFor
     : public PassWrapper<TestSCFUpliftWhileToFor, OperationPass<void>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFUpliftWhileToFor)
@@ -44,6 +66,7 @@ struct TestSCFUpliftWhileToFor
 namespace mlir {
 namespace test {
 void registerTestSCFUpliftWhileToFor() {
+  PassRegistration<TestSCFPrepareUpliftWhileToFor>();
   PassRegistration<TestSCFUpliftWhileToFor>();
 }
 } // namespace test

@joker-eph
Copy link
Collaborator

Can we also expose the APIs without forcing to use patterns?
I can imagine an uplifting pass avoiding to use the greedy rewriter.

@Hardcode84
Copy link
Contributor Author

Can we also expose the APIs without forcing to use patterns?
I can imagine an uplifting pass avoiding to use the greedy rewriter.

I can do it, but this pattern sinks one op at the time and expected to be applied repeatedly, should we ask user to do it or add an implicit loop to the version exposed to user?

@Hardcode84
Copy link
Contributor Author

ping

@joker-eph
Copy link
Collaborator

I can do it, but this pattern sinks one op at the time

That seems quite inefficient, why can't we sink whatever needs to be sunk?

@Hardcode84
Copy link
Contributor Author

That seems quite inefficient, why can't we sink whatever needs to be sunk?

It's possible, but current logic for selecting candidates and updating scf.while block args/results is already quite convoluted, and it easier to reason when it single op at the time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants