@@ -3100,7 +3100,6 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
3100
3100
// ===----------------------------------------------------------------------===//
3101
3101
// OperationConverter
3102
3102
// ===----------------------------------------------------------------------===//
3103
-
3104
3103
namespace {
3105
3104
enum OpConversionMode {
3106
3105
// / In this mode, the conversion will ignore failed conversions to allow
@@ -3118,13 +3117,6 @@ enum OpConversionMode {
3118
3117
} // namespace
3119
3118
3120
3119
namespace mlir {
3121
-
3122
- // Predeclaration only.
3123
- static void reconcileUnrealizedCasts (
3124
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3125
- &castOps,
3126
- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
3127
-
3128
3120
// This class converts operations to a given conversion target via a set of
3129
3121
// rewrite patterns. The conversion behaves differently depending on the
3130
3122
// conversion mode.
@@ -3272,13 +3264,18 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3272
3264
// After a successful conversion, apply rewrites.
3273
3265
rewriterImpl.applyRewrites ();
3274
3266
3275
- // Reconcile all UnrealizedConversionCastOps that were inserted by the
3276
- // dialect conversion frameworks. (Not the ones that were inserted by
3277
- // patterns.)
3267
+ // Gather all unresolved materializations.
3268
+ SmallVector<UnrealizedConversionCastOp> allCastOps;
3278
3269
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3279
3270
&materializations = rewriterImpl.unresolvedMaterializations ;
3271
+ for (auto it : materializations)
3272
+ allCastOps.push_back (it.first );
3273
+
3274
+ // Reconcile all UnrealizedConversionCastOps that were inserted by the
3275
+ // dialect conversion frameworks. (Not the one that were inserted by
3276
+ // patterns.)
3280
3277
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
3281
- reconcileUnrealizedCasts (materializations , &remainingCastOps);
3278
+ reconcileUnrealizedCasts (allCastOps , &remainingCastOps);
3282
3279
3283
3280
// Drop markers.
3284
3281
for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3306,19 +3303,20 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
3306
3303
// Reconcile Unrealized Casts
3307
3304
// ===----------------------------------------------------------------------===//
3308
3305
3309
- // / Try to reconcile all given UnrealizedConversionCastOps and store the
3310
- // / left-over ops in `remainingCastOps` (if provided). See documentation in
3311
- // / DialectConversion.h for more details.
3312
- // / The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
3313
- // / algorithm may visit an operand (or user) which is a cast op, but will not
3314
- // / try to reconcile it if not in the filtered set.
3315
- template <typename RangeT>
3316
- static void reconcileUnrealizedCastsImpl (
3317
- RangeT castOps,
3318
- function_ref<bool (UnrealizedConversionCastOp)> isCastOpOfInterestFn,
3306
+ void mlir::reconcileUnrealizedCasts (
3307
+ ArrayRef<UnrealizedConversionCastOp> castOps,
3319
3308
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3320
- // A worklist of cast ops to process.
3321
3309
SetVector<UnrealizedConversionCastOp> worklist (llvm::from_range, castOps);
3310
+ // This set is maintained only if `remainingCastOps` is provided.
3311
+ DenseSet<Operation *> erasedOps;
3312
+
3313
+ // Helper function that adds all operands to the worklist that are an
3314
+ // unrealized_conversion_cast op result.
3315
+ auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
3316
+ for (Value v : castOp.getInputs ())
3317
+ if (auto inputCastOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3318
+ worklist.insert (inputCastOp);
3319
+ };
3322
3320
3323
3321
// Helper function that return the unrealized_conversion_cast op that
3324
3322
// defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3339,110 +3337,39 @@ static void reconcileUnrealizedCastsImpl(
3339
3337
// Process ops in the worklist bottom-to-top.
3340
3338
while (!worklist.empty ()) {
3341
3339
UnrealizedConversionCastOp castOp = worklist.pop_back_val ();
3340
+ if (castOp->use_empty ()) {
3341
+ // DCE: If the op has no users, erase it. Add the operands to the
3342
+ // worklist to find additional DCE opportunities.
3343
+ enqueueOperands (castOp);
3344
+ if (remainingCastOps)
3345
+ erasedOps.insert (castOp.getOperation ());
3346
+ castOp->erase ();
3347
+ continue ;
3348
+ }
3342
3349
3343
3350
// Traverse the chain of input cast ops to see if an op with the same
3344
3351
// input types can be found.
3345
3352
UnrealizedConversionCastOp nextCast = castOp;
3346
3353
while (nextCast) {
3347
3354
if (nextCast.getInputs ().getTypes () == castOp.getResultTypes ()) {
3348
- if (llvm::any_of (nextCast.getInputs (), [&](Value v) {
3349
- return v.getDefiningOp () == castOp;
3350
- })) {
3351
- // Ran into a cycle.
3352
- break ;
3353
- }
3354
-
3355
3355
// Found a cast where the input types match the output types of the
3356
- // matched op. We can directly use those inputs.
3356
+ // matched op. We can directly use those inputs and the matched op can
3357
+ // be removed.
3358
+ enqueueOperands (castOp);
3357
3359
castOp.replaceAllUsesWith (nextCast.getInputs ());
3360
+ if (remainingCastOps)
3361
+ erasedOps.insert (castOp.getOperation ());
3362
+ castOp->erase ();
3358
3363
break ;
3359
3364
}
3360
3365
nextCast = getInputCast (nextCast);
3361
3366
}
3362
3367
}
3363
3368
3364
- // A set of all alive cast ops. I.e., ops whose results are (transitively)
3365
- // used by an op that is not a cast op.
3366
- DenseSet<Operation *> liveOps;
3367
-
3368
- // Helper function that marks the given op and transitively reachable input
3369
- // cast ops as alive.
3370
- auto markOpLive = [&](Operation *rootOp) {
3371
- SmallVector<Operation *> worklist;
3372
- worklist.push_back (rootOp);
3373
- while (!worklist.empty ()) {
3374
- Operation *op = worklist.pop_back_val ();
3375
- if (liveOps.insert (op).second ) {
3376
- // Successfully inserted: process reachable input cast ops.
3377
- for (Value v : op->getOperands ())
3378
- if (auto castOp = v.getDefiningOp <UnrealizedConversionCastOp>())
3379
- if (isCastOpOfInterestFn (castOp))
3380
- worklist.push_back (castOp);
3381
- }
3382
- }
3383
- };
3384
-
3385
- // Find all alive cast ops.
3386
- for (UnrealizedConversionCastOp op : castOps) {
3387
- // The op may have been marked live already as being an operand of another
3388
- // live cast op.
3389
- if (liveOps.contains (op.getOperation ()))
3390
- continue ;
3391
- // If any of the users is not a cast op, mark the current op (and its
3392
- // input ops) as live.
3393
- if (llvm::any_of (op->getUsers (), [&](Operation *user) {
3394
- auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
3395
- return !castOp || !isCastOpOfInterestFn (castOp);
3396
- }))
3397
- markOpLive (op);
3398
- }
3399
-
3400
- // Erase all dead cast ops.
3401
- for (UnrealizedConversionCastOp op : castOps) {
3402
- if (liveOps.contains (op)) {
3403
- // Op is alive and was not erased. Add it to the remaining cast ops.
3404
- if (remainingCastOps)
3369
+ if (remainingCastOps)
3370
+ for (UnrealizedConversionCastOp op : castOps)
3371
+ if (!erasedOps.contains (op.getOperation ()))
3405
3372
remainingCastOps->push_back (op);
3406
- continue ;
3407
- }
3408
-
3409
- // Op is dead. Erase it.
3410
- op->dropAllUses ();
3411
- op->erase ();
3412
- }
3413
- }
3414
-
3415
- void mlir::reconcileUnrealizedCasts (
3416
- ArrayRef<UnrealizedConversionCastOp> castOps,
3417
- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3418
- // Set of all cast ops for faster lookups.
3419
- DenseSet<UnrealizedConversionCastOp> castOpSet;
3420
- for (UnrealizedConversionCastOp op : castOps)
3421
- castOpSet.insert (op);
3422
- reconcileUnrealizedCasts (castOpSet, remainingCastOps);
3423
- }
3424
-
3425
- void mlir::reconcileUnrealizedCasts (
3426
- const DenseSet<UnrealizedConversionCastOp> &castOps,
3427
- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3428
- reconcileUnrealizedCastsImpl (
3429
- llvm::make_range (castOps.begin (), castOps.end ()),
3430
- [&](UnrealizedConversionCastOp castOp) {
3431
- return castOps.contains (castOp);
3432
- },
3433
- remainingCastOps);
3434
- }
3435
-
3436
- static void mlir::reconcileUnrealizedCasts (
3437
- const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
3438
- &castOps,
3439
- SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
3440
- reconcileUnrealizedCastsImpl (
3441
- castOps.keys (),
3442
- [&](UnrealizedConversionCastOp castOp) {
3443
- return castOps.contains (castOp);
3444
- },
3445
- remainingCastOps);
3446
3373
}
3447
3374
3448
3375
// ===----------------------------------------------------------------------===//
0 commit comments