Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 145 additions & 150 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3097,6 +3097,151 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
return minDepth;
}

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//

/// Try to reconcile all given UnrealizedConversionCastOps and store the
/// left-over ops in `remainingCastOps` (if provided). See documentation in
/// DialectConversion.h for more details.
/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
/// algorithm may visit an operand (or user) which is a cast op, but will not
/// try to reconcile it if not in the filtered set.
template <typename RangeT>
static void reconcileUnrealizedCastsImpl(
RangeT castOps,
function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
// A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);

// Helper function that return the unrealized_conversion_cast op that
// defines all inputs of the given op (in the same order). Return "nullptr"
// if there is no such op.
auto getInputCast =
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
if (castOp.getInputs().empty())
return {};
auto inputCastOp =
castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
if (!inputCastOp)
return {};
if (inputCastOp.getOutputs() != castOp.getInputs())
return {};
return inputCastOp;
};

// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();

// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
return v.getDefiningOp() == castOp;
})) {
// Ran into a cycle.
break;
}

// Found a cast where the input types match the output types of the
// matched op. We can directly use those inputs.
castOp.replaceAllUsesWith(nextCast.getInputs());
break;
}
nextCast = getInputCast(nextCast);
}
}

// A set of all alive cast ops. I.e., ops whose results are (transitively)
// used by an op that is not a cast op.
DenseSet<Operation *> liveOps;

// Helper function that marks the given op and transitively reachable input
// cast ops as alive.
auto markOpLive = [&](Operation *rootOp) {
SmallVector<Operation *> worklist;
worklist.push_back(rootOp);
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
if (liveOps.insert(op).second) {
// Successfully inserted: process reachable input cast ops.
for (Value v : op->getOperands())
if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
if (isCastOpOfInterestFn(castOp))
worklist.push_back(castOp);
}
}
};

// Find all alive cast ops.
for (UnrealizedConversionCastOp op : castOps) {
// The op may have been marked live already as being an operand of another
// live cast op.
if (liveOps.contains(op.getOperation()))
continue;
// If any of the users is not a cast op, mark the current op (and its
// input ops) as live.
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
return !castOp || !isCastOpOfInterestFn(castOp);
}))
markOpLive(op);
}

// Erase all dead cast ops.
for (UnrealizedConversionCastOp op : castOps) {
if (liveOps.contains(op)) {
// Op is alive and was not erased. Add it to the remaining cast ops.
if (remainingCastOps)
remainingCastOps->push_back(op);
continue;
}

// Op is dead. Erase it.
op->dropAllUses();
op->erase();
}
}

void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
// Set of all cast ops for faster lookups.
DenseSet<UnrealizedConversionCastOp> castOpSet;
for (UnrealizedConversionCastOp op : castOps)
castOpSet.insert(op);
reconcileUnrealizedCasts(castOpSet, remainingCastOps);
}

void mlir::reconcileUnrealizedCasts(
const DenseSet<UnrealizedConversionCastOp> &castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
reconcileUnrealizedCastsImpl(
llvm::make_range(castOps.begin(), castOps.end()),
[&](UnrealizedConversionCastOp castOp) {
return castOps.contains(castOp);
},
remainingCastOps);
}

namespace mlir {
static void reconcileUnrealizedCasts(
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
reconcileUnrealizedCastsImpl(
castOps.keys(),
[&](UnrealizedConversionCastOp castOp) {
return castOps.contains(castOp);
},
remainingCastOps);
}
} // namespace mlir

//===----------------------------------------------------------------------===//
// OperationConverter
//===----------------------------------------------------------------------===//
Expand All @@ -3118,13 +3263,6 @@ enum OpConversionMode {
} // namespace

namespace mlir {

// Predeclaration only.
static void reconcileUnrealizedCasts(
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);

// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
Expand Down Expand Up @@ -3302,149 +3440,6 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
return success();
}

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//

/// Try to reconcile all given UnrealizedConversionCastOps and store the
/// left-over ops in `remainingCastOps` (if provided). See documentation in
/// DialectConversion.h for more details.
/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
/// algorithm may visit an operand (or user) which is a cast op, but will not
/// try to reconcile it if not in the filtered set.
template <typename RangeT>
static void reconcileUnrealizedCastsImpl(
RangeT castOps,
function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
// A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);

// Helper function that return the unrealized_conversion_cast op that
// defines all inputs of the given op (in the same order). Return "nullptr"
// if there is no such op.
auto getInputCast =
[](UnrealizedConversionCastOp castOp) -> UnrealizedConversionCastOp {
if (castOp.getInputs().empty())
return {};
auto inputCastOp =
castOp.getInputs().front().getDefiningOp<UnrealizedConversionCastOp>();
if (!inputCastOp)
return {};
if (inputCastOp.getOutputs() != castOp.getInputs())
return {};
return inputCastOp;
};

// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();

// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
return v.getDefiningOp() == castOp;
})) {
// Ran into a cycle.
break;
}

// Found a cast where the input types match the output types of the
// matched op. We can directly use those inputs.
castOp.replaceAllUsesWith(nextCast.getInputs());
break;
}
nextCast = getInputCast(nextCast);
}
}

// A set of all alive cast ops. I.e., ops whose results are (transitively)
// used by an op that is not a cast op.
DenseSet<Operation *> liveOps;

// Helper function that marks the given op and transitively reachable input
// cast ops as alive.
auto markOpLive = [&](Operation *rootOp) {
SmallVector<Operation *> worklist;
worklist.push_back(rootOp);
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
if (liveOps.insert(op).second) {
// Successfully inserted: process reachable input cast ops.
for (Value v : op->getOperands())
if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
if (isCastOpOfInterestFn(castOp))
worklist.push_back(castOp);
}
}
};

// Find all alive cast ops.
for (UnrealizedConversionCastOp op : castOps) {
// The op may have been marked live already as being an operand of another
// live cast op.
if (liveOps.contains(op.getOperation()))
continue;
// If any of the users is not a cast op, mark the current op (and its
// input ops) as live.
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
return !castOp || !isCastOpOfInterestFn(castOp);
}))
markOpLive(op);
}

// Erase all dead cast ops.
for (UnrealizedConversionCastOp op : castOps) {
if (liveOps.contains(op)) {
// Op is alive and was not erased. Add it to the remaining cast ops.
if (remainingCastOps)
remainingCastOps->push_back(op);
continue;
}

// Op is dead. Erase it.
op->dropAllUses();
op->erase();
}
}

void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
// Set of all cast ops for faster lookups.
DenseSet<UnrealizedConversionCastOp> castOpSet;
for (UnrealizedConversionCastOp op : castOps)
castOpSet.insert(op);
reconcileUnrealizedCasts(castOpSet, remainingCastOps);
}

void mlir::reconcileUnrealizedCasts(
const DenseSet<UnrealizedConversionCastOp> &castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
reconcileUnrealizedCastsImpl(
llvm::make_range(castOps.begin(), castOps.end()),
[&](UnrealizedConversionCastOp castOp) {
return castOps.contains(castOp);
},
remainingCastOps);
}

static void mlir::reconcileUnrealizedCasts(
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
reconcileUnrealizedCastsImpl(
castOps.keys(),
[&](UnrealizedConversionCastOp castOp) {
return castOps.contains(castOp);
},
remainingCastOps);
}

//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
Expand Down