diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp index 30fb51725dcb0..5cecc69285bea 100644 --- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp @@ -90,7 +90,7 @@ static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) { } template -static void rewriteOp(Operation *op, OpBuilder &b) { +static bool rewriteOp(Operation *op, OpBuilder &b) { if (isa(op)) { OpBuilder::InsertionGuard guard(b); b.setInsertionPoint(op); @@ -98,28 +98,31 @@ static void rewriteOp(Operation *op, OpBuilder &b) { op->getOperands(), op->getAttrs()); op->replaceAllUsesWith(newOp->getResults()); op->erase(); + return true; } + return false; } -static void rewriteCmpI(Operation *op, OpBuilder &b) { +static bool rewriteCmpI(Operation *op, OpBuilder &b) { if (auto cmpOp = dyn_cast(op)) { cmpOp.setPredicateAttr(CmpIPredicateAttr::get( b.getContext(), toUnsignedPred(cmpOp.getPredicate()))); + return true; } + return false; } static void rewrite(Operation *root, const OpList &toReplace) { OpBuilder b(root->getContext()); b.setInsertionPoint(root); for (Operation *op : toReplace) { - rewriteOp(op, b); - rewriteOp(op, b); - rewriteOp(op, b); - rewriteOp(op, b); - rewriteOp(op, b); - rewriteOp(op, b); - rewriteOp(op, b); - rewriteCmpI(op, b); + rewriteOp(op, b) || + rewriteOp(op, b) || + rewriteOp(op, b) || + rewriteOp(op, b) || + rewriteOp(op, b) || + rewriteOp(op, b) || + rewriteOp(op, b) || rewriteCmpI(op, b); } }