-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Transforms] Fix crash in reconcile-unrealized-casts
#158298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Fix crash in reconcile-unrealized-casts
#158298
Conversation
The `reconcile-unrealized-casts` pass used to crash when the input contains circular chains of `unrealized_conversion_cast` ops. Furthermore, the `reconcileUnrealizedCasts` helper functions used to erase ops that were not passed via the `castOps` operand. Such ops are now preserved. That's why some integration tests had to be changed. Also avoid copying the set of all unresolved materializations in `convertOperations`. This commit is in preparation of turning `RewriterBase::replaceOp` into a non-virtual function. --------- Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe Furthermore, the Also avoid copying the set of all unresolved materializations in This commit is in preparation of turning This is a re-upload of #158067, which was reverted due to CI failures. Note for LLVM integration: If you are seeing tests that are failing with Co-authored-by: Mehdi Amini <joker.eph@gmail.com> Full diff: https://github.com/llvm/llvm-project/pull/158298.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a096f82a4cfd8..f8caae3ce9995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1428,6 +1428,9 @@ struct ConversionConfig {
///
/// In the above example, %0 can be used instead of %3 and all cast ops are
/// folded away.
+void reconcileUnrealizedCasts(
+ const DenseSet<UnrealizedConversionCastOp> &castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
void reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df9700f11200f..d53e1e78f2027 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
//===----------------------------------------------------------------------===//
// OperationConverter
//===----------------------------------------------------------------------===//
+
namespace {
enum OpConversionMode {
/// In this mode, the conversion will ignore failed conversions to allow
@@ -3117,6 +3118,13 @@ enum OpConversionMode {
} // namespace
namespace mlir {
+
+// Predeclaration only.
+static void reconcileUnrealizedCasts(
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+ &castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
+
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
@@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();
- // Gather all unresolved materializations.
- SmallVector<UnrealizedConversionCastOp> allCastOps;
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
- &materializations = rewriterImpl.unresolvedMaterializations;
- for (auto it : materializations)
- allCastOps.push_back(it.first);
-
// Reconcile all UnrealizedConversionCastOps that were inserted by the
- // dialect conversion frameworks. (Not the one that were inserted by
+ // dialect conversion frameworks. (Not the ones that were inserted by
// patterns.)
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+ &materializations = rewriterImpl.unresolvedMaterializations;
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
- reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+ reconcileUnrealizedCasts(materializations, &remainingCastOps);
// Drop markers.
for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3303,20 +3306,19 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
-void mlir::reconcileUnrealizedCasts(
- ArrayRef<UnrealizedConversionCastOp> castOps,
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided). See documentation in
+/// DialectConversion.h for more details.
+/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
+/// algorithm may visit an operand (or user) which is a cast op, but will not
+/// try to reconcile it if not in the filtered set.
+template <typename RangeT>
+static void reconcileUnrealizedCastsImpl(
+ RangeT castOps,
+ function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
- // This set is maintained only if `remainingCastOps` is provided.
- DenseSet<Operation *> erasedOps;
-
- // Helper function that adds all operands to the worklist that are an
- // unrealized_conversion_cast op result.
- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
- for (Value v : castOp.getInputs())
- if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
- worklist.insert(inputCastOp);
- };
// Helper function that return the unrealized_conversion_cast op that
// defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3337,39 +3339,110 @@ void mlir::reconcileUnrealizedCasts(
// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
- continue;
- }
// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
+ return v.getDefiningOp() == castOp;
+ })) {
+ // Ran into a cycle.
+ break;
+ }
+
// Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
- enqueueOperands(castOp);
+ // matched op. We can directly use those inputs.
castOp.replaceAllUsesWith(nextCast.getInputs());
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}
- if (remainingCastOps)
- for (UnrealizedConversionCastOp op : castOps)
- if (!erasedOps.contains(op.getOperation()))
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
+ // used by an op that is not a cast op.
+ DenseSet<Operation *> liveOps;
+
+ // Helper function that marks the given op and transitively reachable input
+ // cast ops as alive.
+ auto markOpLive = [&](Operation *rootOp) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(rootOp);
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (liveOps.insert(op).second) {
+ // Successfully inserted: process reachable input cast ops.
+ for (Value v : op->getOperands())
+ if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ if (isCastOpOfInterestFn(castOp))
+ worklist.push_back(castOp);
+ }
+ }
+ };
+
+ // Find all alive cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ // The op may have been marked live already as being an operand of another
+ // live cast op.
+ if (liveOps.contains(op.getOperation()))
+ continue;
+ // If any of the users is not a cast op, mark the current op (and its
+ // input ops) as live.
+ if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
+ return !castOp || !isCastOpOfInterestFn(castOp);
+ }))
+ markOpLive(op);
+ }
+
+ // Erase all dead cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ if (liveOps.contains(op)) {
+ // Op is alive and was not erased. Add it to the remaining cast ops.
+ if (remainingCastOps)
remainingCastOps->push_back(op);
+ continue;
+ }
+
+ // Op is dead. Erase it.
+ op->dropAllUses();
+ op->erase();
+ }
+}
+
+void mlir::reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // Set of all cast ops for faster lookups.
+ DenseSet<UnrealizedConversionCastOp> castOpSet;
+ for (UnrealizedConversionCastOp op : castOps)
+ castOpSet.insert(op);
+ reconcileUnrealizedCasts(castOpSet, remainingCastOps);
+}
+
+void mlir::reconcileUnrealizedCasts(
+ const DenseSet<UnrealizedConversionCastOp> &castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ reconcileUnrealizedCastsImpl(
+ llvm::make_range(castOps.begin(), castOps.end()),
+ [&](UnrealizedConversionCastOp castOp) {
+ return castOps.contains(castOp);
+ },
+ remainingCastOps);
+}
+
+static void mlir::reconcileUnrealizedCasts(
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+ &castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ reconcileUnrealizedCastsImpl(
+ castOps.keys(),
+ [&](UnrealizedConversionCastOp castOp) {
+ return castOps.contains(castOp);
+ },
+ remainingCastOps);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index 3573114f5e038..ac5ca321c066f 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
%0 = builtin.unrealized_conversion_cast to index
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
+// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
+// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
+// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.user"(%2) : (i32) -> ()
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %0 : i32 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[c0:.*]] = arith.constant
+// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
+// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %cst = arith.constant 0 : i32
+ %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
+ "test.user"(%0) : (i32) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 25a338df8d790..01a826a638606 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 4c6a48d577a6c..1144a7caf36e8 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index dd000c6904bcb..82e63805cd027 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
index f5a6fc5ea2b20..e30c31693fae7 100644
--- a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
+++ b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -73,6 +74,7 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
opt.kernelBarePtrCallConv = true;
opt.kernelIntersperseSizeCallConv = true;
passManager.addPass(createGpuToLLVMConversionPass(opt));
+ passManager.addPass(createReconcileUnrealizedCastsPass());
}
} // namespace
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThe Furthermore, the Also avoid copying the set of all unresolved materializations in This commit is in preparation of turning This is a re-upload of #158067, which was reverted due to CI failures. Note for LLVM integration: If you are seeing tests that are failing with Co-authored-by: Mehdi Amini <joker.eph@gmail.com> Full diff: https://github.com/llvm/llvm-project/pull/158298.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index a096f82a4cfd8..f8caae3ce9995 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1428,6 +1428,9 @@ struct ConversionConfig {
///
/// In the above example, %0 can be used instead of %3 and all cast ops are
/// folded away.
+void reconcileUnrealizedCasts(
+ const DenseSet<UnrealizedConversionCastOp> &castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
void reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df9700f11200f..d53e1e78f2027 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
//===----------------------------------------------------------------------===//
// OperationConverter
//===----------------------------------------------------------------------===//
+
namespace {
enum OpConversionMode {
/// In this mode, the conversion will ignore failed conversions to allow
@@ -3117,6 +3118,13 @@ enum OpConversionMode {
} // namespace
namespace mlir {
+
+// Predeclaration only.
+static void reconcileUnrealizedCasts(
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+ &castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
+
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
// conversion mode.
@@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// After a successful conversion, apply rewrites.
rewriterImpl.applyRewrites();
- // Gather all unresolved materializations.
- SmallVector<UnrealizedConversionCastOp> allCastOps;
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
- &materializations = rewriterImpl.unresolvedMaterializations;
- for (auto it : materializations)
- allCastOps.push_back(it.first);
-
// Reconcile all UnrealizedConversionCastOps that were inserted by the
- // dialect conversion frameworks. (Not the one that were inserted by
+ // dialect conversion frameworks. (Not the ones that were inserted by
// patterns.)
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+ &materializations = rewriterImpl.unresolvedMaterializations;
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
- reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+ reconcileUnrealizedCasts(materializations, &remainingCastOps);
// Drop markers.
for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3303,20 +3306,19 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
-void mlir::reconcileUnrealizedCasts(
- ArrayRef<UnrealizedConversionCastOp> castOps,
+/// Try to reconcile all given UnrealizedConversionCastOps and store the
+/// left-over ops in `remainingCastOps` (if provided). See documentation in
+/// DialectConversion.h for more details.
+/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
+/// algorithm may visit an operand (or user) which is a cast op, but will not
+/// try to reconcile it if not in the filtered set.
+template <typename RangeT>
+static void reconcileUnrealizedCastsImpl(
+ RangeT castOps,
+ function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
- // This set is maintained only if `remainingCastOps` is provided.
- DenseSet<Operation *> erasedOps;
-
- // Helper function that adds all operands to the worklist that are an
- // unrealized_conversion_cast op result.
- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
- for (Value v : castOp.getInputs())
- if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
- worklist.insert(inputCastOp);
- };
// Helper function that return the unrealized_conversion_cast op that
// defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3337,39 +3339,110 @@ void mlir::reconcileUnrealizedCasts(
// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
- continue;
- }
// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
+ return v.getDefiningOp() == castOp;
+ })) {
+ // Ran into a cycle.
+ break;
+ }
+
// Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
- enqueueOperands(castOp);
+ // matched op. We can directly use those inputs.
castOp.replaceAllUsesWith(nextCast.getInputs());
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}
- if (remainingCastOps)
- for (UnrealizedConversionCastOp op : castOps)
- if (!erasedOps.contains(op.getOperation()))
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
+ // used by an op that is not a cast op.
+ DenseSet<Operation *> liveOps;
+
+ // Helper function that marks the given op and transitively reachable input
+ // cast ops as alive.
+ auto markOpLive = [&](Operation *rootOp) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(rootOp);
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (liveOps.insert(op).second) {
+ // Successfully inserted: process reachable input cast ops.
+ for (Value v : op->getOperands())
+ if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+ if (isCastOpOfInterestFn(castOp))
+ worklist.push_back(castOp);
+ }
+ }
+ };
+
+ // Find all alive cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ // The op may have been marked live already as being an operand of another
+ // live cast op.
+ if (liveOps.contains(op.getOperation()))
+ continue;
+ // If any of the users is not a cast op, mark the current op (and its
+ // input ops) as live.
+ if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
+ return !castOp || !isCastOpOfInterestFn(castOp);
+ }))
+ markOpLive(op);
+ }
+
+ // Erase all dead cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ if (liveOps.contains(op)) {
+ // Op is alive and was not erased. Add it to the remaining cast ops.
+ if (remainingCastOps)
remainingCastOps->push_back(op);
+ continue;
+ }
+
+ // Op is dead. Erase it.
+ op->dropAllUses();
+ op->erase();
+ }
+}
+
+void mlir::reconcileUnrealizedCasts(
+ ArrayRef<UnrealizedConversionCastOp> castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // Set of all cast ops for faster lookups.
+ DenseSet<UnrealizedConversionCastOp> castOpSet;
+ for (UnrealizedConversionCastOp op : castOps)
+ castOpSet.insert(op);
+ reconcileUnrealizedCasts(castOpSet, remainingCastOps);
+}
+
+void mlir::reconcileUnrealizedCasts(
+ const DenseSet<UnrealizedConversionCastOp> &castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ reconcileUnrealizedCastsImpl(
+ llvm::make_range(castOps.begin(), castOps.end()),
+ [&](UnrealizedConversionCastOp castOp) {
+ return castOps.contains(castOp);
+ },
+ remainingCastOps);
+}
+
+static void mlir::reconcileUnrealizedCasts(
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
+ &castOps,
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ reconcileUnrealizedCastsImpl(
+ castOps.keys(),
+ [&](UnrealizedConversionCastOp castOp) {
+ return castOps.contains(castOp);
+ },
+ remainingCastOps);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index 3573114f5e038..ac5ca321c066f 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
%0 = builtin.unrealized_conversion_cast to index
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
+// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
+// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
+// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.user"(%2) : (i32) -> ()
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %0 : i32 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[c0:.*]] = arith.constant
+// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
+// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %cst = arith.constant 0 : i32
+ %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
+ "test.user"(%0) : (i32) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 25a338df8d790..01a826a638606 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 4c6a48d577a6c..1144a7caf36e8 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index dd000c6904bcb..82e63805cd027 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
index f5a6fc5ea2b20..e30c31693fae7 100644
--- a/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
+++ b/mlir/test/lib/Pass/TestVulkanRunnerPipeline.cpp
@@ -13,6 +13,7 @@
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
@@ -73,6 +74,7 @@ void buildTestVulkanRunnerPipeline(OpPassManager &passManager,
opt.kernelBarePtrCallConv = true;
opt.kernelIntersperseSizeCallConv = true;
passManager.addPass(createGpuToLLVMConversionPass(opt));
+ passManager.addPass(createReconcileUnrealizedCastsPass());
}
} // namespace
|
Feel free to re-land, you don't need approval, but I'm curious: what was the source of the failure? |
|
This PR changes the combination of `replaceAllUsesWith` followed by `eraseOp` with `replaceOp`, which has the same effect. The original pattern, however, will be unsupported by the upcoming version of the dialect conversion; see llvm/llvm-project#158298. Signed-off-by: Ingo Müller <ingomueller@google.com>
- Add the -reconcile-unrealized-casts pass to tests that involve RewriterBase::replaceOp, following llvm/llvm-project#158298 - Update the operand type of rocdl.make.buffer.rsrc following llvm/llvm-project#159702
- Add the -reconcile-unrealized-casts pass to tests that involve RewriterBase::replaceOp, following llvm/llvm-project#158298 - Update the operand type of rocdl.make.buffer.rsrc following llvm/llvm-project#159702
The
reconcile-unrealized-casts
pass used to crash when the input contains circular chains ofunrealized_conversion_cast
ops.Furthermore, the
reconcileUnrealizedCasts
helper functions used to erase ops that were not passed via thecastOps
operand. Such ops are now preserved. That's why some integration tests had to be changed.Also avoid copying the set of all unresolved materializations in
convertOperations
.This commit is in preparation of turning
RewriterBase::replaceOp
into a non-virtual function.This is a re-upload of #158067, which was reverted due to CI failures.
Note for LLVM integration: If you are seeing tests that are failing with
error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
, you may have to add the-reconcile-unrealized-casts
pass to your pass pipeline. (Or switch to the-convert-to-llvm
pass instead of combining the various-convert-*-to-llvm
passes.)