Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,9 @@ struct ConversionConfig {
///
/// In the above example, %0 can be used instead of %3 and all cast ops are
/// folded away.
void reconcileUnrealizedCasts(
const DenseSet<UnrealizedConversionCastOp> &castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
void reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
Expand Down
151 changes: 112 additions & 39 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
//===----------------------------------------------------------------------===//
// OperationConverter
//===----------------------------------------------------------------------===//

namespace {
enum OpConversionMode {
/// In this mode, the conversion will ignore failed conversions to allow
Expand All @@ -3117,6 +3118,13 @@ enum OpConversionMode {
} // namespace

namespace mlir {

// Predeclaration only.
static void reconcileUnrealizedCasts(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this predeclaration here to keep the diff small, so that the PR is easier to review. Will move the entire function here in a follow-up NFC PR.

const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);

// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
Expand Down Expand Up @@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();

// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&materializations = rewriterImpl.unresolvedMaterializations;
for (auto it : materializations)
allCastOps.push_back(it.first);

// Reconcile all UnrealizedConversionCastOps that were inserted by the
// dialect conversion frameworks. (Not the one that were inserted by
// dialect conversion frameworks. (Not the ones that were inserted by
// patterns.)
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&materializations = rewriterImpl.unresolvedMaterializations;
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
reconcileUnrealizedCasts(materializations, &remainingCastOps);

// Drop markers.
for (UnrealizedConversionCastOp castOp : remainingCastOps)
Expand Down Expand Up @@ -3303,20 +3306,19 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//

void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
/// Try to reconcile all given UnrealizedConversionCastOps and store the
/// left-over ops in `remainingCastOps` (if provided). See documentation in
/// DialectConversion.h for more details.
/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
/// algorithm may visit an operand (or user) which is a cast op, but will not
/// try to reconcile it if not in the filtered set.
template <typename RangeT>
static void reconcileUnrealizedCastsImpl(
RangeT castOps,
function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
// A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
// This set is maintained only if `remainingCastOps` is provided.
DenseSet<Operation *> erasedOps;

// Helper function that adds all operands to the worklist that are an
// unrealized_conversion_cast op result.
auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
for (Value v : castOp.getInputs())
if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
worklist.insert(inputCastOp);
};

// Helper function that return the unrealized_conversion_cast op that
// defines all inputs of the given op (in the same order). Return "nullptr"
Expand All @@ -3337,39 +3339,110 @@ void mlir::reconcileUnrealizedCasts(
// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
if (castOp->use_empty()) {
// DCE: If the op has no users, erase it. Add the operands to the
// worklist to find additional DCE opportunities.
enqueueOperands(castOp);
if (remainingCastOps)
erasedOps.insert(castOp.getOperation());
castOp->erase();
continue;
}

// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
return v.getDefiningOp() == castOp;
})) {
// Ran into a cycle.
break;
}

// Found a cast where the input types match the output types of the
// matched op. We can directly use those inputs and the matched op can
// be removed.
enqueueOperands(castOp);
// matched op. We can directly use those inputs.
castOp.replaceAllUsesWith(nextCast.getInputs());
if (remainingCastOps)
erasedOps.insert(castOp.getOperation());
castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}

if (remainingCastOps)
for (UnrealizedConversionCastOp op : castOps)
if (!erasedOps.contains(op.getOperation()))
// A set of all alive cast ops. I.e., ops whose results are (transitively)
// used by an op that is not a cast op.
DenseSet<Operation *> liveOps;

// Helper function that marks the given op and transitively reachable input
// cast ops as alive.
auto markOpLive = [&](Operation *rootOp) {
SmallVector<Operation *> worklist;
worklist.push_back(rootOp);
while (!worklist.empty()) {
Operation *op = worklist.pop_back_val();
if (liveOps.insert(op).second) {
// Successfully inserted: process reachable input cast ops.
for (Value v : op->getOperands())
if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
if (isCastOpOfInterestFn(castOp))
worklist.push_back(castOp);
}
}
};

// Find all alive cast ops.
for (UnrealizedConversionCastOp op : castOps) {
// The op may have been marked live already as being an operand of another
// live cast op.
if (liveOps.contains(op.getOperation()))
continue;
// If any of the users is not a cast op, mark the current op (and its
// input ops) as live.
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
return !castOp || !isCastOpOfInterestFn(castOp);
}))
markOpLive(op);
}

// Erase all dead cast ops.
for (UnrealizedConversionCastOp op : castOps) {
if (liveOps.contains(op)) {
// Op is alive and was not erased. Add it to the remaining cast ops.
if (remainingCastOps)
remainingCastOps->push_back(op);
continue;
}

// Op is dead. Erase it.
op->dropAllUses();
op->erase();
}
}

void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
// Set of all cast ops for faster lookups.
DenseSet<UnrealizedConversionCastOp> castOpSet;
for (UnrealizedConversionCastOp op : castOps)
castOpSet.insert(op);
reconcileUnrealizedCasts(castOpSet, remainingCastOps);
}

void mlir::reconcileUnrealizedCasts(
const DenseSet<UnrealizedConversionCastOp> &castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
reconcileUnrealizedCastsImpl(
llvm::make_range(castOps.begin(), castOps.end()),
[&](UnrealizedConversionCastOp castOp) {
return castOps.contains(castOp);
},
remainingCastOps);
}

static void mlir::reconcileUnrealizedCasts(
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
reconcileUnrealizedCastsImpl(
castOps.keys(),
[&](UnrealizedConversionCastOp castOp) {
return castOps.contains(castOp);
},
remainingCastOps);
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
%0 = builtin.unrealized_conversion_cast to index
return %0 : index
}

// -----

// CHECK-LABEL: test.graph_region
// CHECK-NEXT: "test.return"() : () -> ()
test.graph_region {
%0 = builtin.unrealized_conversion_cast %2 : i32 to i64
%1 = builtin.unrealized_conversion_cast %0 : i64 to i16
%2 = builtin.unrealized_conversion_cast %1 : i16 to i32
"test.return"() : () -> ()
}

// -----

// CHECK-LABEL: test.graph_region
// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
// CHECK-NEXT: "test.return"() : () -> ()
test.graph_region {
%0 = builtin.unrealized_conversion_cast %2 : i32 to i64
%1 = builtin.unrealized_conversion_cast %0 : i64 to i16
%2 = builtin.unrealized_conversion_cast %1 : i16 to i32
"test.user"(%2) : (i32) -> ()
"test.return"() : () -> ()
}

// -----

// CHECK-LABEL: test.graph_region
// CHECK-NEXT: "test.return"() : () -> ()
test.graph_region {
%0 = builtin.unrealized_conversion_cast %0 : i32 to i32
"test.return"() : () -> ()
}

// -----

// CHECK-LABEL: test.graph_region
// CHECK-NEXT: %[[c0:.*]] = arith.constant
// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
// CHECK-NEXT: "test.return"() : () -> ()
test.graph_region {
%cst = arith.constant 0 : i32
%0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
"test.user"(%0) : (i32) -> ()
"test.return"() : () -> ()
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
// RUN: -convert-to-llvm | \
// RUN: -convert-to-llvm \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
Expand Down
Loading