@@ -3265,11 +3265,11 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3265
3265
rewriterImpl.applyRewrites ();
3266
3266
3267
3267
// Gather all unresolved materializations.
3268
- SmallVector <UnrealizedConversionCastOp> allCastOps;
3268
+ DenseSet <UnrealizedConversionCastOp> allCastOps;
3269
3269
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3270
3270
&materializations = rewriterImpl.unresolvedMaterializations ;
3271
3271
for (auto it : materializations)
3272
- allCastOps.push_back (it.first );
3272
+ allCastOps.insert (it.first );
3273
3273
3274
3274
// Reconcile all UnrealizedConversionCastOps that were inserted by the
3275
3275
// dialect conversion frameworks. (Not the one that were inserted by
@@ -3306,17 +3306,18 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3306
3306
void mlir::reconcileUnrealizedCasts (
3307
3307
ArrayRef<UnrealizedConversionCastOp> castOps,
3308
3308
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3309
- SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3310
- // This set is maintained only if `remainingCastOps` is provided.
3311
- DenseSet<Operation *> erasedOps;
3309
+ // Set of all cast ops for faster lookups.
3310
+ DenseSet<UnrealizedConversionCastOp> castOpSet;
3311
+ for (UnrealizedConversionCastOp op : castOps)
3312
+ castOpSet.insert (op);
3313
+ reconcileUnrealizedCasts (castOpSet, remainingCastOps);
3314
+ }
3312
3315
3313
- // Helper function that adds all operands to the worklist that are an
3314
- // unrealized_conversion_cast op result.
3315
- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3316
- for (Value v : castOp.getInputs ())
3317
- if (auto inputCastOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3318
- worklist.insert (inputCastOp);
3319
- };
3316
+ void mlir::reconcileUnrealizedCasts (
3317
+ DenseSet<UnrealizedConversionCastOp> castOps,
3318
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3319
+ // A worklist of cast ops to process.
3320
+ SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3320
3321
3321
3322
// Helper function that return the unrealized_conversion_cast op that
3322
3323
// defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3337,39 +3338,74 @@ void mlir::reconcileUnrealizedCasts(
3337
3338
// Process ops in the worklist bottom-to-top.
3338
3339
while (!worklist.empty ()) {
3339
3340
UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3340
- if (castOp->use_empty ()) {
3341
- // DCE: If the op has no users, erase it. Add the operands to the
3342
- // worklist to find additional DCE opportunities.
3343
- enqueueOperands (castOp);
3344
- if (remainingCastOps)
3345
- erasedOps.insert (castOp.getOperation ());
3346
- castOp->erase ();
3347
- continue ;
3348
- }
3349
3341
3350
3342
// Traverse the chain of input cast ops to see if an op with the same
3351
3343
// input types can be found.
3352
3344
UnrealizedConversionCastOp nextCast = castOp;
3353
3345
while (nextCast) {
3354
3346
if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3347
+ if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3348
+ return v.getDefiningOp () == castOp;
3349
+ })) {
3350
+ // Ran into a cycle.
3351
+ break ;
3352
+ }
3353
+
3355
3354
// Found a cast where the input types match the output types of the
3356
- // matched op. We can directly use those inputs and the matched op can
3357
- // be removed.
3358
- enqueueOperands (castOp);
3355
+ // matched op. We can directly use those inputs.
3359
3356
castOp.replaceAllUsesWith (nextCast.getInputs ());
3360
- if (remainingCastOps)
3361
- erasedOps.insert (castOp.getOperation ());
3362
- castOp->erase ();
3363
3357
break ;
3364
3358
}
3365
3359
nextCast = getInputCast (nextCast);
3366
3360
}
3367
3361
}
3368
3362
3369
- if (remainingCastOps)
3370
- for (UnrealizedConversionCastOp op : castOps)
3371
- if (!erasedOps.contains (op.getOperation ()))
3363
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
3364
+ // used by an op that is not a cast op.
3365
+ DenseSet<Operation *> liveOps;
3366
+
3367
+ // Helper function that marks the given op and transitively reachable input
3368
+ // cast ops as alive.
3369
+ auto markOpLive = [&](Operation *op) {
3370
+ SmallVector<Operation *> worklist;
3371
+ worklist.push_back (op);
3372
+ while (!worklist.empty ()) {
3373
+ Operation *op = worklist.pop_back_val ();
3374
+ if (liveOps.insert (op).second ) {
3375
+ // Successfully inserted: the op is live. Add its operands to the
3376
+ // worklist to mark them live.
3377
+ for (Value v : op->getOperands ())
3378
+ if (auto castOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3379
+ if (castOps.contains (castOp))
3380
+ worklist.push_back (castOp);
3381
+ }
3382
+ }
3383
+ };
3384
+
3385
+ // Find all alive cast ops.
3386
+ for (UnrealizedConversionCastOp op : castOps) {
3387
+ // If any of the users is not a cast op, mark the current op (and its
3388
+ // input ops) as live.
3389
+ if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3390
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3391
+ return !castOp || !castOps.contains (castOp);
3392
+ }))
3393
+ markOpLive (op);
3394
+ }
3395
+
3396
+ // Erase all dead cast ops.
3397
+ for (UnrealizedConversionCastOp op : castOps) {
3398
+ if (liveOps.contains (op)) {
3399
+ // Op is alive and was not erased. Add it to the remaining cast ops.
3400
+ if (remainingCastOps)
3372
3401
remainingCastOps->push_back (op);
3402
+ continue ;
3403
+ }
3404
+
3405
+ // Op is dead. Erase it.
3406
+ op->dropAllUses ();
3407
+ op->erase ();
3408
+ }
3373
3409
}
3374
3410
3375
3411
// ===----------------------------------------------------------------------===//
0 commit comments