diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index a096f82a4cfd8..f8caae3ce9995 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1428,6 +1428,9 @@ struct ConversionConfig { /// /// In the above example, %0 can be used instead of %3 and all cast ops are /// folded away. +void reconcileUnrealizedCasts( + const DenseSet &castOps, + SmallVectorImpl *remainingCastOps = nullptr); void reconcileUnrealizedCasts( ArrayRef castOps, SmallVectorImpl *remainingCastOps = nullptr); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index df9700f11200f..d53e1e78f2027 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns( //===----------------------------------------------------------------------===// // OperationConverter //===----------------------------------------------------------------------===// + namespace { enum OpConversionMode { /// In this mode, the conversion will ignore failed conversions to allow @@ -3117,6 +3118,13 @@ enum OpConversionMode { } // namespace namespace mlir { + +// Predeclaration only. +static void reconcileUnrealizedCasts( + const DenseMap + &castOps, + SmallVectorImpl *remainingCastOps); + // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the // conversion mode. @@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // After a successful conversion, apply rewrites. rewriterImpl.applyRewrites(); - // Gather all unresolved materializations. - SmallVector allCastOps; - const DenseMap - &materializations = rewriterImpl.unresolvedMaterializations; - for (auto it : materializations) - allCastOps.push_back(it.first); - // Reconcile all UnrealizedConversionCastOps that were inserted by the - // dialect conversion frameworks. (Not the one that were inserted by + // dialect conversion frameworks. (Not the ones that were inserted by // patterns.) + const DenseMap + &materializations = rewriterImpl.unresolvedMaterializations; SmallVector remainingCastOps; - reconcileUnrealizedCasts(allCastOps, &remainingCastOps); + reconcileUnrealizedCasts(materializations, &remainingCastOps); // Drop markers. for (UnrealizedConversionCastOp castOp : remainingCastOps) @@ -3303,20 +3306,19 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // Reconcile Unrealized Casts //===----------------------------------------------------------------------===// -void mlir::reconcileUnrealizedCasts( - ArrayRef castOps, +/// Try to reconcile all given UnrealizedConversionCastOps and store the +/// left-over ops in `remainingCastOps` (if provided). See documentation in +/// DialectConversion.h for more details. +/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the +/// algorithm may visit an operand (or user) which is a cast op, but will not +/// try to reconcile it if not in the filtered set. +template +static void reconcileUnrealizedCastsImpl( + RangeT castOps, + function_ref isCastOpOfInterestFn, SmallVectorImpl *remainingCastOps) { + // A worklist of cast ops to process. SetVector worklist(llvm::from_range, castOps); - // This set is maintained only if `remainingCastOps` is provided. - DenseSet erasedOps; - - // Helper function that adds all operands to the worklist that are an - // unrealized_conversion_cast op result. - auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) { - for (Value v : castOp.getInputs()) - if (auto inputCastOp = v.getDefiningOp()) - worklist.insert(inputCastOp); - }; // Helper function that return the unrealized_conversion_cast op that // defines all inputs of the given op (in the same order). Return "nullptr" @@ -3337,39 +3339,110 @@ void mlir::reconcileUnrealizedCasts( // Process ops in the worklist bottom-to-top. while (!worklist.empty()) { UnrealizedConversionCastOp castOp = worklist.pop_back_val(); - if (castOp->use_empty()) { - // DCE: If the op has no users, erase it. Add the operands to the - // worklist to find additional DCE opportunities. - enqueueOperands(castOp); - if (remainingCastOps) - erasedOps.insert(castOp.getOperation()); - castOp->erase(); - continue; - } // Traverse the chain of input cast ops to see if an op with the same // input types can be found. UnrealizedConversionCastOp nextCast = castOp; while (nextCast) { if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { + if (llvm::any_of(nextCast.getInputs(), [&](Value v) { + return v.getDefiningOp() == castOp; + })) { + // Ran into a cycle. + break; + } + // Found a cast where the input types match the output types of the - // matched op. We can directly use those inputs and the matched op can - // be removed. - enqueueOperands(castOp); + // matched op. We can directly use those inputs. castOp.replaceAllUsesWith(nextCast.getInputs()); - if (remainingCastOps) - erasedOps.insert(castOp.getOperation()); - castOp->erase(); break; } nextCast = getInputCast(nextCast); } } - if (remainingCastOps) - for (UnrealizedConversionCastOp op : castOps) - if (!erasedOps.contains(op.getOperation())) + // A set of all alive cast ops. I.e., ops whose results are (transitively) + // used by an op that is not a cast op. + DenseSet liveOps; + + // Helper function that marks the given op and transitively reachable input + // cast ops as alive. + auto markOpLive = [&](Operation *rootOp) { + SmallVector worklist; + worklist.push_back(rootOp); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (liveOps.insert(op).second) { + // Successfully inserted: process reachable input cast ops. + for (Value v : op->getOperands()) + if (auto castOp = v.getDefiningOp()) + if (isCastOpOfInterestFn(castOp)) + worklist.push_back(castOp); + } + } + }; + + // Find all alive cast ops. + for (UnrealizedConversionCastOp op : castOps) { + // The op may have been marked live already as being an operand of another + // live cast op. + if (liveOps.contains(op.getOperation())) + continue; + // If any of the users is not a cast op, mark the current op (and its + // input ops) as live. + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + auto castOp = dyn_cast(user); + return !castOp || !isCastOpOfInterestFn(castOp); + })) + markOpLive(op); + } + + // Erase all dead cast ops. + for (UnrealizedConversionCastOp op : castOps) { + if (liveOps.contains(op)) { + // Op is alive and was not erased. Add it to the remaining cast ops. + if (remainingCastOps) remainingCastOps->push_back(op); + continue; + } + + // Op is dead. Erase it. + op->dropAllUses(); + op->erase(); + } +} + +void mlir::reconcileUnrealizedCasts( + ArrayRef castOps, + SmallVectorImpl *remainingCastOps) { + // Set of all cast ops for faster lookups. + DenseSet castOpSet; + for (UnrealizedConversionCastOp op : castOps) + castOpSet.insert(op); + reconcileUnrealizedCasts(castOpSet, remainingCastOps); +} + +void mlir::reconcileUnrealizedCasts( + const DenseSet &castOps, + SmallVectorImpl *remainingCastOps) { + reconcileUnrealizedCastsImpl( + llvm::make_range(castOps.begin(), castOps.end()), + [&](UnrealizedConversionCastOp castOp) { + return castOps.contains(castOp); + }, + remainingCastOps); +} + +static void mlir::reconcileUnrealizedCasts( + const DenseMap + &castOps, + SmallVectorImpl *remainingCastOps) { + reconcileUnrealizedCastsImpl( + castOps.keys(), + [&](UnrealizedConversionCastOp castOp) { + return castOps.contains(castOp); + }, + remainingCastOps); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir index 3573114f5e038..ac5ca321c066f 100644 --- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir +++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir @@ -194,3 +194,53 @@ func.func @emptyCast() -> index { %0 = builtin.unrealized_conversion_cast to index return %0 : index } + +// ----- + +// CHECK-LABEL: test.graph_region +// CHECK-NEXT: "test.return"() : () -> () +test.graph_region { + %0 = builtin.unrealized_conversion_cast %2 : i32 to i64 + %1 = builtin.unrealized_conversion_cast %0 : i64 to i16 + %2 = builtin.unrealized_conversion_cast %1 : i16 to i32 + "test.return"() : () -> () +} + +// ----- + +// CHECK-LABEL: test.graph_region +// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64 +// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16 +// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32 +// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> () +// CHECK-NEXT: "test.return"() : () -> () +test.graph_region { + %0 = builtin.unrealized_conversion_cast %2 : i32 to i64 + %1 = builtin.unrealized_conversion_cast %0 : i64 to i16 + %2 = builtin.unrealized_conversion_cast %1 : i16 to i32 + "test.user"(%2) : (i32) -> () + "test.return"() : () -> () +} + +// ----- + +// CHECK-LABEL: test.graph_region +// CHECK-NEXT: "test.return"() : () -> () +test.graph_region { + %0 = builtin.unrealized_conversion_cast %0 : i32 to i32 + "test.return"() : () -> () +} + +// ----- + +// CHECK-LABEL: test.graph_region +// CHECK-NEXT: %[[c0:.*]] = arith.constant +// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32 +// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> () +// CHECK-NEXT: "test.return"() : () -> () +test.graph_region { + %cst = arith.constant 0 : i32 + %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32 + "test.user"(%0) : (i32) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir index 25a338df8d790..01a826a638606 100644 --- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir @@ -1,7 +1,8 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -expand-strided-metadata \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm | \ +// RUN: -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir index 4c6a48d577a6c..1144a7caf36e8 100644 --- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm | \ +// RUN: -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir index dd000c6904bcb..82e63805cd027 100644 --- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -generate-runtime-verification \ // RUN: -test-cf-assert \ -// RUN: -convert-to-llvm | \ +// RUN: -convert-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s diff --git a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp index f5a6fc5ea2b20..e30c31693fae7 100644 --- a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp +++ b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" @@ -73,6 +74,7 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager, opt.kernelBarePtrCallConv = true; opt.kernelIntersperseSizeCallConv = true; passManager.addPass(createGpuToLLVMConversionPass(opt)); + passManager.addPass(createReconcileUnrealizedCastsPass()); } } // namespace