@@ -3100,6 +3100,7 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
3100
3100
// ===----------------------------------------------------------------------===//
3101
3101
// OperationConverter
3102
3102
// ===----------------------------------------------------------------------===//
3103
+
3103
3104
namespace {
3104
3105
enum OpConversionMode {
3105
3106
// / In this mode, the conversion will ignore failed conversions to allow
@@ -3117,6 +3118,13 @@ enum OpConversionMode {
3117
3118
} // namespace
3118
3119
3119
3120
namespace mlir {
3121
+
3122
+ // Predeclaration only.
3123
+ static void reconcileUnrealizedCasts (
3124
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3125
+ &castOps,
3126
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
3127
+
3120
3128
// This class converts operations to a given conversion target via a set of
3121
3129
// rewrite patterns. The conversion behaves differently depending on the
3122
3130
// conversion mode.
@@ -3264,18 +3272,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3264
3272
// After a successful conversion, apply rewrites.
3265
3273
rewriterImpl.applyRewrites ();
3266
3274
3267
- // Gather all unresolved materializations.
3268
- SmallVector<UnrealizedConversionCastOp> allCastOps;
3269
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3270
- &materializations = rewriterImpl.unresolvedMaterializations ;
3271
- for (auto it : materializations)
3272
- allCastOps.push_back (it.first );
3273
-
3274
3275
// Reconcile all UnrealizedConversionCastOps that were inserted by the
3275
- // dialect conversion frameworks. (Not the one that were inserted by
3276
+ // dialect conversion frameworks. (Not the ones that were inserted by
3276
3277
// patterns.)
3278
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3279
+ &materializations = rewriterImpl.unresolvedMaterializations ;
3277
3280
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3278
- reconcileUnrealizedCasts (allCastOps , &remainingCastOps);
3281
+ reconcileUnrealizedCasts (materializations , &remainingCastOps);
3279
3282
3280
3283
// Drop markers.
3281
3284
for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3303,20 +3306,15 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3303
3306
// Reconcile Unrealized Casts
3304
3307
// ===----------------------------------------------------------------------===//
3305
3308
3306
- void mlir::reconcileUnrealizedCasts (
3307
- ArrayRef<UnrealizedConversionCastOp> castOps,
3309
+ // / Try to reconcile all given UnrealizedConversionCastOps and store the
3310
+ // / left-over ops in `remainingCastOps` (if provided). See documentation in
3311
+ // / DialectConversion.h for more details.
3312
+ template <typename RangeT>
3313
+ static void reconcileUnrealizedCastsImpl (
3314
+ RangeT castOps, function_ref<bool (UnrealizedConversionCastOp)> isCastOpFn,
3308
3315
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3316
+ // A worklist of cast ops to process.
3309
3317
SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3310
- // This set is maintained only if `remainingCastOps` is provided.
3311
- DenseSet<Operation *> erasedOps;
3312
-
3313
- // Helper function that adds all operands to the worklist that are an
3314
- // unrealized_conversion_cast op result.
3315
- auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3316
- for (Value v : castOp.getInputs ())
3317
- if (auto inputCastOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3318
- worklist.insert (inputCastOp);
3319
- };
3320
3318
3321
3319
// Helper function that return the unrealized_conversion_cast op that
3322
3320
// defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3337,39 +3335,106 @@ void mlir::reconcileUnrealizedCasts(
3337
3335
// Process ops in the worklist bottom-to-top.
3338
3336
while (!worklist.empty ()) {
3339
3337
UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3340
- if (castOp->use_empty ()) {
3341
- // DCE: If the op has no users, erase it. Add the operands to the
3342
- // worklist to find additional DCE opportunities.
3343
- enqueueOperands (castOp);
3344
- if (remainingCastOps)
3345
- erasedOps.insert (castOp.getOperation ());
3346
- castOp->erase ();
3347
- continue ;
3348
- }
3349
3338
3350
3339
// Traverse the chain of input cast ops to see if an op with the same
3351
3340
// input types can be found.
3352
3341
UnrealizedConversionCastOp nextCast = castOp;
3353
3342
while (nextCast) {
3354
3343
if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3344
+ if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3345
+ return v.getDefiningOp () == castOp;
3346
+ })) {
3347
+ // Ran into a cycle.
3348
+ break ;
3349
+ }
3350
+
3355
3351
// Found a cast where the input types match the output types of the
3356
- // matched op. We can directly use those inputs and the matched op can
3357
- // be removed.
3358
- enqueueOperands (castOp);
3352
+ // matched op. We can directly use those inputs.
3359
3353
castOp.replaceAllUsesWith (nextCast.getInputs ());
3360
- if (remainingCastOps)
3361
- erasedOps.insert (castOp.getOperation ());
3362
- castOp->erase ();
3363
3354
break ;
3364
3355
}
3365
3356
nextCast = getInputCast (nextCast);
3366
3357
}
3367
3358
}
3368
3359
3369
- if (remainingCastOps)
3370
- for (UnrealizedConversionCastOp op : castOps)
3371
- if (!erasedOps.contains (op.getOperation ()))
3360
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
3361
+ // used by an op that is not a cast op.
3362
+ DenseSet<Operation *> liveOps;
3363
+
3364
+ // Helper function that marks the given op and transitively reachable input
3365
+ // cast ops as alive.
3366
+ auto markOpLive = [&](Operation *op) {
3367
+ SmallVector<Operation *> worklist;
3368
+ worklist.push_back (op);
3369
+ while (!worklist.empty ()) {
3370
+ Operation *op = worklist.pop_back_val ();
3371
+ if (liveOps.insert (op).second ) {
3372
+ // Successfully inserted: process reachable input cast ops.
3373
+ for (Value v : op->getOperands ())
3374
+ if (auto castOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3375
+ if (isCastOpFn (castOp))
3376
+ worklist.push_back (castOp);
3377
+ }
3378
+ }
3379
+ };
3380
+
3381
+ // Find all alive cast ops.
3382
+ for (UnrealizedConversionCastOp op : castOps) {
3383
+ // If any of the users is not a cast op, mark the current op (and its
3384
+ // input ops) as live.
3385
+ if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3386
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3387
+ return !castOp || !isCastOpFn (castOp);
3388
+ }))
3389
+ markOpLive (op);
3390
+ }
3391
+
3392
+ // Erase all dead cast ops.
3393
+ for (UnrealizedConversionCastOp op : castOps) {
3394
+ if (liveOps.contains (op)) {
3395
+ // Op is alive and was not erased. Add it to the remaining cast ops.
3396
+ if (remainingCastOps)
3372
3397
remainingCastOps->push_back (op);
3398
+ continue ;
3399
+ }
3400
+
3401
+ // Op is dead. Erase it.
3402
+ op->dropAllUses ();
3403
+ op->erase ();
3404
+ }
3405
+ }
3406
+
3407
+ void mlir::reconcileUnrealizedCasts (
3408
+ ArrayRef<UnrealizedConversionCastOp> castOps,
3409
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3410
+ // Set of all cast ops for faster lookups.
3411
+ DenseSet<UnrealizedConversionCastOp> castOpSet;
3412
+ for (UnrealizedConversionCastOp op : castOps)
3413
+ castOpSet.insert (op);
3414
+ reconcileUnrealizedCasts (castOpSet, remainingCastOps);
3415
+ }
3416
+
3417
+ void mlir::reconcileUnrealizedCasts (
3418
+ const DenseSet<UnrealizedConversionCastOp> &castOps,
3419
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3420
+ reconcileUnrealizedCastsImpl (
3421
+ llvm::make_range (castOps.begin (), castOps.end ()),
3422
+ [&](UnrealizedConversionCastOp castOp) {
3423
+ return castOps.contains (castOp);
3424
+ },
3425
+ remainingCastOps);
3426
+ }
3427
+
3428
+ static void mlir::reconcileUnrealizedCasts (
3429
+ const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3430
+ &castOps,
3431
+ SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3432
+ reconcileUnrealizedCastsImpl (
3433
+ castOps.keys (),
3434
+ [&](UnrealizedConversionCastOp castOp) {
3435
+ return castOps.contains (castOp);
3436
+ },
3437
+ remainingCastOps);
3373
3438
}
3374
3439
3375
3440
// ===----------------------------------------------------------------------===//
0 commit comments