Skip to content

Commit 66a0f0b

Browse files
[mlir][Transforms] Fix crash in reconcile-unrealized-casts
1 parent 3a7da9a commit 66a0f0b

File tree

5 files changed

+113
-22
lines changed

5 files changed

+113
-22
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3306,9 +3306,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
33063306
void mlir::reconcileUnrealizedCasts(
33073307
ArrayRef<UnrealizedConversionCastOp> castOps,
33083308
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3309+
// Set of all cast ops for faster lookups.
3310+
DenseSet<Operation *> castOpSet;
3311+
for (UnrealizedConversionCastOp op : castOps)
3312+
castOpSet.insert(op);
3313+
3314+
// A worklist of cast ops to process.
33093315
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
3310-
// This set is maintained only if `remainingCastOps` is provided.
3311-
DenseSet<Operation *> erasedOps;
33123316

33133317
// Helper function that adds all operands to the worklist that are an
33143318
// unrealized_conversion_cast op result.
@@ -3337,39 +3341,73 @@ void mlir::reconcileUnrealizedCasts(
33373341
// Process ops in the worklist bottom-to-top.
33383342
while (!worklist.empty()) {
33393343
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
3340-
if (castOp->use_empty()) {
3341-
// DCE: If the op has no users, erase it. Add the operands to the
3342-
// worklist to find additional DCE opportunities.
3343-
enqueueOperands(castOp);
3344-
if (remainingCastOps)
3345-
erasedOps.insert(castOp.getOperation());
3346-
castOp->erase();
3347-
continue;
3348-
}
33493344

33503345
// Traverse the chain of input cast ops to see if an op with the same
33513346
// input types can be found.
33523347
UnrealizedConversionCastOp nextCast = castOp;
33533348
while (nextCast) {
33543349
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
3350+
if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
3351+
return v.getDefiningOp() == castOp;
3352+
})) {
3353+
// Ran into a cycle.
3354+
break;
3355+
}
3356+
33553357
// Found a cast where the input types match the output types of the
3356-
// matched op. We can directly use those inputs and the matched op can
3357-
// be removed.
3358+
// matched op. We can directly use those inputs.
33583359
enqueueOperands(castOp);
33593360
castOp.replaceAllUsesWith(nextCast.getInputs());
3360-
if (remainingCastOps)
3361-
erasedOps.insert(castOp.getOperation());
3362-
castOp->erase();
33633361
break;
33643362
}
33653363
nextCast = getInputCast(nextCast);
33663364
}
33673365
}
33683366

3369-
if (remainingCastOps)
3370-
for (UnrealizedConversionCastOp op : castOps)
3371-
if (!erasedOps.contains(op.getOperation()))
3367+
// A set of all alive cast ops. I.e., ops whose results are (transitively)
3368+
// used by an op that is not a cast op.
3369+
DenseSet<Operation *> liveOps;
3370+
3371+
// Helper function that marks the given op and all ops transitively reachable
3372+
// input cast ops as alive.
3373+
auto markOpLive = [&](Operation *op) {
3374+
SmallVector<Operation *> worklist;
3375+
worklist.push_back(op);
3376+
while (!worklist.empty()) {
3377+
Operation *op = worklist.pop_back_val();
3378+
if (liveOps.insert(op).second) {
3379+
// Successfully inserted: the op is live. Add its operands to the
3380+
// worklist to mark them live.
3381+
for (Value v : op->getOperands())
3382+
if (castOpSet.contains(v.getDefiningOp()))
3383+
worklist.push_back(v.getDefiningOp());
3384+
}
3385+
}
3386+
};
3387+
3388+
// Find all alive cast ops.
3389+
for (UnrealizedConversionCastOp op : castOps) {
3390+
// If any of the users is not a cast op, mark the current op (and its
3391+
// input ops) as live.
3392+
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
3393+
return !castOpSet.contains(user);
3394+
}))
3395+
markOpLive(op);
3396+
}
3397+
3398+
// Erase all dead cast ops.
3399+
for (UnrealizedConversionCastOp op : castOps) {
3400+
if (liveOps.contains(op)) {
3401+
// Op is alive and was not erased. Add it to the remaining cast ops.
3402+
if (remainingCastOps)
33723403
remainingCastOps->push_back(op);
3404+
continue;
3405+
}
3406+
3407+
// Op is dead. Erase it.
3408+
op->dropAllUses();
3409+
op->erase();
3410+
}
33733411
}
33743412

33753413
//===----------------------------------------------------------------------===//

mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
194194
%0 = builtin.unrealized_conversion_cast to index
195195
return %0 : index
196196
}
197+
198+
// -----
199+
200+
// CHECK-LABEL: test.graph_region
201+
// CHECK-NEXT: "test.return"() : () -> ()
202+
test.graph_region {
203+
%0 = builtin.unrealized_conversion_cast %2 : i32 to i64
204+
%1 = builtin.unrealized_conversion_cast %0 : i64 to i16
205+
%2 = builtin.unrealized_conversion_cast %1 : i16 to i32
206+
"test.return"() : () -> ()
207+
}
208+
209+
// -----
210+
211+
// CHECK-LABEL: test.graph_region
212+
// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
213+
// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
214+
// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
215+
// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
216+
// CHECK-NEXT: "test.return"() : () -> ()
217+
test.graph_region {
218+
%0 = builtin.unrealized_conversion_cast %2 : i32 to i64
219+
%1 = builtin.unrealized_conversion_cast %0 : i64 to i16
220+
%2 = builtin.unrealized_conversion_cast %1 : i16 to i32
221+
"test.user"(%2) : (i32) -> ()
222+
"test.return"() : () -> ()
223+
}
224+
225+
// -----
226+
227+
// CHECK-LABEL: test.graph_region
228+
// CHECK-NEXT: "test.return"() : () -> ()
229+
test.graph_region {
230+
%0 = builtin.unrealized_conversion_cast %0 : i32 to i32
231+
"test.return"() : () -> ()
232+
}
233+
234+
// -----
235+
236+
// CHECK-LABEL: test.graph_region
237+
// CHECK-NEXT: %[[c0:.*]] = arith.constant
238+
// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
239+
// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
240+
// CHECK-NEXT: "test.return"() : () -> ()
241+
test.graph_region {
242+
%cst = arith.constant 0 : i32
243+
%0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
244+
"test.user"(%0) : (i32) -> ()
245+
"test.return"() : () -> ()
246+
}

mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
22
// RUN: -expand-strided-metadata \
33
// RUN: -test-cf-assert \
4-
// RUN: -convert-to-llvm | \
4+
// RUN: -convert-to-llvm \
5+
// RUN: -reconcile-unrealized-casts | \
56
// RUN: mlir-runner -e main -entry-point-result=void \
67
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
78
// RUN: FileCheck %s

mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
22
// RUN: -test-cf-assert \
3-
// RUN: -convert-to-llvm | \
3+
// RUN: -convert-to-llvm \
4+
// RUN: -reconcile-unrealized-casts | \
45
// RUN: mlir-runner -e main -entry-point-result=void \
56
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
67
// RUN: FileCheck %s

mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt %s -generate-runtime-verification \
22
// RUN: -test-cf-assert \
3-
// RUN: -convert-to-llvm | \
3+
// RUN: -convert-to-llvm \
4+
// RUN: -reconcile-unrealized-casts | \
45
// RUN: mlir-runner -e main -entry-point-result=void \
56
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
67
// RUN: FileCheck %s

0 commit comments

Comments
 (0)