From 7da049a5fca5a47007002d4f99980e6f4c00413b Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 12 Sep 2025 13:39:17 +0000 Subject: [PATCH] [mlir][Transforms][NFC] Remove `reconcileUnrealizedCasts` predeclaration This is a follow-up to https://github.com/llvm/llvm-project/pull/158067/files#r2343711946. --- .../Transforms/Utils/DialectConversion.cpp | 295 +++++++++--------- 1 file changed, 145 insertions(+), 150 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d53e1e78f2027..f7565cfb0e45e 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -3097,6 +3097,151 @@ unsigned OperationLegalizer::applyCostModelToPatterns( return minDepth; } +//===----------------------------------------------------------------------===// +// Reconcile Unrealized Casts +//===----------------------------------------------------------------------===// + +/// Try to reconcile all given UnrealizedConversionCastOps and store the +/// left-over ops in `remainingCastOps` (if provided). See documentation in +/// DialectConversion.h for more details. +/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the +/// algorithm may visit an operand (or user) which is a cast op, but will not +/// try to reconcile it if not in the filtered set. +template +static void reconcileUnrealizedCastsImpl( + RangeT castOps, + function_ref isCastOpOfInterestFn, + SmallVectorImpl *remainingCastOps) { + // A worklist of cast ops to process. + SetVector worklist(llvm::from_range, castOps); + + // Helper function that return the unrealized_conversion_cast op that + // defines all inputs of the given op (in the same order). Return "nullptr" + // if there is no such op. + auto getInputCast = + [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp { + if (castOp.getInputs().empty()) + return {}; + auto inputCastOp = + castOp.getInputs().front().getDefiningOp(); + if (!inputCastOp) + return {}; + if (inputCastOp.getOutputs() != castOp.getInputs()) + return {}; + return inputCastOp; + }; + + // Process ops in the worklist bottom-to-top. + while (!worklist.empty()) { + UnrealizedConversionCastOp castOp = worklist.pop_back_val(); + + // Traverse the chain of input cast ops to see if an op with the same + // input types can be found. + UnrealizedConversionCastOp nextCast = castOp; + while (nextCast) { + if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { + if (llvm::any_of(nextCast.getInputs(), [&](Value v) { + return v.getDefiningOp() == castOp; + })) { + // Ran into a cycle. + break; + } + + // Found a cast where the input types match the output types of the + // matched op. We can directly use those inputs. + castOp.replaceAllUsesWith(nextCast.getInputs()); + break; + } + nextCast = getInputCast(nextCast); + } + } + + // A set of all alive cast ops. I.e., ops whose results are (transitively) + // used by an op that is not a cast op. + DenseSet liveOps; + + // Helper function that marks the given op and transitively reachable input + // cast ops as alive. + auto markOpLive = [&](Operation *rootOp) { + SmallVector worklist; + worklist.push_back(rootOp); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (liveOps.insert(op).second) { + // Successfully inserted: process reachable input cast ops. + for (Value v : op->getOperands()) + if (auto castOp = v.getDefiningOp()) + if (isCastOpOfInterestFn(castOp)) + worklist.push_back(castOp); + } + } + }; + + // Find all alive cast ops. + for (UnrealizedConversionCastOp op : castOps) { + // The op may have been marked live already as being an operand of another + // live cast op. + if (liveOps.contains(op.getOperation())) + continue; + // If any of the users is not a cast op, mark the current op (and its + // input ops) as live. + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + auto castOp = dyn_cast(user); + return !castOp || !isCastOpOfInterestFn(castOp); + })) + markOpLive(op); + } + + // Erase all dead cast ops. + for (UnrealizedConversionCastOp op : castOps) { + if (liveOps.contains(op)) { + // Op is alive and was not erased. Add it to the remaining cast ops. + if (remainingCastOps) + remainingCastOps->push_back(op); + continue; + } + + // Op is dead. Erase it. + op->dropAllUses(); + op->erase(); + } +} + +void mlir::reconcileUnrealizedCasts( + ArrayRef castOps, + SmallVectorImpl *remainingCastOps) { + // Set of all cast ops for faster lookups. + DenseSet castOpSet; + for (UnrealizedConversionCastOp op : castOps) + castOpSet.insert(op); + reconcileUnrealizedCasts(castOpSet, remainingCastOps); +} + +void mlir::reconcileUnrealizedCasts( + const DenseSet &castOps, + SmallVectorImpl *remainingCastOps) { + reconcileUnrealizedCastsImpl( + llvm::make_range(castOps.begin(), castOps.end()), + [&](UnrealizedConversionCastOp castOp) { + return castOps.contains(castOp); + }, + remainingCastOps); +} + +namespace mlir { +static void reconcileUnrealizedCasts( + const DenseMap + &castOps, + SmallVectorImpl *remainingCastOps) { + reconcileUnrealizedCastsImpl( + castOps.keys(), + [&](UnrealizedConversionCastOp castOp) { + return castOps.contains(castOp); + }, + remainingCastOps); +} +} // namespace mlir + //===----------------------------------------------------------------------===// // OperationConverter //===----------------------------------------------------------------------===// @@ -3118,13 +3263,6 @@ enum OpConversionMode { } // namespace namespace mlir { - -// Predeclaration only. -static void reconcileUnrealizedCasts( - const DenseMap - &castOps, - SmallVectorImpl *remainingCastOps); - // This class converts operations to a given conversion target via a set of // rewrite patterns. The conversion behaves differently depending on the // conversion mode. @@ -3302,149 +3440,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { return success(); } -//===----------------------------------------------------------------------===// -// Reconcile Unrealized Casts -//===----------------------------------------------------------------------===// - -/// Try to reconcile all given UnrealizedConversionCastOps and store the -/// left-over ops in `remainingCastOps` (if provided). See documentation in -/// DialectConversion.h for more details. -/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the -/// algorithm may visit an operand (or user) which is a cast op, but will not -/// try to reconcile it if not in the filtered set. -template -static void reconcileUnrealizedCastsImpl( - RangeT castOps, - function_ref isCastOpOfInterestFn, - SmallVectorImpl *remainingCastOps) { - // A worklist of cast ops to process. - SetVector worklist(llvm::from_range, castOps); - - // Helper function that return the unrealized_conversion_cast op that - // defines all inputs of the given op (in the same order). Return "nullptr" - // if there is no such op. - auto getInputCast = - [](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp { - if (castOp.getInputs().empty()) - return {}; - auto inputCastOp = - castOp.getInputs().front().getDefiningOp(); - if (!inputCastOp) - return {}; - if (inputCastOp.getOutputs() != castOp.getInputs()) - return {}; - return inputCastOp; - }; - - // Process ops in the worklist bottom-to-top. - while (!worklist.empty()) { - UnrealizedConversionCastOp castOp = worklist.pop_back_val(); - - // Traverse the chain of input cast ops to see if an op with the same - // input types can be found. - UnrealizedConversionCastOp nextCast = castOp; - while (nextCast) { - if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) { - if (llvm::any_of(nextCast.getInputs(), [&](Value v) { - return v.getDefiningOp() == castOp; - })) { - // Ran into a cycle. - break; - } - - // Found a cast where the input types match the output types of the - // matched op. We can directly use those inputs. - castOp.replaceAllUsesWith(nextCast.getInputs()); - break; - } - nextCast = getInputCast(nextCast); - } - } - - // A set of all alive cast ops. I.e., ops whose results are (transitively) - // used by an op that is not a cast op. - DenseSet liveOps; - - // Helper function that marks the given op and transitively reachable input - // cast ops as alive. - auto markOpLive = [&](Operation *rootOp) { - SmallVector worklist; - worklist.push_back(rootOp); - while (!worklist.empty()) { - Operation *op = worklist.pop_back_val(); - if (liveOps.insert(op).second) { - // Successfully inserted: process reachable input cast ops. - for (Value v : op->getOperands()) - if (auto castOp = v.getDefiningOp()) - if (isCastOpOfInterestFn(castOp)) - worklist.push_back(castOp); - } - } - }; - - // Find all alive cast ops. - for (UnrealizedConversionCastOp op : castOps) { - // The op may have been marked live already as being an operand of another - // live cast op. - if (liveOps.contains(op.getOperation())) - continue; - // If any of the users is not a cast op, mark the current op (and its - // input ops) as live. - if (llvm::any_of(op->getUsers(), [&](Operation *user) { - auto castOp = dyn_cast(user); - return !castOp || !isCastOpOfInterestFn(castOp); - })) - markOpLive(op); - } - - // Erase all dead cast ops. - for (UnrealizedConversionCastOp op : castOps) { - if (liveOps.contains(op)) { - // Op is alive and was not erased. Add it to the remaining cast ops. - if (remainingCastOps) - remainingCastOps->push_back(op); - continue; - } - - // Op is dead. Erase it. - op->dropAllUses(); - op->erase(); - } -} - -void mlir::reconcileUnrealizedCasts( - ArrayRef castOps, - SmallVectorImpl *remainingCastOps) { - // Set of all cast ops for faster lookups. - DenseSet castOpSet; - for (UnrealizedConversionCastOp op : castOps) - castOpSet.insert(op); - reconcileUnrealizedCasts(castOpSet, remainingCastOps); -} - -void mlir::reconcileUnrealizedCasts( - const DenseSet &castOps, - SmallVectorImpl *remainingCastOps) { - reconcileUnrealizedCastsImpl( - llvm::make_range(castOps.begin(), castOps.end()), - [&](UnrealizedConversionCastOp castOp) { - return castOps.contains(castOp); - }, - remainingCastOps); -} - -static void mlir::reconcileUnrealizedCasts( - const DenseMap - &castOps, - SmallVectorImpl *remainingCastOps) { - reconcileUnrealizedCastsImpl( - castOps.keys(), - [&](UnrealizedConversionCastOp castOp) { - return castOps.contains(castOp); - }, - remainingCastOps); -} - //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===//