Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][SCF] Retire SCF-specific to_memref/to_tensor canonicalization patterns #74551

Conversation

matthias-springer
Copy link
Member

The partial bufferization framework has been replaced with One-Shot Bufferize. SCF-specific canonicalization patterns for to_memref/to_tensor are no longer needed.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 6, 2023

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

The partial bufferization framework has been replaced with One-Shot Bufferize. SCF-specific canonicalization patterns for to_memref/to_tensor are no longer needed.


Full diff: https://github.com/llvm/llvm-project/pull/74551.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/CMakeLists.txt (-1)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2-130)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (-50)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (-1)
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index 9882b843c285e..d4bfe3285c987 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRSCFDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
-  MLIRBufferizationDialect
   MLIRControlFlowDialect
   MLIRFunctionInterfaces
   MLIRIR
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 3b55704c4ea07..cf807a2adc10e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
@@ -1082,139 +1081,12 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
   }
 };
 
-/// Canonicalize the iter_args of an scf::ForOp that involve a
-/// `bufferization.to_tensor` and for which only the last loop iteration is
-/// actually visible outside of the loop. The canonicalization looks for a
-/// pattern such as:
-/// ```
-///    %t0 = ... : tensor_type
-///    %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
-///      ...
-///      // %m is either buffer_cast(%bb00) or defined above the loop
-///      %m... : memref_type
-///      ... // uses of %m with potential inplace updates
-///      %new_tensor = bufferization.to_tensor %m : memref_type
-///      ...
-///      scf.yield %new_tensor : tensor_type
-///    }
-/// ```
-///
-/// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
-/// `%m = buffer_cast %bb0` op that feeds into the yielded
-/// `bufferization.to_tensor` op.
-///
-/// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
-/// occurs between `bufferization.to_tensor and yield then the value %0
-/// visible outside of the loop is the last `bufferization.to_tensor`
-/// produced in the loop.
-///
-/// For now, we approximate the absence of aliasing by only supporting the case
-/// when the bufferization.to_tensor is the operation immediately preceding
-/// the yield.
-//
-/// The canonicalization rewrites the pattern as:
-/// ```
-///    // %m is either a buffer_cast or defined above
-///    %m... : memref_type
-///    scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
-///      ... // uses of %m with potential inplace updates
-///      scf.yield %bb0: tensor_type
-///    }
-///    %0 = bufferization.to_tensor %m : memref_type
-/// ```
-///
-/// A later bbArg canonicalization will further rewrite as:
-/// ```
-///    // %m is either a buffer_cast or defined above
-///    %m... : memref_type
-///    scf.for ... { // no iter_args
-///      ... // uses of %m with potential inplace updates
-///    }
-///    %0 = bufferization.to_tensor %m : memref_type
-/// ```
-struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
-  using OpRewritePattern<ForOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ForOp forOp,
-                                PatternRewriter &rewriter) const override {
-    assert(std::next(forOp.getRegion().begin()) == forOp.getRegion().end() &&
-           "unexpected multiple blocks");
-
-    Location loc = forOp.getLoc();
-    DenseMap<Value, Value> replacements;
-    for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
-      unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
-      auto yieldOp =
-          cast<scf::YieldOp>(forOp.getRegion().front().getTerminator());
-      Value yieldVal = yieldOp->getOperand(idx);
-      auto tensorLoadOp = yieldVal.getDefiningOp<bufferization::ToTensorOp>();
-      bool isTensor = llvm::isa<TensorType>(bbArg.getType());
-
-      bufferization::ToMemrefOp tensorToMemref;
-      // Either bbArg has no use or it has a single buffer_cast use.
-      if (bbArg.hasOneUse())
-        tensorToMemref =
-            dyn_cast<bufferization::ToMemrefOp>(*bbArg.getUsers().begin());
-      if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !tensorToMemref))
-        continue;
-      // If tensorToMemref is present, it must feed into the `ToTensorOp`.
-      if (tensorToMemref && tensorLoadOp.getMemref() != tensorToMemref)
-        continue;
-      // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
-      // must be before `ToTensorOp` in the block so that the lastWrite
-      // property is not subject to additional side-effects.
-      // For now, we only support the case when ToTensorOp appears
-      // immediately before the terminator.
-      if (tensorLoadOp->getNextNode() != yieldOp)
-        continue;
-
-      // Clone the optional tensorToMemref before forOp.
-      if (tensorToMemref) {
-        rewriter.setInsertionPoint(forOp);
-        rewriter.replaceOpWithNewOp<bufferization::ToMemrefOp>(
-            tensorToMemref, tensorToMemref.getMemref().getType(),
-            tensorToMemref.getTensor());
-      }
-
-      // Clone the tensorLoad after forOp.
-      rewriter.setInsertionPointAfter(forOp);
-      Value newTensorLoad = rewriter.create<bufferization::ToTensorOp>(
-          loc, tensorLoadOp.getMemref());
-      Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
-      replacements.insert(std::make_pair(forOpResult, newTensorLoad));
-
-      // Make the terminator just yield the bbArg, the old tensorLoadOp + the
-      // old bbArg (that is now directly yielded) will canonicalize away.
-      rewriter.startRootUpdate(yieldOp);
-      yieldOp.setOperand(idx, bbArg);
-      rewriter.finalizeRootUpdate(yieldOp);
-    }
-    if (replacements.empty())
-      return failure();
-
-    // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
-    // replaces the whole op and erase it unconditionally. This is wrong for
-    // `forOp` as it generally contains ops with side effects.
-    // Instead, use `rewriter.replaceOpWithIf`.
-    SmallVector<Value> newResults;
-    newResults.reserve(forOp.getNumResults());
-    for (Value v : forOp.getResults()) {
-      auto it = replacements.find(v);
-      newResults.push_back((it != replacements.end()) ? it->second : v);
-    }
-    unsigned idx = 0;
-    rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
-      return op.get() != newResults[idx++];
-    });
-    return success();
-  }
-};
 } // namespace
 
 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                         MLIRContext *context) {
-  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
-              LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
+  results.add<ForOpIterArgsFolder, SimplifyTrivialLoops, ForOpTensorCastFolder>(
+      context);
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 9dbf8d5dab11a..41e028028616a 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -773,56 +773,6 @@ func.func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) {
 
 // -----
 
-func.func private @process(%0 : memref<128x128xf32>)
-func.func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32>
-
-// CHECK-LABEL: last_value
-//  CHECK-SAME:   %[[T0:[0-9a-z]*]]: tensor<128x128xf32>
-//  CHECK-SAME:   %[[T1:[0-9a-z]*]]: tensor<128x128xf32>
-//  CHECK-SAME:   %[[T2:[0-9a-z]*]]: tensor<128x128xf32>
-//  CHECK-SAME:   %[[M0:[0-9a-z]*]]: memref<128x128xf32>
-func.func @last_value(%t0: tensor<128x128xf32>, %t1: tensor<128x128xf32>,
-                 %t2: tensor<128x128xf32>, %m0: memref<128x128xf32>,
-                 %lb : index, %ub : index, %step : index)
-  -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
-{
-  // CHECK-NEXT: %[[M1:.*]] = bufferization.to_memref %[[T1]] : memref<128x128xf32>
-  // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args(%[[BBARG_T2:.*]] = %[[T2]]) -> (tensor<128x128xf32>) {
-  %0:3 = scf.for %arg0 = %lb to %ub step %step iter_args(%arg1 = %t0, %arg2 = %t1, %arg3 = %t2)
-    -> (tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>)
-  {
-    %m1 = bufferization.to_memref %arg2 : memref<128x128xf32>
-
-    // CHECK-NEXT:   call @process(%[[M0]]) : (memref<128x128xf32>) -> ()
-    func.call @process(%m0) : (memref<128x128xf32>) -> ()
-
-    // CHECK-NEXT:   call @process(%[[M1]]) : (memref<128x128xf32>) -> ()
-    func.call @process(%m1) : (memref<128x128xf32>) -> ()
-
-    // This does not hoist (fails the bbArg has at most a single check).
-    // CHECK-NEXT:   %[[T:.*]] = func.call @process_tensor(%[[BBARG_T2]]) : (tensor<128x128xf32>) -> memref<128x128xf32>
-    // CHECK-NEXT:   %[[YIELD_T:.*]] = bufferization.to_tensor %[[T:.*]]
-    %m2 = func.call @process_tensor(%arg3): (tensor<128x128xf32>) -> memref<128x128xf32>
-    %3 = bufferization.to_tensor %m2 : memref<128x128xf32>
-
-    // All this stuff goes away, incrementally
-    %1 = bufferization.to_tensor %m0 : memref<128x128xf32>
-    %2 = bufferization.to_tensor %m1 : memref<128x128xf32>
-
-    // CHECK-NEXT:   scf.yield %[[YIELD_T]] : tensor<128x128xf32>
-    scf.yield %1, %2, %3 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
-
-  // CHECK-NEXT: }
-  }
-
-  // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32>
-  // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32>
-  // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
-  return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32>
-}
-
-// -----
-
 // CHECK-LABEL: fold_away_iter_with_no_use_and_yielded_input
 //  CHECK-SAME:   %[[A0:[0-9a-z]*]]: i32
 func.func @fold_away_iter_with_no_use_and_yielded_input(%arg0 : i32,
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 4fb6a50a174c2..2a3ebbba02384 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3994,7 +3994,6 @@ cc_library(
     deps = [
         ":ArithDialect",
         ":ArithUtils",
-        ":BufferizationDialect",
         ":ControlFlowDialect",
         ":ControlFlowInterfaces",
         ":DestinationStyleOpInterface",

matthias-springer added a commit to matthias-springer/llvm-project that referenced this pull request Dec 6, 2023
…der` pattern

`ParallelOpSingleOrZeroIterationDimsFolder` used to produce invalid IR:
```
within split at mlir/test/Dialect/SCF/canonicalize.mlir:1 offset :11:3: error: 'scf.parallel' op expects region #0 to have 0 or 1 blocks
  scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c2, %c3) {
  ^
within split at mlir/test/Dialect/SCF/canonicalize.mlir:1 offset :11:3: note: see current operation:
"scf.parallel"(%4, %5, %3) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({
^bb0(%arg1: index):
  "memref.store"(%0, %arg0, %1, %arg1, %6) : (i32, memref<?x?x?xi32>, index, index, index) -> ()
  "scf.yield"() : () -> ()
^bb1(%8: index):  // no predecessors
  "scf.yield"() : () -> ()
}) : (index, index, index) -> ()
```

Together with llvm#74551, this commit fixes `mlir/test/Dialect/SCF/canonicalize.mlir` when verifying the IR after each pattern application (llvm#74270).
…ion patterns

The partial bufferization framework has been replaced with One-Shot Bufferize. SCF-specific canonicalization patterns for `to_memref`/`to_tensor` are no longer needed.
matthias-springer added a commit that referenced this pull request Dec 6, 2023
…der` pattern (#74552)

`ParallelOpSingleOrZeroIterationDimsFolder` used to produce invalid IR:
```
within split at mlir/test/Dialect/SCF/canonicalize.mlir:1 offset :11:3: error: 'scf.parallel' op expects region #0 to have 0 or 1 blocks
  scf.parallel (%i0, %i1, %i2) = (%c0, %c3, %c7) to (%c1, %c6, %c10) step (%c1, %c2, %c3) {
  ^
within split at mlir/test/Dialect/SCF/canonicalize.mlir:1 offset :11:3: note: see current operation:
"scf.parallel"(%4, %5, %3) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({
^bb0(%arg1: index):
  "memref.store"(%0, %arg0, %1, %arg1, %6) : (i32, memref<?x?x?xi32>, index, index, index) -> ()
  "scf.yield"() : () -> ()
^bb1(%8: index):  // no predecessors
  "scf.yield"() : () -> ()
}) : (index, index, index) -> ()
```

Together with #74551, this commit fixes
`mlir/test/Dialect/SCF/canonicalize.mlir` when verifying the IR after
each pattern application (#74270).
@matthias-springer matthias-springer merged commit 77f5b33 into llvm:main Dec 6, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants