From f7686bd8c84024bb9d6e17ad0d234711f09ef3ad Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 11 Sep 2025 13:14:56 +0000 Subject: [PATCH 1/3] [mlir][Transforms] Fix crash in `reconcile-unrealized-casts` --- .../mlir/Transforms/DialectConversion.h | 3 + .../Transforms/Utils/DialectConversion.cpp | 143 +++++++++++++----- .../reconcile-unrealized-casts.mlir | 50 ++++++ ...assume-alignment-runtime-verification.mlir | 3 +- .../atomic-rmw-runtime-verification.mlir | 3 +- .../MemRef/store-runtime-verification.mlir | 3 +- 6 files changed, 163 insertions(+), 42 deletions(-) diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 6949f4a14fdba..5bf73aa350534 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1421,6 +1421,9 @@ struct ConversionConfig { /// /// In the above example, %0 can be used instead of %3 and all cast ops are /// folded away. +void reconcileUnrealizedCasts( + const DenseSet &castOps, + SmallVectorImpl *remainingCastOps = nullptr); void reconcileUnrealizedCasts( ArrayRef castOps, SmallVectorImpl *remainingCastOps = nullptr); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 36ee87b533b3b..627e5d50c42ac 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns( //===----------------------------------------------------------------------===// // OperationConverter //===----------------------------------------------------------------------===// + namespace { enum OpConversionMode { /// In this mode, the conversion will ignore failed conversions to allow @@ -3117,6 +3118,13 @@ enum OpConversionMode { } // namespace namespace mlir { + +// Predeclaration only. +static void reconcileUnrealizedCasts( + const DenseMap + &castOps, + SmallVectorImpl *remainingCastOps); + // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the // conversion mode. @@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // After a successful conversion, apply rewrites. rewriterImpl.applyRewrites(); - // Gather all unresolved materializations. - SmallVector allCastOps; - const DenseMap - &materializations = rewriterImpl.unresolvedMaterializations; - for (auto it : materializations) - allCastOps.push_back(it.first); - // Reconcile all UnrealizedConversionCastOps that were inserted by the - // dialect conversion frameworks. (Not the one that were inserted by + // dialect conversion frameworks. (Not the ones that were inserted by // patterns.) + const DenseMap + &materializations = rewriterImpl.unresolvedMaterializations; SmallVector remainingCastOps; - reconcileUnrealizedCasts(allCastOps, &remainingCastOps); + reconcileUnrealizedCasts(materializations, &remainingCastOps); // Drop markers. for (UnrealizedConversionCastOp castOp : remainingCastOps) @@ -3303,20 +3306,15 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // Reconcile Unrealized Casts //===----------------------------------------------------------------------===// -void mlir::reconcileUnrealizedCasts( - ArrayRef castOps, +/// Try to reconcile all given UnrealizedConversionCastOps and store the +/// left-over ops in `remainingCastOps` (if provided). See documentation in +/// DialectConversion.h for more details. +template +static void reconcileUnrealizedCastsImpl( + RangeT castOps, function_ref isCastOpFn, SmallVectorImpl *remainingCastOps) { + // A worklist of cast ops to process. SetVector worklist(llvm::from_range, castOps); - // This set is maintained only if `remainingCastOps` is provided. - DenseSet erasedOps; - - // Helper function that adds all operands to the worklist that are an - // unrealized_conversion_cast op result. - auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) { - for (Value v : castOp.getInputs()) - if (auto inputCastOp = v.getDefiningOp()) - worklist.insert(inputCastOp); - }; // Helper function that return the unrealized_conversion_cast op that // defines all inputs of the given op (in the same order). Return "nullptr" @@ -3337,39 +3335,106 @@ void mlir::reconcileUnrealizedCasts( // Process ops in the worklist bottom-to-top. while (!worklist.empty()) { UnrealizedConversionCastOp castOp = worklist.pop_back_val(); - if (castOp->use_empty()) { - // DCE: If the op has no users, erase it. Add the operands to the - // worklist to find additional DCE opportunities. - enqueueOperands(castOp); - if (remainingCastOps) - erasedOps.insert(castOp.getOperation()); - castOp->erase(); - continue; - } // Traverse the chain of input cast ops to see if an op with the same // input types can be found. UnrealizedConversionCastOp nextCast = castOp; while (nextCast) { if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { + if (llvm::any_of(nextCast.getInputs(), [&](Value v) { + return v.getDefiningOp() == castOp; + })) { + // Ran into a cycle. + break; + } + // Found a cast where the input types match the output types of the - // matched op. We can directly use those inputs and the matched op can - // be removed. - enqueueOperands(castOp); + // matched op. We can directly use those inputs. castOp.replaceAllUsesWith(nextCast.getInputs()); - if (remainingCastOps) - erasedOps.insert(castOp.getOperation()); - castOp->erase(); break; } nextCast = getInputCast(nextCast); } } - if (remainingCastOps) - for (UnrealizedConversionCastOp op : castOps) - if (!erasedOps.contains(op.getOperation())) + // A set of all alive cast ops. I.e., ops whose results are (transitively) + // used by an op that is not a cast op. + DenseSet liveOps; + + // Helper function that marks the given op and transitively reachable input + // cast ops as alive. + auto markOpLive = [&](Operation *op) { + SmallVector worklist; + worklist.push_back(op); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (liveOps.insert(op).second) { + // Successfully inserted: process reachable input cast ops. + for (Value v : op->getOperands()) + if (auto castOp = v.getDefiningOp()) + if (isCastOpFn(castOp)) + worklist.push_back(castOp); + } + } + }; + + // Find all alive cast ops. + for (UnrealizedConversionCastOp op : castOps) { + // If any of the users is not a cast op, mark the current op (and its + // input ops) as live. + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + auto castOp = dyn_cast(user); + return !castOp || !isCastOpFn(castOp); + })) + markOpLive(op); + } + + // Erase all dead cast ops. + for (UnrealizedConversionCastOp op : castOps) { + if (liveOps.contains(op)) { + // Op is alive and was not erased. Add it to the remaining cast ops. + if (remainingCastOps) remainingCastOps->push_back(op); + continue; + } + + // Op is dead. Erase it. + op->dropAllUses(); + op->erase(); + } +} + +void mlir::reconcileUnrealizedCasts( + ArrayRef castOps, + SmallVectorImpl *remainingCastOps) { + // Set of all cast ops for faster lookups. + DenseSet castOpSet; + for (UnrealizedConversionCastOp op : castOps) + castOpSet.insert(op); + reconcileUnrealizedCasts(castOpSet, remainingCastOps); +} + +void mlir::reconcileUnrealizedCasts( + const DenseSet &castOps, + SmallVectorImpl *remainingCastOps) { + reconcileUnrealizedCastsImpl( + llvm::make_range(castOps.begin(), castOps.end()), + [&](UnrealizedConversionCastOp castOp) { + return castOps.contains(castOp); + }, + remainingCastOps); +} + +static void mlir::reconcileUnrealizedCasts( + const DenseMap + &castOps, + SmallVectorImpl *remainingCastOps) { + reconcileUnrealizedCastsImpl( + castOps.keys(), + [&](UnrealizedConversionCastOp castOp) { + return castOps.contains(castOp); + }, + remainingCastOps); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir index 3573114f5e038..ac5ca321c066f 100644 --- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir @@ -194,3 +194,53 @@ func.func @emptyCast() -> index { %0 = builtin.unrealized_conversion_cast to index return %0 : index } + +// ----- + +// CHECK-LABEL: test.graph_region +// CHECK-NEXT: "test.return"() : () -> () +test.graph_region { + %0 = builtin.unrealized_conversion_cast %2 : i32 to i64 + %1 = builtin.unrealized_conversion_cast %0 : i64 to i16 + %2 = builtin.unrealized_conversion_cast %1 : i16 to i32 + "test.return"() : () -> () +} + +// ----- + +// CHECK-LABEL: test.graph_region +// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64 +// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16 +// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32 +// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> () +// CHECK-NEXT: "test.return"() : () -> () +test.graph_region { + %0 = builtin.unrealized_conversion_cast %2 : i32 to i64 + %1 = builtin.unrealized_conversion_cast %0 : i64 to i16 + %2 = builtin.unrealized_conversion_cast %1 : i16 to i32 + "test.user"(%2) : (i32) -> () + "test.return"() : () -> () +} + +// ----- + +// CHECK-LABEL: test.graph_region +// CHECK-NEXT: "test.return"() : () -> () +test.graph_region { + %0 = builtin.unrealized_conversion_cast %0 : i32 to i32 + "test.return"() : () -> () +} + +// ----- + +// CHECK-LABEL: test.graph_region +// CHECK-NEXT: %[[c0:.*]] = arith.constant +// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32 +// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> () +// CHECK-NEXT: "test.return"() : () -> () +test.graph_region { + %cst = arith.constant 0 : i32 + %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32 + "test.user"(%0) : (i32) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir index 25a338df8d790..01a826a638606 100644 --- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -expand-strided-metadata \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm | \ +// RUN: -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir index 4c6a48d577a6c..1144a7caf36e8 100644 --- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm | \ +// RUN: -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir index dd000c6904bcb..82e63805cd027 100644 --- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm | \ +// RUN: -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s From fa3bebd3835a77898404c6516df675ed116fa249 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 12 Sep 2025 14:16:09 +0100 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: Mehdi Amini --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 627e5d50c42ac..0c29b168d771a 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3309,9 +3309,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { /// Try to reconcile all given UnrealizedConversionCastOps and store the /// left-over ops in `remainingCastOps` (if provided). See documentation in /// DialectConversion.h for more details. +/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the algorithm may +/// visit an operand (or user) which is a cast op, but will not try to reconcile it if not in the +/// filtered set. template static void reconcileUnrealizedCastsImpl( - RangeT castOps, function_ref isCastOpFn, + RangeT castOps, function_ref isCastOpOfInterestFn, SmallVectorImpl *remainingCastOps) { // A worklist of cast ops to process. SetVector worklist(llvm::from_range, castOps); @@ -3363,9 +3366,9 @@ static void reconcileUnrealizedCastsImpl( // Helper function that marks the given op and transitively reachable input // cast ops as alive. - auto markOpLive = [&](Operation *op) { + auto markOpLive = [&](Operation *rootOp) { SmallVector worklist; - worklist.push_back(op); + worklist.push_back(rootOp); while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); if (liveOps.insert(op).second) { @@ -3380,6 +3383,8 @@ static void reconcileUnrealizedCastsImpl( // Find all alive cast ops. for (UnrealizedConversionCastOp op : castOps) { + // The op may have been marked live already as being an operand of another live cast op. + if (liveOps.contains(op.getOperation()) continue; // If any of the users is not a cast op, mark the current op (and its // input ops) as live. if (llvm::any_of(op->getUsers(), [&](Operation *user) { From 6183b1df34c779e9b456880fc8913e4d8d608cfe Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 12 Sep 2025 13:21:47 +0000 Subject: [PATCH 3/3] fix --- .../Transforms/Utils/DialectConversion.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 0c29b168d771a..e4e2ad1729b14 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3309,12 +3309,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { /// Try to reconcile all given UnrealizedConversionCastOps and store the /// left-over ops in `remainingCastOps` (if provided). See documentation in /// DialectConversion.h for more details. -/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the algorithm may -/// visit an operand (or user) which is a cast op, but will not try to reconcile it if not in the -/// filtered set. +/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the +/// algorithm may visit an operand (or user) which is a cast op, but will not +/// try to reconcile it if not in the filtered set. template static void reconcileUnrealizedCastsImpl( - RangeT castOps, function_ref isCastOpOfInterestFn, + RangeT castOps, + function_ref isCastOpOfInterestFn, SmallVectorImpl *remainingCastOps) { // A worklist of cast ops to process. SetVector worklist(llvm::from_range, castOps); @@ -3375,7 +3376,7 @@ static void reconcileUnrealizedCastsImpl( // Successfully inserted: process reachable input cast ops. for (Value v : op->getOperands()) if (auto castOp = v.getDefiningOp()) - if (isCastOpFn(castOp)) + if (isCastOpOfInterestFn(castOp)) worklist.push_back(castOp); } } @@ -3383,13 +3384,15 @@ static void reconcileUnrealizedCastsImpl( // Find all alive cast ops. for (UnrealizedConversionCastOp op : castOps) { - // The op may have been marked live already as being an operand of another live cast op. - if (liveOps.contains(op.getOperation()) continue; + // The op may have been marked live already as being an operand of another + // live cast op. + if (liveOps.contains(op.getOperation())) + continue; // If any of the users is not a cast op, mark the current op (and its // input ops) as live. if (llvm::any_of(op->getUsers(), [&](Operation *user) { auto castOp = dyn_cast(user); - return !castOp || !isCastOpFn(castOp); + return !castOp || !isCastOpOfInterestFn(castOp); })) markOpLive(op); }