Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] GreedyPatternRewriteDriver: Do not CSE constants during iterations #75897

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 19, 2023

The GreedyPatternRewriteDriver tries to iteratively fold ops and apply rewrite patterns to ops. It has special handling for constants: they are CSE'd and sometimes moved to parent regions to allow for additional CSE'ing. This happens in OperationFolder.

To allow for efficient CSE'ing, OperationFolder maintains an internal lookup data structure to find the existing constant ops with the same value for each IsolatedFromAbove region:

/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes;

Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent.

We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different IsolatedFromAbove region (that is not accessible).

This commit changes the behavior of the GreedyPatternRewriteDriver such that OperationFolder is used to CSE constants at the beginning of each iteration (as the worklist is populated), but no longer during an iteration. OperationFolder is no longer used after populating the worklist, so we do not have to care about inconsistent state in the OperationFolder due to IR rewrites. The GreedyPatternRewriteDriver now performs the op folding by itself instead of calling OperationFolder::tryToFold.

This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning CHECK into CHECK-DAG.

Alternatives considered: The state of OperationFolder could be partially invalidated with every notifyOperationModified notification. That is more fragile than the solution in this commit because incorrect rewriter API usage can lead to missing notifications and hard-to-debug IsolatedFromAbove violations. (It did not fix the above mention bug in a downstream project, which could be due to incorrect rewriter API usage or due to another conceptual problem that I missed.) Moreover, ops are frequently getting modified during a greedy pattern rewrite, so we would likely keep invalidating large parts of the state of OperationFolder over and over.

Migration guide: Turn CHECK into CHECK-DAG in test cases. Constant ops are no longer folded during a greedy pattern rewrite. If you rely on folding (and rematerialization) of constant ops during a greedy pattern rewrite, turn the folder into a pattern.

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 19, 2023

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-math

Author: Matthias Springer (matthias-springer)

Changes

The GreedyPatternRewriteDriver tries to iteratively fold ops and apply rewrite patterns to ops. It has special handling for constants: they are CSE'd and sometimes moved to parent regions to allow for additional CSE'ing. This happens in OperationFolder.

To allow for efficient CSE'ing, OperationFolder maintains an internal lookup data structure to find the existing constant ops with the same value for each IsolatedFromAbove region:

/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap&lt;Region *, ConstantMap&gt; foldScopes;

Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent.

We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different IsolatedFromAbove region (that is not accessible).

This commit changes the behavior of the GreedyPatternRewriteDriver such that OperationFolder is used to CSE constants at the beginning of each iteration (as the worklist is populated), but no longer during an iteration. OperationFolder is no longer used after populating the worklist, so we do not have to care about inconsistent state in the OperationFolder due to IR rewrites. The GreedyPatternRewriteDriver now performs the op folding by itself instead of calling OperationFolder::tryToFold.

This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning CHECK into CHECK-DAG.

Alternatives considered: The state of OperationFolder could be partially invalidated with every notifyOperationModified notification. That is more fragile than the solution in this commit because incorrect rewriter API usage can lead to missing notifications and hard-to-debug IsolatedFromAbove violations. (It did not fix the above mention bug in a downstream project, which could be due to incorrect rewriter API usage or due to another conceptual problem that I missed.) Moreover, ops are frequently getting modified during a greedy pattern rewrite, so we would likely keep invalidating large parts of the state of OperationFolder over and over.

Depends on #75887. Review only the top commit.


Patch is 86.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75897.diff

35 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-4)
  • (modified) mlir/lib/IR/Builders.cpp (+2-3)
  • (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+2-4)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+55-19)
  • (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (+5-5)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+4-4)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+15-15)
  • (modified) mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir (+8-8)
  • (modified) mlir/test/Dialect/LLVMIR/type-consistency.mlir (+29-29)
  • (modified) mlir/test/Dialect/Linalg/loops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+48-48)
  • (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+14-14)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+12-12)
  • (modified) mlir/test/Dialect/Math/polynomial-approximation.mlir (+36-36)
  • (modified) mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir (+6-6)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_concat.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_storage.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+8-8)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+6-9)
  • (modified) mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+3-3)
  • (modified) mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir (+7-7)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+17-17)
  • (modified) mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir (+2-2)
  • (modified) mlir/test/Transforms/test-canonicalize.mlir (-13)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (-10)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (-4)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (-7)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
       setOperand(src);
       return getResult();
     }
+
     // trunci(zexti(a)) -> a
     // trunci(sexti(a)) -> a
-    return src;
+    if (srcType == dstType)
+      return src;
   }
 
   // trunci(trunci(a)) -> trunci(a))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..ac9485326a32ed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1600,11 +1600,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
                                        : 0;
   };
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0)
-    return source;
 
+  unsigned broadcastSrcRank = getRank(source.getType());
   unsigned extractResultRank = getRank(extractOp.getType());
   if (extractResultRank >= broadcastSrcRank)
     return Value();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..c28cbe109c3ffd 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -491,9 +491,8 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
     // Normal values get pushed back directly.
     if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
-      if (value.getType() != expectedType)
-        return cleanupFailure();
-
+      assert(value.getType() == expectedType &&
+             "folder produced value of incorrect type");
       results.push_back(value);
       continue;
     }
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..34b7117a035748 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,8 @@ OperationFolder::processFoldResults(Operation *op,
 
     // Check if the result was an SSA value.
     if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
-      if (repl.getType() != op->getResult(i).getType()) {
-        results.clear();
-        return failure();
-      }
+      assert(repl.getType() == op->getResult(i).getType() &&
+             "folder produced value of incorrect type");
       results.emplace_back(repl);
       continue;
     }
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7decbce018a878..eb63050c6c3354 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -313,9 +313,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   Worklist worklist;
 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
 
-  /// Non-pattern based folder for operations.
-  OperationFolder folder;
-
   /// Configuration information for how to simplify.
   const GreedyRewriteConfig config;
 
@@ -428,11 +425,47 @@ bool GreedyPatternRewriteDriver::processWorklist() {
       continue;
     }
 
-    // Try to fold this op.
-    if (succeeded(folder.tryToFold(op))) {
-      LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-      changed = true;
-      continue;
+    // Try to fold this op. Do not fold constant ops. That would lead to an
+    // infinite folding loop, as every constant op would be folded to an
+    // Attribute and then immediately be rematerialized as a constant op, which
+    // is then put on the worklist.
+    if (!op->hasTrait<OpTrait::ConstantLike>()) {
+      SmallVector<OpFoldResult> foldResults;
+      if (succeeded(op->fold(foldResults))) {
+        LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
+        changed = true;
+        if (foldResults.empty()) {
+          // Op was modified in-place.
+          notifyOperationModified(op);
+          continue;
+        }
+
+        // Op results can be replaced with `foldResults`.
+        assert(foldResults.size() == op->getNumResults() &&
+               "folder produced incorrect number of results");
+        OpBuilder::InsertionGuard g(*this);
+        setInsertionPoint(op);
+        SmallVector<Value> replacements;
+        for (auto [ofr, resultType] :
+             llvm::zip_equal(foldResults, op->getResultTypes())) {
+          if (auto value = ofr.dyn_cast<Value>()) {
+            assert(value.getType() == resultType &&
+                   "folder produced value of incorrect type");
+            replacements.push_back(value);
+            continue;
+          }
+          // Materialize Attributes as SSA values.
+          Operation *constOp = op->getDialect()->materializeConstant(
+              *this, ofr.get<Attribute>(), resultType, op->getLoc());
+          assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
+                 "materializeConstant produced op that is not a ConstantLike");
+          assert(constOp->getResultTypes()[0] == resultType &&
+                 "materializeConstant produced incorrect result type");
+          replacements.push_back(constOp->getResult(0));
+        }
+        replaceOp(op, replacements);
+        continue;
+      }
     }
 
     // Try to match one of the patterns. The rewriter is automatically
@@ -567,7 +600,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
 
   addOperandsToWorklist(op->getOperands());
   worklist.remove(op);
-  folder.notifyRemoval(op);
 
   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.erase(op);
@@ -647,16 +679,6 @@ class GreedyPatternRewriteIteration
 } // namespace
 
 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
-  auto insertKnownConstant = [&](Operation *op) {
-    // Check for existing constants when populating the worklist. This avoids
-    // accidentally reversing the constant order during processing.
-    Attribute constValue;
-    if (matchPattern(op, m_Constant(&constValue)))
-      if (!folder.insertKnownConstant(op, constValue))
-        return true;
-    return false;
-  };
-
   bool continueRewrites = false;
   int64_t iteration = 0;
   MLIRContext *ctx = getContext();
@@ -666,8 +688,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
         config.maxIterations != GreedyRewriteConfig::kNoLimit)
       break;
 
+    // New iteration: start with an empty worklist.
     worklist.clear();
 
+    // `OperationFolder` CSE's constant ops (and may move them into parents
+    // regions to enable more aggressive CSE'ing).
+    OperationFolder folder(getContext(), this);
+    auto insertKnownConstant = [&](Operation *op) {
+      // Check for existing constants when populating the worklist. This avoids
+      // accidentally reversing the constant order during processing.
+      Attribute constValue;
+      if (matchPattern(op, m_Constant(&constValue)))
+        if (!folder.insertKnownConstant(op, constValue))
+          return true;
+      return false;
+    };
+
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
       region.walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 6783263c184961..d3f02c6288a240 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
 
 // CHECK-LABEL:   func.func @broadcast_vec2d_from_i32(
 // CHECK-SAME:                                        %[[SRC:.*]]: i32) {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {
 
 // CHECK-LABEL:   func.func @transpose_i8(
 // CHECK-SAME:                            %[[TILE:.*]]: vector<[16]x[16]xi8>)
-// CHECK:           %[[C16:.*]] = arith.constant 16 : index
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
 // CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..90f134d57a3c13 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
 }
 // CHECK-LABEL: @broadcast_vec3d_from_vec1d(
 // CHECK-SAME:  %[[A:.*]]: vector<2xf32>)
-// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK:       %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
-// CHECK:       %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK-DAG:   %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK-DAG:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK-DAG:   %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK-DAG:   %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
 
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
 // CHECK:       %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..97eda0672c0299 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
 }
 
 // CHECK-SAME:      %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
-// CHECK:           %[[C_0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C_16:.*]] = arith.constant 16 : index
-// CHECK:           %[[STEP:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C_16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[STEP:.*]] = arith.constant 1 : index
 // CHECK:           %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
@@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
 }
 // CHECK-LABEL:   func.func @vector_print_vector_0d(
 // CHECK-SAME:                                      %[[VEC:.*]]: vector<f32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
@@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
 }
 // CHECK-LABEL:   func.func @vector_print_vector(
 // CHECK-SAME:                                   %[[VEC:.*]]: vector<2x2xf32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C2:.*]] = arith.constant 2 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
 }
 // CHECK-LABEL:   func.func @transfer_read_array_of_scalable(
 // CHECK-SAME:                                               %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
-// CHECK:           %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C3:.*]] = arith.constant 3 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
 // CHECK:           %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
 // CHECK:           %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
@@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
 // CHECK-LABEL:   func.func @transfer_write_array_of_scalable(
 // CHECK-SAME:                                                %[[VEC:.*]]: vector<3x[4]xf32>,
 // CHECK-SAME:                                                %[[MEMREF:.*]]: memref<3x?xf32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C3:.*]] = arith.constant 3 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
 // CHECK:           %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
 // CHECK:           %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index ae2d0f40f03af5..e51f2485dadbcc 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
 // -----
 
 // CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
@@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
 // -----
 
 // CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 3a1ab924ebdacb..021151b929d8e2 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -168,8 +168,8 @@ llvm.func @no_crash_on_negative_gep_index() {
 // CHECK-LABEL: llvm.func @coalesced_store_ints
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_ints(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
 
   %0 = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
@@ -193,8 +193,8 @@ llvm.func @coalesced_store_ints(%arg: i64) {
 // CHECK-LABEL: llvm.func @coalesced_store_ints_offset
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_ints_offset(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
   %0 = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr
@@ -218,8 +218,8 @@ llvm.func @coalesced_store_ints_offset(%arg: i64) {
 // CHECK-LABEL: llvm.func @coalesced_store_floats
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_floats(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
   %0 = llvm.mlir.constant(1 : i32) : i32
 
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)>
@@ -292,9 +292,9 @@ llvm.func @coalesced_store_past_end(%arg: i64) {
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_packed_struct(%arg: i64) {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
-  // CHECK: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
+  // CHECK-DAG: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
 
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr
@@ -320,10 +320,10 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
 // CHECK-LABEL: llvm.func @vector_write_split
 // CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
 llvm.func @vector_write_split(%arg: vector<4xi32>) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 19, 2023

@llvm/pr-subscribers-mlir-arith

Author: Matthias Springer (matthias-springer)

Changes

The GreedyPatternRewriteDriver tries to iteratively fold ops and apply rewrite patterns to ops. It has special handling for constants: they are CSE'd and sometimes moved to parent regions to allow for additional CSE'ing. This happens in OperationFolder.

To allow for efficient CSE'ing, OperationFolder maintains an internal lookup data structure to find the existing constant ops with the same value for each IsolatedFromAbove region:

/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap&lt;Region *, ConstantMap&gt; foldScopes;

Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent.

We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different IsolatedFromAbove region (that is not accessible).

This commit changes the behavior of the GreedyPatternRewriteDriver such that OperationFolder is used to CSE constants at the beginning of each iteration (as the worklist is populated), but no longer during an iteration. OperationFolder is no longer used after populating the worklist, so we do not have to care about inconsistent state in the OperationFolder due to IR rewrites. The GreedyPatternRewriteDriver now performs the op folding by itself instead of calling OperationFolder::tryToFold.

This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning CHECK into CHECK-DAG.

Alternatives considered: The state of OperationFolder could be partially invalidated with every notifyOperationModified notification. That is more fragile than the solution in this commit because incorrect rewriter API usage can lead to missing notifications and hard-to-debug IsolatedFromAbove violations. (It did not fix the above mention bug in a downstream project, which could be due to incorrect rewriter API usage or due to another conceptual problem that I missed.) Moreover, ops are frequently getting modified during a greedy pattern rewrite, so we would likely keep invalidating large parts of the state of OperationFolder over and over.

Depends on #75887. Review only the top commit.


Patch is 86.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75897.diff

35 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-4)
  • (modified) mlir/lib/IR/Builders.cpp (+2-3)
  • (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+2-4)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+55-19)
  • (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (+5-5)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+4-4)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+15-15)
  • (modified) mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir (+8-8)
  • (modified) mlir/test/Dialect/LLVMIR/type-consistency.mlir (+29-29)
  • (modified) mlir/test/Dialect/Linalg/loops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+48-48)
  • (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+14-14)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+12-12)
  • (modified) mlir/test/Dialect/Math/polynomial-approximation.mlir (+36-36)
  • (modified) mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir (+6-6)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_concat.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_storage.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+8-8)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+6-9)
  • (modified) mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+3-3)
  • (modified) mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir (+7-7)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+17-17)
  • (modified) mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir (+2-2)
  • (modified) mlir/test/Transforms/test-canonicalize.mlir (-13)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (-10)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (-4)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (-7)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
       setOperand(src);
       return getResult();
     }
+
     // trunci(zexti(a)) -> a
     // trunci(sexti(a)) -> a
-    return src;
+    if (srcType == dstType)
+      return src;
   }
 
   // trunci(trunci(a)) -> trunci(a))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..ac9485326a32ed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1600,11 +1600,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
                                        : 0;
   };
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0)
-    return source;
 
+  unsigned broadcastSrcRank = getRank(source.getType());
   unsigned extractResultRank = getRank(extractOp.getType());
   if (extractResultRank >= broadcastSrcRank)
     return Value();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..c28cbe109c3ffd 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -491,9 +491,8 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
     // Normal values get pushed back directly.
     if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
-      if (value.getType() != expectedType)
-        return cleanupFailure();
-
+      assert(value.getType() == expectedType &&
+             "folder produced value of incorrect type");
       results.push_back(value);
       continue;
     }
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..34b7117a035748 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,8 @@ OperationFolder::processFoldResults(Operation *op,
 
     // Check if the result was an SSA value.
     if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
-      if (repl.getType() != op->getResult(i).getType()) {
-        results.clear();
-        return failure();
-      }
+      assert(repl.getType() == op->getResult(i).getType() &&
+             "folder produced value of incorrect type");
       results.emplace_back(repl);
       continue;
     }
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7decbce018a878..eb63050c6c3354 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -313,9 +313,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   Worklist worklist;
 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
 
-  /// Non-pattern based folder for operations.
-  OperationFolder folder;
-
   /// Configuration information for how to simplify.
   const GreedyRewriteConfig config;
 
@@ -428,11 +425,47 @@ bool GreedyPatternRewriteDriver::processWorklist() {
       continue;
     }
 
-    // Try to fold this op.
-    if (succeeded(folder.tryToFold(op))) {
-      LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-      changed = true;
-      continue;
+    // Try to fold this op. Do not fold constant ops. That would lead to an
+    // infinite folding loop, as every constant op would be folded to an
+    // Attribute and then immediately be rematerialized as a constant op, which
+    // is then put on the worklist.
+    if (!op->hasTrait<OpTrait::ConstantLike>()) {
+      SmallVector<OpFoldResult> foldResults;
+      if (succeeded(op->fold(foldResults))) {
+        LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
+        changed = true;
+        if (foldResults.empty()) {
+          // Op was modified in-place.
+          notifyOperationModified(op);
+          continue;
+        }
+
+        // Op results can be replaced with `foldResults`.
+        assert(foldResults.size() == op->getNumResults() &&
+               "folder produced incorrect number of results");
+        OpBuilder::InsertionGuard g(*this);
+        setInsertionPoint(op);
+        SmallVector<Value> replacements;
+        for (auto [ofr, resultType] :
+             llvm::zip_equal(foldResults, op->getResultTypes())) {
+          if (auto value = ofr.dyn_cast<Value>()) {
+            assert(value.getType() == resultType &&
+                   "folder produced value of incorrect type");
+            replacements.push_back(value);
+            continue;
+          }
+          // Materialize Attributes as SSA values.
+          Operation *constOp = op->getDialect()->materializeConstant(
+              *this, ofr.get<Attribute>(), resultType, op->getLoc());
+          assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
+                 "materializeConstant produced op that is not a ConstantLike");
+          assert(constOp->getResultTypes()[0] == resultType &&
+                 "materializeConstant produced incorrect result type");
+          replacements.push_back(constOp->getResult(0));
+        }
+        replaceOp(op, replacements);
+        continue;
+      }
     }
 
     // Try to match one of the patterns. The rewriter is automatically
@@ -567,7 +600,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
 
   addOperandsToWorklist(op->getOperands());
   worklist.remove(op);
-  folder.notifyRemoval(op);
 
   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.erase(op);
@@ -647,16 +679,6 @@ class GreedyPatternRewriteIteration
 } // namespace
 
 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
-  auto insertKnownConstant = [&](Operation *op) {
-    // Check for existing constants when populating the worklist. This avoids
-    // accidentally reversing the constant order during processing.
-    Attribute constValue;
-    if (matchPattern(op, m_Constant(&constValue)))
-      if (!folder.insertKnownConstant(op, constValue))
-        return true;
-    return false;
-  };
-
   bool continueRewrites = false;
   int64_t iteration = 0;
   MLIRContext *ctx = getContext();
@@ -666,8 +688,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
         config.maxIterations != GreedyRewriteConfig::kNoLimit)
       break;
 
+    // New iteration: start with an empty worklist.
     worklist.clear();
 
+    // `OperationFolder` CSE's constant ops (and may move them into parents
+    // regions to enable more aggressive CSE'ing).
+    OperationFolder folder(getContext(), this);
+    auto insertKnownConstant = [&](Operation *op) {
+      // Check for existing constants when populating the worklist. This avoids
+      // accidentally reversing the constant order during processing.
+      Attribute constValue;
+      if (matchPattern(op, m_Constant(&constValue)))
+        if (!folder.insertKnownConstant(op, constValue))
+          return true;
+      return false;
+    };
+
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
       region.walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 6783263c184961..d3f02c6288a240 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
 
 // CHECK-LABEL:   func.func @broadcast_vec2d_from_i32(
 // CHECK-SAME:                                        %[[SRC:.*]]: i32) {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {
 
 // CHECK-LABEL:   func.func @transpose_i8(
 // CHECK-SAME:                            %[[TILE:.*]]: vector<[16]x[16]xi8>)
-// CHECK:           %[[C16:.*]] = arith.constant 16 : index
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
 // CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..90f134d57a3c13 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
 }
 // CHECK-LABEL: @broadcast_vec3d_from_vec1d(
 // CHECK-SAME:  %[[A:.*]]: vector<2xf32>)
-// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK:       %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
-// CHECK:       %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK-DAG:   %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK-DAG:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK-DAG:   %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK-DAG:   %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
 
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
 // CHECK:       %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..97eda0672c0299 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
 }
 
 // CHECK-SAME:      %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
-// CHECK:           %[[C_0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C_16:.*]] = arith.constant 16 : index
-// CHECK:           %[[STEP:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C_16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[STEP:.*]] = arith.constant 1 : index
 // CHECK:           %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
@@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
 }
 // CHECK-LABEL:   func.func @vector_print_vector_0d(
 // CHECK-SAME:                                      %[[VEC:.*]]: vector<f32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
@@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
 }
 // CHECK-LABEL:   func.func @vector_print_vector(
 // CHECK-SAME:                                   %[[VEC:.*]]: vector<2x2xf32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C2:.*]] = arith.constant 2 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
 }
 // CHECK-LABEL:   func.func @transfer_read_array_of_scalable(
 // CHECK-SAME:                                               %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
-// CHECK:           %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C3:.*]] = arith.constant 3 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
 // CHECK:           %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
 // CHECK:           %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
@@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
 // CHECK-LABEL:   func.func @transfer_write_array_of_scalable(
 // CHECK-SAME:                                                %[[VEC:.*]]: vector<3x[4]xf32>,
 // CHECK-SAME:                                                %[[MEMREF:.*]]: memref<3x?xf32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C3:.*]] = arith.constant 3 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
 // CHECK:           %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
 // CHECK:           %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index ae2d0f40f03af5..e51f2485dadbcc 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
 // -----
 
 // CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
@@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
 // -----
 
 // CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 3a1ab924ebdacb..021151b929d8e2 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -168,8 +168,8 @@ llvm.func @no_crash_on_negative_gep_index() {
 // CHECK-LABEL: llvm.func @coalesced_store_ints
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_ints(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
 
   %0 = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
@@ -193,8 +193,8 @@ llvm.func @coalesced_store_ints(%arg: i64) {
 // CHECK-LABEL: llvm.func @coalesced_store_ints_offset
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_ints_offset(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
   %0 = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr
@@ -218,8 +218,8 @@ llvm.func @coalesced_store_ints_offset(%arg: i64) {
 // CHECK-LABEL: llvm.func @coalesced_store_floats
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_floats(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
   %0 = llvm.mlir.constant(1 : i32) : i32
 
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)>
@@ -292,9 +292,9 @@ llvm.func @coalesced_store_past_end(%arg: i64) {
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_packed_struct(%arg: i64) {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
-  // CHECK: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
+  // CHECK-DAG: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
 
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr
@@ -320,10 +320,10 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
 // CHECK-LABEL: llvm.func @vector_write_split
 // CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
 llvm.func @vector_write_split(%arg: vector<4xi32>) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Dec 19, 2023

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

Changes

The GreedyPatternRewriteDriver tries to iteratively fold ops and apply rewrite patterns to ops. It has special handling for constants: they are CSE'd and sometimes moved to parent regions to allow for additional CSE'ing. This happens in OperationFolder.

To allow for efficient CSE'ing, OperationFolder maintains an internal lookup data structure to find the existing constant ops with the same value for each IsolatedFromAbove region:

/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap&lt;Region *, ConstantMap&gt; foldScopes;

Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent.

We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different IsolatedFromAbove region (that is not accessible).

This commit changes the behavior of the GreedyPatternRewriteDriver such that OperationFolder is used to CSE constants at the beginning of each iteration (as the worklist is populated), but no longer during an iteration. OperationFolder is no longer used after populating the worklist, so we do not have to care about inconsistent state in the OperationFolder due to IR rewrites. The GreedyPatternRewriteDriver now performs the op folding by itself instead of calling OperationFolder::tryToFold.

This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning CHECK into CHECK-DAG.

Alternatives considered: The state of OperationFolder could be partially invalidated with every notifyOperationModified notification. That is more fragile than the solution in this commit because incorrect rewriter API usage can lead to missing notifications and hard-to-debug IsolatedFromAbove violations. (It did not fix the above mention bug in a downstream project, which could be due to incorrect rewriter API usage or due to another conceptual problem that I missed.) Moreover, ops are frequently getting modified during a greedy pattern rewrite, so we would likely keep invalidating large parts of the state of OperationFolder over and over.

Depends on #75887. Review only the top commit.


Patch is 86.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75897.diff

35 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-4)
  • (modified) mlir/lib/IR/Builders.cpp (+2-3)
  • (modified) mlir/lib/Transforms/Utils/FoldUtils.cpp (+2-4)
  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+55-19)
  • (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (+5-5)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+4-4)
  • (modified) mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir (+15-15)
  • (modified) mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir (+8-8)
  • (modified) mlir/test/Dialect/LLVMIR/type-consistency.mlir (+29-29)
  • (modified) mlir/test/Dialect/Linalg/loops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+48-48)
  • (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+14-14)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+12-12)
  • (modified) mlir/test/Dialect/Math/polynomial-approximation.mlir (+36-36)
  • (modified) mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir (+6-6)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+3-3)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_1d.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_affine.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_concat.mlir (+4-4)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_storage.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir (+8-8)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+6-9)
  • (modified) mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir (+3-3)
  • (modified) mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir (+7-7)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+17-17)
  • (modified) mlir/test/Dialect/Vector/vector-scalable-create-mask-lowering.mlir (+2-2)
  • (modified) mlir/test/Transforms/test-canonicalize.mlir (-13)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (-10)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (-4)
  • (modified) mlir/test/lib/Dialect/Test/TestOps.td (-7)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
       setOperand(src);
       return getResult();
     }
+
     // trunci(zexti(a)) -> a
     // trunci(sexti(a)) -> a
-    return src;
+    if (srcType == dstType)
+      return src;
   }
 
   // trunci(trunci(a)) -> trunci(a))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..ac9485326a32ed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1600,11 +1600,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
     return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
                                        : 0;
   };
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0)
-    return source;
 
+  unsigned broadcastSrcRank = getRank(source.getType());
   unsigned extractResultRank = getRank(extractOp.getType());
   if (extractResultRank >= broadcastSrcRank)
     return Value();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..c28cbe109c3ffd 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -491,9 +491,8 @@ LogicalResult OpBuilder::tryFold(Operation *op,
 
     // Normal values get pushed back directly.
     if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
-      if (value.getType() != expectedType)
-        return cleanupFailure();
-
+      assert(value.getType() == expectedType &&
+             "folder produced value of incorrect type");
       results.push_back(value);
       continue;
     }
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..34b7117a035748 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,8 @@ OperationFolder::processFoldResults(Operation *op,
 
     // Check if the result was an SSA value.
     if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
-      if (repl.getType() != op->getResult(i).getType()) {
-        results.clear();
-        return failure();
-      }
+      assert(repl.getType() == op->getResult(i).getType() &&
+             "folder produced value of incorrect type");
       results.emplace_back(repl);
       continue;
     }
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7decbce018a878..eb63050c6c3354 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -313,9 +313,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   Worklist worklist;
 #endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
 
-  /// Non-pattern based folder for operations.
-  OperationFolder folder;
-
   /// Configuration information for how to simplify.
   const GreedyRewriteConfig config;
 
@@ -428,11 +425,47 @@ bool GreedyPatternRewriteDriver::processWorklist() {
       continue;
     }
 
-    // Try to fold this op.
-    if (succeeded(folder.tryToFold(op))) {
-      LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-      changed = true;
-      continue;
+    // Try to fold this op. Do not fold constant ops. That would lead to an
+    // infinite folding loop, as every constant op would be folded to an
+    // Attribute and then immediately be rematerialized as a constant op, which
+    // is then put on the worklist.
+    if (!op->hasTrait<OpTrait::ConstantLike>()) {
+      SmallVector<OpFoldResult> foldResults;
+      if (succeeded(op->fold(foldResults))) {
+        LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
+        changed = true;
+        if (foldResults.empty()) {
+          // Op was modified in-place.
+          notifyOperationModified(op);
+          continue;
+        }
+
+        // Op results can be replaced with `foldResults`.
+        assert(foldResults.size() == op->getNumResults() &&
+               "folder produced incorrect number of results");
+        OpBuilder::InsertionGuard g(*this);
+        setInsertionPoint(op);
+        SmallVector<Value> replacements;
+        for (auto [ofr, resultType] :
+             llvm::zip_equal(foldResults, op->getResultTypes())) {
+          if (auto value = ofr.dyn_cast<Value>()) {
+            assert(value.getType() == resultType &&
+                   "folder produced value of incorrect type");
+            replacements.push_back(value);
+            continue;
+          }
+          // Materialize Attributes as SSA values.
+          Operation *constOp = op->getDialect()->materializeConstant(
+              *this, ofr.get<Attribute>(), resultType, op->getLoc());
+          assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
+                 "materializeConstant produced op that is not a ConstantLike");
+          assert(constOp->getResultTypes()[0] == resultType &&
+                 "materializeConstant produced incorrect result type");
+          replacements.push_back(constOp->getResult(0));
+        }
+        replaceOp(op, replacements);
+        continue;
+      }
     }
 
     // Try to match one of the patterns. The rewriter is automatically
@@ -567,7 +600,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
 
   addOperandsToWorklist(op->getOperands());
   worklist.remove(op);
-  folder.notifyRemoval(op);
 
   if (config.strictMode != GreedyRewriteStrictness::AnyOp)
     strictModeFilteredOps.erase(op);
@@ -647,16 +679,6 @@ class GreedyPatternRewriteIteration
 } // namespace
 
 LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
-  auto insertKnownConstant = [&](Operation *op) {
-    // Check for existing constants when populating the worklist. This avoids
-    // accidentally reversing the constant order during processing.
-    Attribute constValue;
-    if (matchPattern(op, m_Constant(&constValue)))
-      if (!folder.insertKnownConstant(op, constValue))
-        return true;
-    return false;
-  };
-
   bool continueRewrites = false;
   int64_t iteration = 0;
   MLIRContext *ctx = getContext();
@@ -666,8 +688,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
         config.maxIterations != GreedyRewriteConfig::kNoLimit)
       break;
 
+    // New iteration: start with an empty worklist.
     worklist.clear();
 
+    // `OperationFolder` CSE's constant ops (and may move them into parents
+    // regions to enable more aggressive CSE'ing).
+    OperationFolder folder(getContext(), this);
+    auto insertKnownConstant = [&](Operation *op) {
+      // Check for existing constants when populating the worklist. This avoids
+      // accidentally reversing the constant order during processing.
+      Attribute constValue;
+      if (matchPattern(op, m_Constant(&constValue)))
+        if (!folder.insertKnownConstant(op, constValue))
+          return true;
+      return false;
+    };
+
     if (!config.useTopDownTraversal) {
       // Add operations to the worklist in postorder.
       region.walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 6783263c184961..d3f02c6288a240 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
 
 // CHECK-LABEL:   func.func @broadcast_vec2d_from_i32(
 // CHECK-SAME:                                        %[[SRC:.*]]: i32) {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {
 
 // CHECK-LABEL:   func.func @transpose_i8(
 // CHECK-SAME:                            %[[TILE:.*]]: vector<[16]x[16]xi8>)
-// CHECK:           %[[C16:.*]] = arith.constant 16 : index
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
 // CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..90f134d57a3c13 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
 }
 // CHECK-LABEL: @broadcast_vec3d_from_vec1d(
 // CHECK-SAME:  %[[A:.*]]: vector<2xf32>)
-// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK:       %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
-// CHECK:       %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK-DAG:   %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK-DAG:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK-DAG:   %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK-DAG:   %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
 
 // CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
 // CHECK:       %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..97eda0672c0299 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
 }
 
 // CHECK-SAME:      %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
-// CHECK:           %[[C_0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C_16:.*]] = arith.constant 16 : index
-// CHECK:           %[[STEP:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C_16:.*]] = arith.constant 16 : index
+// CHECK-DAG:       %[[STEP:.*]] = arith.constant 1 : index
 // CHECK:           %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
 // CHECK:           %[[VSCALE:.*]] = vector.vscale
 // CHECK:           %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
@@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
 }
 // CHECK-LABEL:   func.func @vector_print_vector_0d(
 // CHECK-SAME:                                      %[[VEC:.*]]: vector<f32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
@@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
 }
 // CHECK-LABEL:   func.func @vector_print_vector(
 // CHECK-SAME:                                   %[[VEC:.*]]: vector<2x2xf32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C2:.*]] = arith.constant 2 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
 // CHECK:           vector.print punctuation <open>
 // CHECK:           scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
 }
 // CHECK-LABEL:   func.func @transfer_read_array_of_scalable(
 // CHECK-SAME:                                               %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
-// CHECK:           %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C3:.*]] = arith.constant 3 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
 // CHECK:           %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
 // CHECK:           %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
@@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
 // CHECK-LABEL:   func.func @transfer_write_array_of_scalable(
 // CHECK-SAME:                                                %[[VEC:.*]]: vector<3x[4]xf32>,
 // CHECK-SAME:                                                %[[MEMREF:.*]]: memref<3x?xf32>) {
-// CHECK:           %[[C0:.*]] = arith.constant 0 : index
-// CHECK:           %[[C3:.*]] = arith.constant 3 : index
-// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:       %[[C1:.*]] = arith.constant 1 : index
 // CHECK:           %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
 // CHECK:           %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
 // CHECK:           %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index ae2d0f40f03af5..e51f2485dadbcc 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
 // -----
 
 // CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
@@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
 // -----
 
 // CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 3a1ab924ebdacb..021151b929d8e2 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -168,8 +168,8 @@ llvm.func @no_crash_on_negative_gep_index() {
 // CHECK-LABEL: llvm.func @coalesced_store_ints
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_ints(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
 
   %0 = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
@@ -193,8 +193,8 @@ llvm.func @coalesced_store_ints(%arg: i64) {
 // CHECK-LABEL: llvm.func @coalesced_store_ints_offset
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_ints_offset(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
   %0 = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr
@@ -218,8 +218,8 @@ llvm.func @coalesced_store_ints_offset(%arg: i64) {
 // CHECK-LABEL: llvm.func @coalesced_store_floats
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_floats(%arg: i64) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
   %0 = llvm.mlir.constant(1 : i32) : i32
 
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)>
@@ -292,9 +292,9 @@ llvm.func @coalesced_store_past_end(%arg: i64) {
 // CHECK-SAME: %[[ARG:.*]]: i64
 llvm.func @coalesced_store_packed_struct(%arg: i64) {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
-  // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
-  // CHECK: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
+  // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+  // CHECK-DAG: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
+  // CHECK-DAG: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
 
   // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)>
   %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr
@@ -320,10 +320,10 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
 // CHECK-LABEL: llvm.func @vector_write_split
 // CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
 llvm.func @vector_write_split(%arg: vector<4xi32>) {
-  // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32...
[truncated]

@joker-eph
Copy link
Collaborator

Depends on #75887. Review only the top commit.

We have support for stacked PR by pushing branches to the LLVM repo and opening cascading PR (the second PR is targeting the branch from the first PR).

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me. Previously the constant folding optimization limited the patterns that can be used. This change gives most of the same folding optimization while avoiding that. (we may need to hoist out some helpers there eventually). I'll wait for someone else to chime in too given change centrality, but this looks close to least disruptive change that avoids the issue.

@@ -273,7 +273,14 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp,

/// If the op has only a single statement (apart from the yield), do nothing.
Block *body = genericOp.getBody();
if (body->getOperations().size() <= 2) {
int64_t numOps = 0;
for (Operation &op : body->getOperations()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change related?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The rewrite pattern here was poorly designed and depends on constant hoisting to avoid getting into an infinite loop.

The rewrite pattern here decomposes a LinalgOp into multiple LinalgOps with one nested op per decomposed op. The rewrite pattern sometimes materializes constants inside of the region. These are now no longer hoisted immediately, so the rewrite pattern is applied again (decomposed LinalgOp now has 2 ops) and again (infinite loop). The hoisting now occurs at the beginning of a greedy pattern iteration, but there can be multiple pattern applications (of the same pattern) in one iteration.

The fix that I had here was not ideal. The issue in DecomposeLinalgOps.cpp is now fixed in a different way.

@matthias-springer matthias-springer force-pushed the operation_folder_state branch 3 times, most recently from 4b7bb8f to ccfef12 Compare December 22, 2023 07:48
@jpienaar
Copy link
Member

jpienaar commented Jan 2, 2024

Looks good to me, sufficient time for additional feedback so feel free to merge.

@matthias-springer
Copy link
Member Author

Looks good to me, sufficient time for additional feedback so feel free to merge.

Thx, can you also approve the PR on Github?

… during iterations

The `GreedyPatternRewriteDriver` tries to iteratively fold ops and apply rewrite patterns to ops. It has special handling for constants: they are CSE'd and sometimes moved to parent regions to allow for additional CSE'ing. This happens in `OperationFolder`.

To allow for efficient CSE'ing, `OperationFolder` maintains an internal lookup data structure to find the existing constant ops with the same value for each `IsolatedFromAbove` region:
```c++
/// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes;
```

Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent.

We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different `IsolatedFromAbove` region (that is not accessible).

This commit changes the behavior of the `GreedyPatternRewriteDriver` such that `OperationFolder` is used to CSE constants at the beginning of each iteration (as the worklist is populated), but no longer during an iteration. `OperationFolder` is no longer used after populating the worklist, so we do not have to care about inconsistent state in the `OperationFolder` due to IR rewrites. The `GreedyPatternRewriteDriver` now performs the op folding by itself instead of calling `OperationFolder::tryToFold`.

This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning `CHECK` into `CHECK-DAG`.

Alternatives considered: The state of `OperationFolder` could be partially invalidated with every `notifyOperationModified` notification. That is more fragile than the solution in this commit because incorrect rewriter API usage can lead to missing notifications and hard-to-debug `IsolatedFromAbove` violations. (It did not fix the above mention bug in a downstream project, which could be due to incorrect rewriter API usage or due to another conceptual problem that I missed.) Moreover, ops are frequently getting modified during a greedy pattern rewrite, so we would likely keep invalidating large parts of the state of `OperationFolder` over and over.
@matthias-springer matthias-springer merged commit bb6d5c2 into llvm:main Jan 5, 2024
4 checks passed
@MaheshRavishankar
Copy link
Contributor

Just bubbling an issue from downstream project here. This pattern seems to have caused some major regression in folding downstream iree-org/iree#16073 (comment) .

I havent had time to triage this more to see if the fix is downstream or upstream, but posting here for visibility.

@joker-eph
Copy link
Collaborator

Is there a limit on the number of iterations in IREE? I understood the change here as in that it shouldn't affect (much) the result because CSE would still be done before every iteration and so reaching fixed-point means we have successfully applied CSE and no pattern triggered.

@sjain-stanford
Copy link

sjain-stanford commented Feb 3, 2024

@matthias-springer we're seeing a bunch of downstream tests break that we've bisected to this change. These are not just lit tests (those were easy to fix with s/CHECK/CHECK-DAG) but more like e2e integration tests, so we're still unsure what's the right fix here. The main phenomenon we seem to observe is when running the rewrite pass in debug mode:

The pattern rewrite did not converge after scanning 10 times

Any ideas on what might be going on? Sorry but the repo is internal so we're unable to point to the pass that's failing.
cc: @sjarus, @zezhang

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants