-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] GreedyPatternRewriteDriver
: Do not CSE constants during iterations
#75897
[mlir][Transforms] GreedyPatternRewriteDriver
: Do not CSE constants during iterations
#75897
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-math Author: Matthias Springer (matthias-springer) ChangesThe To allow for efficient CSE'ing, /// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes; Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent. We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different This commit changes the behavior of the This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning Alternatives considered: The state of Depends on #75887. Review only the top commit. Patch is 86.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75897.diff 35 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
setOperand(src);
return getResult();
}
+
// trunci(zexti(a)) -> a
// trunci(sexti(a)) -> a
- return src;
+ if (srcType == dstType)
+ return src;
}
// trunci(trunci(a)) -> trunci(a))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..ac9485326a32ed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1600,11 +1600,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
: 0;
};
- // If splat or broadcast from a scalar, just return the source scalar.
- unsigned broadcastSrcRank = getRank(source.getType());
- if (broadcastSrcRank == 0)
- return source;
+ unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank >= broadcastSrcRank)
return Value();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..c28cbe109c3ffd 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -491,9 +491,8 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// Normal values get pushed back directly.
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
- if (value.getType() != expectedType)
- return cleanupFailure();
-
+ assert(value.getType() == expectedType &&
+ "folder produced value of incorrect type");
results.push_back(value);
continue;
}
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..34b7117a035748 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,8 @@ OperationFolder::processFoldResults(Operation *op,
// Check if the result was an SSA value.
if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
- if (repl.getType() != op->getResult(i).getType()) {
- results.clear();
- return failure();
- }
+ assert(repl.getType() == op->getResult(i).getType() &&
+ "folder produced value of incorrect type");
results.emplace_back(repl);
continue;
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7decbce018a878..eb63050c6c3354 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -313,9 +313,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
Worklist worklist;
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
- /// Non-pattern based folder for operations.
- OperationFolder folder;
-
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
@@ -428,11 +425,47 @@ bool GreedyPatternRewriteDriver::processWorklist() {
continue;
}
- // Try to fold this op.
- if (succeeded(folder.tryToFold(op))) {
- LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
- changed = true;
- continue;
+ // Try to fold this op. Do not fold constant ops. That would lead to an
+ // infinite folding loop, as every constant op would be folded to an
+ // Attribute and then immediately be rematerialized as a constant op, which
+ // is then put on the worklist.
+ if (!op->hasTrait<OpTrait::ConstantLike>()) {
+ SmallVector<OpFoldResult> foldResults;
+ if (succeeded(op->fold(foldResults))) {
+ LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
+ changed = true;
+ if (foldResults.empty()) {
+ // Op was modified in-place.
+ notifyOperationModified(op);
+ continue;
+ }
+
+ // Op results can be replaced with `foldResults`.
+ assert(foldResults.size() == op->getNumResults() &&
+ "folder produced incorrect number of results");
+ OpBuilder::InsertionGuard g(*this);
+ setInsertionPoint(op);
+ SmallVector<Value> replacements;
+ for (auto [ofr, resultType] :
+ llvm::zip_equal(foldResults, op->getResultTypes())) {
+ if (auto value = ofr.dyn_cast<Value>()) {
+ assert(value.getType() == resultType &&
+ "folder produced value of incorrect type");
+ replacements.push_back(value);
+ continue;
+ }
+ // Materialize Attributes as SSA values.
+ Operation *constOp = op->getDialect()->materializeConstant(
+ *this, ofr.get<Attribute>(), resultType, op->getLoc());
+ assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
+ "materializeConstant produced op that is not a ConstantLike");
+ assert(constOp->getResultTypes()[0] == resultType &&
+ "materializeConstant produced incorrect result type");
+ replacements.push_back(constOp->getResult(0));
+ }
+ replaceOp(op, replacements);
+ continue;
+ }
}
// Try to match one of the patterns. The rewriter is automatically
@@ -567,7 +600,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
addOperandsToWorklist(op->getOperands());
worklist.remove(op);
- folder.notifyRemoval(op);
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
@@ -647,16 +679,6 @@ class GreedyPatternRewriteIteration
} // namespace
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
- auto insertKnownConstant = [&](Operation *op) {
- // Check for existing constants when populating the worklist. This avoids
- // accidentally reversing the constant order during processing.
- Attribute constValue;
- if (matchPattern(op, m_Constant(&constValue)))
- if (!folder.insertKnownConstant(op, constValue))
- return true;
- return false;
- };
-
bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
@@ -666,8 +688,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;
+ // New iteration: start with an empty worklist.
worklist.clear();
+ // `OperationFolder` CSE's constant ops (and may move them into parents
+ // regions to enable more aggressive CSE'ing).
+ OperationFolder folder(getContext(), this);
+ auto insertKnownConstant = [&](Operation *op) {
+ // Check for existing constants when populating the worklist. This avoids
+ // accidentally reversing the constant order during processing.
+ Attribute constValue;
+ if (matchPattern(op, m_Constant(&constValue)))
+ if (!folder.insertKnownConstant(op, constValue))
+ return true;
+ return false;
+ };
+
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 6783263c184961..d3f02c6288a240 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
// CHECK-LABEL: func.func @broadcast_vec2d_from_i32(
// CHECK-SAME: %[[SRC:.*]]: i32) {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {
// CHECK-LABEL: func.func @transpose_i8(
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>)
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..90f134d57a3c13 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
}
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
-// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
-// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..97eda0672c0299 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
}
// CHECK-SAME: %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
-// CHECK: %[[C_0:.*]] = arith.constant 0 : index
-// CHECK: %[[C_16:.*]] = arith.constant 16 : index
-// CHECK: %[[STEP:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C_16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
@@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
}
// CHECK-LABEL: func.func @vector_print_vector_0d(
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
@@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
}
// CHECK-LABEL: func.func @vector_print_vector(
// CHECK-SAME: %[[VEC:.*]]: vector<2x2xf32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
}
// CHECK-LABEL: func.func @transfer_read_array_of_scalable(
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
-// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
@@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
// CHECK-LABEL: func.func @transfer_write_array_of_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index ae2d0f40f03af5..e51f2485dadbcc 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
// -----
// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
@@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
// -----
// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 3a1ab924ebdacb..021151b929d8e2 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -168,8 +168,8 @@ llvm.func @no_crash_on_negative_gep_index() {
// CHECK-LABEL: llvm.func @coalesced_store_ints
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_ints(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
@@ -193,8 +193,8 @@ llvm.func @coalesced_store_ints(%arg: i64) {
// CHECK-LABEL: llvm.func @coalesced_store_ints_offset
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_ints_offset(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr
@@ -218,8 +218,8 @@ llvm.func @coalesced_store_ints_offset(%arg: i64) {
// CHECK-LABEL: llvm.func @coalesced_store_floats
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_floats(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)>
@@ -292,9 +292,9 @@ llvm.func @coalesced_store_past_end(%arg: i64) {
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_packed_struct(%arg: i64) {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
- // CHECK: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
+ // CHECK-DAG: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr
@@ -320,10 +320,10 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
// CHECK-LABEL: llvm.func @vector_write_split
// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
llvm.func @vector_write_split(%arg: vector<4xi32>) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32...
[truncated]
|
@llvm/pr-subscribers-mlir-arith Author: Matthias Springer (matthias-springer) ChangesThe To allow for efficient CSE'ing, /// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes; Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent. We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different This commit changes the behavior of the This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning Alternatives considered: The state of Depends on #75887. Review only the top commit. Patch is 86.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75897.diff 35 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
setOperand(src);
return getResult();
}
+
// trunci(zexti(a)) -> a
// trunci(sexti(a)) -> a
- return src;
+ if (srcType == dstType)
+ return src;
}
// trunci(trunci(a)) -> trunci(a))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..ac9485326a32ed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1600,11 +1600,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
: 0;
};
- // If splat or broadcast from a scalar, just return the source scalar.
- unsigned broadcastSrcRank = getRank(source.getType());
- if (broadcastSrcRank == 0)
- return source;
+ unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank >= broadcastSrcRank)
return Value();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..c28cbe109c3ffd 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -491,9 +491,8 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// Normal values get pushed back directly.
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
- if (value.getType() != expectedType)
- return cleanupFailure();
-
+ assert(value.getType() == expectedType &&
+ "folder produced value of incorrect type");
results.push_back(value);
continue;
}
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..34b7117a035748 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,8 @@ OperationFolder::processFoldResults(Operation *op,
// Check if the result was an SSA value.
if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
- if (repl.getType() != op->getResult(i).getType()) {
- results.clear();
- return failure();
- }
+ assert(repl.getType() == op->getResult(i).getType() &&
+ "folder produced value of incorrect type");
results.emplace_back(repl);
continue;
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7decbce018a878..eb63050c6c3354 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -313,9 +313,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
Worklist worklist;
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
- /// Non-pattern based folder for operations.
- OperationFolder folder;
-
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
@@ -428,11 +425,47 @@ bool GreedyPatternRewriteDriver::processWorklist() {
continue;
}
- // Try to fold this op.
- if (succeeded(folder.tryToFold(op))) {
- LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
- changed = true;
- continue;
+ // Try to fold this op. Do not fold constant ops. That would lead to an
+ // infinite folding loop, as every constant op would be folded to an
+ // Attribute and then immediately be rematerialized as a constant op, which
+ // is then put on the worklist.
+ if (!op->hasTrait<OpTrait::ConstantLike>()) {
+ SmallVector<OpFoldResult> foldResults;
+ if (succeeded(op->fold(foldResults))) {
+ LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
+ changed = true;
+ if (foldResults.empty()) {
+ // Op was modified in-place.
+ notifyOperationModified(op);
+ continue;
+ }
+
+ // Op results can be replaced with `foldResults`.
+ assert(foldResults.size() == op->getNumResults() &&
+ "folder produced incorrect number of results");
+ OpBuilder::InsertionGuard g(*this);
+ setInsertionPoint(op);
+ SmallVector<Value> replacements;
+ for (auto [ofr, resultType] :
+ llvm::zip_equal(foldResults, op->getResultTypes())) {
+ if (auto value = ofr.dyn_cast<Value>()) {
+ assert(value.getType() == resultType &&
+ "folder produced value of incorrect type");
+ replacements.push_back(value);
+ continue;
+ }
+ // Materialize Attributes as SSA values.
+ Operation *constOp = op->getDialect()->materializeConstant(
+ *this, ofr.get<Attribute>(), resultType, op->getLoc());
+ assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
+ "materializeConstant produced op that is not a ConstantLike");
+ assert(constOp->getResultTypes()[0] == resultType &&
+ "materializeConstant produced incorrect result type");
+ replacements.push_back(constOp->getResult(0));
+ }
+ replaceOp(op, replacements);
+ continue;
+ }
}
// Try to match one of the patterns. The rewriter is automatically
@@ -567,7 +600,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
addOperandsToWorklist(op->getOperands());
worklist.remove(op);
- folder.notifyRemoval(op);
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
@@ -647,16 +679,6 @@ class GreedyPatternRewriteIteration
} // namespace
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
- auto insertKnownConstant = [&](Operation *op) {
- // Check for existing constants when populating the worklist. This avoids
- // accidentally reversing the constant order during processing.
- Attribute constValue;
- if (matchPattern(op, m_Constant(&constValue)))
- if (!folder.insertKnownConstant(op, constValue))
- return true;
- return false;
- };
-
bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
@@ -666,8 +688,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;
+ // New iteration: start with an empty worklist.
worklist.clear();
+ // `OperationFolder` CSE's constant ops (and may move them into parents
+ // regions to enable more aggressive CSE'ing).
+ OperationFolder folder(getContext(), this);
+ auto insertKnownConstant = [&](Operation *op) {
+ // Check for existing constants when populating the worklist. This avoids
+ // accidentally reversing the constant order during processing.
+ Attribute constValue;
+ if (matchPattern(op, m_Constant(&constValue)))
+ if (!folder.insertKnownConstant(op, constValue))
+ return true;
+ return false;
+ };
+
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 6783263c184961..d3f02c6288a240 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
// CHECK-LABEL: func.func @broadcast_vec2d_from_i32(
// CHECK-SAME: %[[SRC:.*]]: i32) {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {
// CHECK-LABEL: func.func @transpose_i8(
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>)
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..90f134d57a3c13 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
}
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
-// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
-// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..97eda0672c0299 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
}
// CHECK-SAME: %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
-// CHECK: %[[C_0:.*]] = arith.constant 0 : index
-// CHECK: %[[C_16:.*]] = arith.constant 16 : index
-// CHECK: %[[STEP:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C_16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
@@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
}
// CHECK-LABEL: func.func @vector_print_vector_0d(
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
@@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
}
// CHECK-LABEL: func.func @vector_print_vector(
// CHECK-SAME: %[[VEC:.*]]: vector<2x2xf32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
}
// CHECK-LABEL: func.func @transfer_read_array_of_scalable(
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
-// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
@@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
// CHECK-LABEL: func.func @transfer_write_array_of_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index ae2d0f40f03af5..e51f2485dadbcc 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
// -----
// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
@@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
// -----
// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 3a1ab924ebdacb..021151b929d8e2 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -168,8 +168,8 @@ llvm.func @no_crash_on_negative_gep_index() {
// CHECK-LABEL: llvm.func @coalesced_store_ints
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_ints(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
@@ -193,8 +193,8 @@ llvm.func @coalesced_store_ints(%arg: i64) {
// CHECK-LABEL: llvm.func @coalesced_store_ints_offset
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_ints_offset(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr
@@ -218,8 +218,8 @@ llvm.func @coalesced_store_ints_offset(%arg: i64) {
// CHECK-LABEL: llvm.func @coalesced_store_floats
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_floats(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)>
@@ -292,9 +292,9 @@ llvm.func @coalesced_store_past_end(%arg: i64) {
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_packed_struct(%arg: i64) {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
- // CHECK: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
+ // CHECK-DAG: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr
@@ -320,10 +320,10 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
// CHECK-LABEL: llvm.func @vector_write_split
// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
llvm.func @vector_write_split(%arg: vector<4xi32>) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32...
[truncated]
|
@llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) ChangesThe To allow for efficient CSE'ing, /// A mapping between an insertion region and the constants that have been
/// created within it.
DenseMap<Region *, ConstantMap> foldScopes; Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent. We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different This commit changes the behavior of the This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning Alternatives considered: The state of Depends on #75887. Review only the top commit. Patch is 86.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75897.diff 35 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 56d5e0fed76185..ff72becc8dfa77 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1352,9 +1352,11 @@ OpFoldResult arith::TruncIOp::fold(FoldAdaptor adaptor) {
setOperand(src);
return getResult();
}
+
// trunci(zexti(a)) -> a
// trunci(sexti(a)) -> a
- return src;
+ if (srcType == dstType)
+ return src;
}
// trunci(trunci(a)) -> trunci(a))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..ac9485326a32ed 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1600,11 +1600,8 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
: 0;
};
- // If splat or broadcast from a scalar, just return the source scalar.
- unsigned broadcastSrcRank = getRank(source.getType());
- if (broadcastSrcRank == 0)
- return source;
+ unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank >= broadcastSrcRank)
return Value();
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 2cabfcd24d3559..c28cbe109c3ffd 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -491,9 +491,8 @@ LogicalResult OpBuilder::tryFold(Operation *op,
// Normal values get pushed back directly.
if (auto value = llvm::dyn_cast_if_present<Value>(std::get<0>(it))) {
- if (value.getType() != expectedType)
- return cleanupFailure();
-
+ assert(value.getType() == expectedType &&
+ "folder produced value of incorrect type");
results.push_back(value);
continue;
}
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 90ee5ba51de3ad..34b7117a035748 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -247,10 +247,8 @@ OperationFolder::processFoldResults(Operation *op,
// Check if the result was an SSA value.
if (auto repl = llvm::dyn_cast_if_present<Value>(foldResults[i])) {
- if (repl.getType() != op->getResult(i).getType()) {
- results.clear();
- return failure();
- }
+ assert(repl.getType() == op->getResult(i).getType() &&
+ "folder produced value of incorrect type");
results.emplace_back(repl);
continue;
}
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 7decbce018a878..eb63050c6c3354 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -313,9 +313,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
Worklist worklist;
#endif // MLIR_GREEDY_REWRITE_RANDOMIZER_SEED
- /// Non-pattern based folder for operations.
- OperationFolder folder;
-
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
@@ -428,11 +425,47 @@ bool GreedyPatternRewriteDriver::processWorklist() {
continue;
}
- // Try to fold this op.
- if (succeeded(folder.tryToFold(op))) {
- LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
- changed = true;
- continue;
+ // Try to fold this op. Do not fold constant ops. That would lead to an
+ // infinite folding loop, as every constant op would be folded to an
+ // Attribute and then immediately be rematerialized as a constant op, which
+ // is then put on the worklist.
+ if (!op->hasTrait<OpTrait::ConstantLike>()) {
+ SmallVector<OpFoldResult> foldResults;
+ if (succeeded(op->fold(foldResults))) {
+ LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
+ changed = true;
+ if (foldResults.empty()) {
+ // Op was modified in-place.
+ notifyOperationModified(op);
+ continue;
+ }
+
+ // Op results can be replaced with `foldResults`.
+ assert(foldResults.size() == op->getNumResults() &&
+ "folder produced incorrect number of results");
+ OpBuilder::InsertionGuard g(*this);
+ setInsertionPoint(op);
+ SmallVector<Value> replacements;
+ for (auto [ofr, resultType] :
+ llvm::zip_equal(foldResults, op->getResultTypes())) {
+ if (auto value = ofr.dyn_cast<Value>()) {
+ assert(value.getType() == resultType &&
+ "folder produced value of incorrect type");
+ replacements.push_back(value);
+ continue;
+ }
+ // Materialize Attributes as SSA values.
+ Operation *constOp = op->getDialect()->materializeConstant(
+ *this, ofr.get<Attribute>(), resultType, op->getLoc());
+ assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
+ "materializeConstant produced op that is not a ConstantLike");
+ assert(constOp->getResultTypes()[0] == resultType &&
+ "materializeConstant produced incorrect result type");
+ replacements.push_back(constOp->getResult(0));
+ }
+ replaceOp(op, replacements);
+ continue;
+ }
}
// Try to match one of the patterns. The rewriter is automatically
@@ -567,7 +600,6 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
addOperandsToWorklist(op->getOperands());
worklist.remove(op);
- folder.notifyRemoval(op);
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
@@ -647,16 +679,6 @@ class GreedyPatternRewriteIteration
} // namespace
LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
- auto insertKnownConstant = [&](Operation *op) {
- // Check for existing constants when populating the worklist. This avoids
- // accidentally reversing the constant order during processing.
- Attribute constValue;
- if (matchPattern(op, m_Constant(&constValue)))
- if (!folder.insertKnownConstant(op, constValue))
- return true;
- return false;
- };
-
bool continueRewrites = false;
int64_t iteration = 0;
MLIRContext *ctx = getContext();
@@ -666,8 +688,22 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
config.maxIterations != GreedyRewriteConfig::kNoLimit)
break;
+ // New iteration: start with an empty worklist.
worklist.clear();
+ // `OperationFolder` CSE's constant ops (and may move them into parents
+ // regions to enable more aggressive CSE'ing).
+ OperationFolder folder(getContext(), this);
+ auto insertKnownConstant = [&](Operation *op) {
+ // Check for existing constants when populating the worklist. This avoids
+ // accidentally reversing the constant order during processing.
+ Attribute constValue;
+ if (matchPattern(op, m_Constant(&constValue)))
+ if (!folder.insertKnownConstant(op, constValue))
+ return true;
+ return false;
+ };
+
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 6783263c184961..d3f02c6288a240 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -309,9 +309,9 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
// CHECK-LABEL: func.func @broadcast_vec2d_from_i32(
// CHECK-SAME: %[[SRC:.*]]: i32) {
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
@@ -393,8 +393,8 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {
// CHECK-LABEL: func.func @transpose_i8(
// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>)
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 012d30d96799f2..90f134d57a3c13 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -196,10 +196,10 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
}
// CHECK-LABEL: @broadcast_vec3d_from_vec1d(
// CHECK-SAME: %[[A:.*]]: vector<2xf32>)
-// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
-// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
-// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
-// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
+// CHECK-DAG: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32>
+// CHECK-DAG: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
+// CHECK-DAG: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32>
+// CHECK-DAG: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>>
// CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>>
// CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>>
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c945b24d..97eda0672c0299 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -533,9 +533,9 @@ func.func @transfer_write_scalable(%arg0: memref<?xf32, strided<[?], offset: ?>>
}
// CHECK-SAME: %[[ARG_0:.*]]: memref<?xf32, strided<[?], offset: ?>>,
-// CHECK: %[[C_0:.*]] = arith.constant 0 : index
-// CHECK: %[[C_16:.*]] = arith.constant 16 : index
-// CHECK: %[[STEP:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C_16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: %[[MASK_VEC:.*]] = arith.cmpi slt, %{{.*}}, %{{.*}} : vector<[16]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[UB:.*]] = arith.muli %[[VSCALE]], %[[C_16]] : index
@@ -556,8 +556,8 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
}
// CHECK-LABEL: func.func @vector_print_vector_0d(
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
@@ -581,9 +581,9 @@ func.func @vector_print_vector(%arg0: vector<2x2xf32>) {
}
// CHECK-LABEL: func.func @vector_print_vector(
// CHECK-SAME: %[[VEC:.*]]: vector<2x2xf32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<2x2xf32> to vector<4xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
@@ -650,10 +650,10 @@ func.func @transfer_read_array_of_scalable(%arg0: memref<3x?xf32>) -> vector<3x[
}
// CHECK-LABEL: func.func @transfer_read_array_of_scalable(
// CHECK-SAME: %[[ARG:.*]]: memref<3x?xf32>) -> vector<3x[4]xf32> {
-// CHECK: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[PADDING:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[ARG]], %[[C1]] : memref<3x?xf32>
@@ -684,9 +684,9 @@ func.func @transfer_write_array_of_scalable(%vec: vector<3x[4]xf32>, %arg0: memr
// CHECK-LABEL: func.func @transfer_write_array_of_scalable(
// CHECK-SAME: %[[VEC:.*]]: vector<3x[4]xf32>,
// CHECK-SAME: %[[MEMREF:.*]]: memref<3x?xf32>) {
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C3:.*]] = arith.constant 3 : index
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[ALLOCA_VEC:.*]] = memref.alloca() : memref<vector<3x[4]xf32>>
// CHECK: %[[ALLOCA_MASK:.*]] = memref.alloca() : memref<vector<3x[4]xi1>>
// CHECK: %[[DIM_SIZE:.*]] = memref.dim %[[MEMREF]], %[[C1]] : memref<3x?xf32>
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index ae2d0f40f03af5..e51f2485dadbcc 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -91,10 +91,10 @@ func.func @arith_constant_dense_2d_zero_f64() {
// -----
// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_i8() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C16:.*]] = arith.constant 16 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2> : vector<[16]xi8>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
@@ -111,10 +111,10 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
// -----
// CHECK-LABEL: func.func @arith_constant_dense_2d_nonzero_f64() {
-// CHECK: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
-// CHECK: %[[C1:.*]] = arith.constant 1 : index
-// CHECK: %[[C2:.*]] = arith.constant 2 : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2_SPLAT:.*]] = arith.constant dense<2.000000e+00> : vector<[2]xf64>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INIT_TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
index 3a1ab924ebdacb..021151b929d8e2 100644
--- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir
+++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir
@@ -168,8 +168,8 @@ llvm.func @no_crash_on_negative_gep_index() {
// CHECK-LABEL: llvm.func @coalesced_store_ints
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_ints(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)>
@@ -193,8 +193,8 @@ llvm.func @coalesced_store_ints(%arg: i64) {
// CHECK-LABEL: llvm.func @coalesced_store_ints_offset
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_ints_offset(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr
@@ -218,8 +218,8 @@ llvm.func @coalesced_store_ints_offset(%arg: i64) {
// CHECK-LABEL: llvm.func @coalesced_store_floats
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_floats(%arg: i64) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)>
@@ -292,9 +292,9 @@ llvm.func @coalesced_store_past_end(%arg: i64) {
// CHECK-SAME: %[[ARG:.*]]: i64
llvm.func @coalesced_store_packed_struct(%arg: i64) {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
- // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
- // CHECK: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
+ // CHECK-DAG: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK-DAG: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64
+ // CHECK-DAG: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64
// CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)>
%1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr
@@ -320,10 +320,10 @@ llvm.func @coalesced_store_packed_struct(%arg: i64) {
// CHECK-LABEL: llvm.func @vector_write_split
// CHECK-SAME: %[[ARG:.*]]: vector<4xi32>
llvm.func @vector_write_split(%arg: vector<4xi32>) {
- // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32
- // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32
- // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32...
[truncated]
|
d709c52
to
d37d2fc
Compare
d37d2fc
to
27d13e2
Compare
We have support for stacked PR by pushing branches to the LLVM repo and opening cascading PR (the second PR is targeting the branch from the first PR). |
27d13e2
to
8be671e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes sense to me. Previously the constant folding optimization limited the patterns that can be used. This change gives most of the same folding optimization while avoiding that. (we may need to hoist out some helpers there eventually). I'll wait for someone else to chime in too given change centrality, but this looks close to least disruptive change that avoids the issue.
@@ -273,7 +273,14 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, | |||
|
|||
/// If the op has only a single statement (apart from the yield), do nothing. | |||
Block *body = genericOp.getBody(); | |||
if (body->getOperations().size() <= 2) { | |||
int64_t numOps = 0; | |||
for (Operation &op : body->getOperations()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change related?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The rewrite pattern here was poorly designed and depends on constant hoisting to avoid getting into an infinite loop.
The rewrite pattern here decomposes a LinalgOp into multiple LinalgOps with one nested op per decomposed op. The rewrite pattern sometimes materializes constants inside of the region. These are now no longer hoisted immediately, so the rewrite pattern is applied again (decomposed LinalgOp now has 2 ops) and again (infinite loop). The hoisting now occurs at the beginning of a greedy pattern iteration, but there can be multiple pattern applications (of the same pattern) in one iteration.
The fix that I had here was not ideal. The issue in DecomposeLinalgOps.cpp
is now fixed in a different way.
4b7bb8f
to
ccfef12
Compare
Looks good to me, sufficient time for additional feedback so feel free to merge. |
Thx, can you also approve the PR on Github? |
ccfef12
to
9aa5e20
Compare
… during iterations The `GreedyPatternRewriteDriver` tries to iteratively fold ops and apply rewrite patterns to ops. It has special handling for constants: they are CSE'd and sometimes moved to parent regions to allow for additional CSE'ing. This happens in `OperationFolder`. To allow for efficient CSE'ing, `OperationFolder` maintains an internal lookup data structure to find the existing constant ops with the same value for each `IsolatedFromAbove` region: ```c++ /// A mapping between an insertion region and the constants that have been /// created within it. DenseMap<Region *, ConstantMap> foldScopes; ``` Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent. We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different `IsolatedFromAbove` region (that is not accessible). This commit changes the behavior of the `GreedyPatternRewriteDriver` such that `OperationFolder` is used to CSE constants at the beginning of each iteration (as the worklist is populated), but no longer during an iteration. `OperationFolder` is no longer used after populating the worklist, so we do not have to care about inconsistent state in the `OperationFolder` due to IR rewrites. The `GreedyPatternRewriteDriver` now performs the op folding by itself instead of calling `OperationFolder::tryToFold`. This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning `CHECK` into `CHECK-DAG`. Alternatives considered: The state of `OperationFolder` could be partially invalidated with every `notifyOperationModified` notification. That is more fragile than the solution in this commit because incorrect rewriter API usage can lead to missing notifications and hard-to-debug `IsolatedFromAbove` violations. (It did not fix the above mention bug in a downstream project, which could be due to incorrect rewriter API usage or due to another conceptual problem that I missed.) Moreover, ops are frequently getting modified during a greedy pattern rewrite, so we would likely keep invalidating large parts of the state of `OperationFolder` over and over.
9aa5e20
to
017e99e
Compare
Just bubbling an issue from downstream project here. This pattern seems to have caused some major regression in folding downstream iree-org/iree#16073 (comment) . I havent had time to triage this more to see if the fix is downstream or upstream, but posting here for visibility. |
Is there a limit on the number of iterations in IREE? I understood the change here as in that it shouldn't affect (much) the result because CSE would still be done before every iteration and so reaching fixed-point means we have successfully applied CSE and no pattern triggered. |
@matthias-springer we're seeing a bunch of downstream tests break that we've bisected to this change. These are not just lit tests (those were easy to fix with
Any ideas on what might be going on? Sorry but the repo is internal so we're unable to point to the pass that's failing. |
The
GreedyPatternRewriteDriver
tries to iteratively fold ops and apply rewrite patterns to ops. It has special handling for constants: they are CSE'd and sometimes moved to parent regions to allow for additional CSE'ing. This happens inOperationFolder
.To allow for efficient CSE'ing,
OperationFolder
maintains an internal lookup data structure to find the existing constant ops with the same value for eachIsolatedFromAbove
region:Rewrite patterns are allowed to modify operations. In particular, they may move operations (including constants) from one region to another one. Such an IR rewrite can make the above lookup data structure inconsistent.
We encountered such a bug in a downstream project. This bug materialized in the form of an op that uses the result of a constant op from a different
IsolatedFromAbove
region (that is not accessible).This commit changes the behavior of the
GreedyPatternRewriteDriver
such thatOperationFolder
is used to CSE constants at the beginning of each iteration (as the worklist is populated), but no longer during an iteration.OperationFolder
is no longer used after populating the worklist, so we do not have to care about inconsistent state in theOperationFolder
due to IR rewrites. TheGreedyPatternRewriteDriver
now performs the op folding by itself instead of callingOperationFolder::tryToFold
.This change changes the order of constant ops in test cases, but not the region in which they appear. All broken test cases were fixed by turning
CHECK
intoCHECK-DAG
.Alternatives considered: The state of
OperationFolder
could be partially invalidated with everynotifyOperationModified
notification. That is more fragile than the solution in this commit because incorrect rewriter API usage can lead to missing notifications and hard-to-debugIsolatedFromAbove
violations. (It did not fix the above mention bug in a downstream project, which could be due to incorrect rewriter API usage or due to another conceptual problem that I missed.) Moreover, ops are frequently getting modified during a greedy pattern rewrite, so we would likely keep invalidating large parts of the state ofOperationFolder
over and over.Migration guide: Turn
CHECK
intoCHECK-DAG
in test cases. Constant ops are no longer folded during a greedy pattern rewrite. If you rely on folding (and rematerialization) of constant ops during a greedy pattern rewrite, turn the folder into a pattern.