@@ -3306,9 +3306,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3306
3306
void mlir::reconcileUnrealizedCasts (
3307
3307
ArrayRef<UnrealizedConversionCastOp> castOps,
3308
3308
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3309
+ // Set of all cast ops for faster lookups.
3310
+ DenseSet<Operation *> castOpSet;
3311
+ for (UnrealizedConversionCastOp op : castOps)
3312
+ castOpSet.insert (op);
3313
+
3314
+ // A worklist of cast ops to process.
3309
3315
SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3310
- // This set is maintained only if `remainingCastOps` is provided.
3311
- DenseSet<Operation *> erasedOps;
3312
3316
3313
3317
// Helper function that adds all operands to the worklist that are an
3314
3318
// unrealized_conversion_cast op result.
@@ -3337,39 +3341,73 @@ void mlir::reconcileUnrealizedCasts(
3337
3341
// Process ops in the worklist bottom-to-top.
3338
3342
while (!worklist.empty ()) {
3339
3343
UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3340
- if (castOp->use_empty ()) {
3341
- // DCE: If the op has no users, erase it. Add the operands to the
3342
- // worklist to find additional DCE opportunities.
3343
- enqueueOperands (castOp);
3344
- if (remainingCastOps)
3345
- erasedOps.insert (castOp.getOperation ());
3346
- castOp->erase ();
3347
- continue ;
3348
- }
3349
3344
3350
3345
// Traverse the chain of input cast ops to see if an op with the same
3351
3346
// input types can be found.
3352
3347
UnrealizedConversionCastOp nextCast = castOp;
3353
3348
while (nextCast) {
3354
3349
if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3350
+ if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3351
+ return v.getDefiningOp () == castOp;
3352
+ })) {
3353
+ // Ran into a cycle.
3354
+ break ;
3355
+ }
3356
+
3355
3357
// Found a cast where the input types match the output types of the
3356
- // matched op. We can directly use those inputs and the matched op can
3357
- // be removed.
3358
+ // matched op. We can directly use those inputs.
3358
3359
enqueueOperands (castOp);
3359
3360
castOp.replaceAllUsesWith (nextCast.getInputs ());
3360
- if (remainingCastOps)
3361
- erasedOps.insert (castOp.getOperation ());
3362
- castOp->erase ();
3363
3361
break ;
3364
3362
}
3365
3363
nextCast = getInputCast (nextCast);
3366
3364
}
3367
3365
}
3368
3366
3369
- if (remainingCastOps)
3370
- for (UnrealizedConversionCastOp op : castOps)
3371
- if (!erasedOps.contains (op.getOperation ()))
3367
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
3368
+ // used by an op that is not a cast op.
3369
+ DenseSet<Operation *> liveOps;
3370
+
3371
+ // Helper function that marks the given op and all ops transitively reachable
3372
+ // input cast ops as alive.
3373
+ auto markOpLive = [&](Operation *op) {
3374
+ SmallVector<Operation *> worklist;
3375
+ worklist.push_back (op);
3376
+ while (!worklist.empty ()) {
3377
+ Operation *op = worklist.pop_back_val ();
3378
+ if (liveOps.insert (op).second ) {
3379
+ // Successfully inserted: the op is live. Add its operands to the
3380
+ // worklist to mark them live.
3381
+ for (Value v : op->getOperands ())
3382
+ if (castOpSet.contains (v.getDefiningOp ()))
3383
+ worklist.push_back (v.getDefiningOp ());
3384
+ }
3385
+ }
3386
+ };
3387
+
3388
+ // Find all alive cast ops.
3389
+ for (UnrealizedConversionCastOp op : castOps) {
3390
+ // If any of the users is not a cast op, mark the current op (and its
3391
+ // input ops) as live.
3392
+ if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3393
+ return !castOpSet.contains (user);
3394
+ }))
3395
+ markOpLive (op);
3396
+ }
3397
+
3398
+ // Erase all dead cast ops.
3399
+ for (UnrealizedConversionCastOp op : castOps) {
3400
+ if (liveOps.contains (op)) {
3401
+ // Op is alive and was not erased. Add it to the remaining cast ops.
3402
+ if (remainingCastOps)
3372
3403
remainingCastOps->push_back (op);
3404
+ continue ;
3405
+ }
3406
+
3407
+ // Op is dead. Erase it.
3408
+ op->dropAllUses ();
3409
+ op->erase ();
3410
+ }
3373
3411
}
3374
3412
3375
3413
// ===----------------------------------------------------------------------===//
0 commit comments