diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp index 5cecc69285bea..f84990d0a8c47 100644 --- a/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/UnsignedWhenEquivalent.cpp @@ -12,40 +12,55 @@ #include "mlir/Analysis/IntRangeAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::arith; -using OpList = llvm::SmallVector; - -/// Returns true when a value is statically non-negative in that it has a lower +/// Succeeds when a value is statically non-negative in that it has a lower /// bound on its value (if it is treated as signed) and that bound is /// non-negative. -static bool staticallyNonNegative(IntRangeAnalysis &analysis, Value v) { +static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis, + Value v) { Optional result = analysis.getResult(v); if (!result.hasValue()) - return false; + return failure(); const ConstantIntRanges &range = result.getValue(); - return (range.smin().isNonNegative()); + return success(range.smin().isNonNegative()); } -/// Identify all operations in a block that have signed equivalents and have -/// operands and results that are statically non-negative. -template -static void getConvertableOps(Operation *root, OpList &toRewrite, - IntRangeAnalysis &analysis) { +/// Succeeds if an op can be converted to its unsigned equivalent without +/// changing its semantics. This is the case when none of its openands or +/// results can be below 0 when analyzed from a signed perspective. +static LogicalResult staticallyNonNegative(IntRangeAnalysis &analysis, + Operation *op) { auto nonNegativePred = [&analysis](Value v) -> bool { - return staticallyNonNegative(analysis, v); + return succeeded(staticallyNonNegative(analysis, v)); }; - root->walk([&nonNegativePred, &toRewrite](Operation *orig) { - if (isa(orig) && - llvm::all_of(orig->getOperands(), nonNegativePred) && - llvm::all_of(orig->getResults(), nonNegativePred)) { - toRewrite.push_back(orig); - } - }); + return success(llvm::all_of(op->getOperands(), nonNegativePred) && + llvm::all_of(op->getResults(), nonNegativePred)); } +/// Succeeds when the comparison predicate is a signed operation and all the +/// operands are non-negative, indicating that the cmpi operation `op` can have +/// its predicate changed to an unsigned equivalent. +static LogicalResult isCmpIConvertable(IntRangeAnalysis &analysis, CmpIOp op) { + CmpIPredicate pred = op.getPredicate(); + switch (pred) { + case CmpIPredicate::sle: + case CmpIPredicate::slt: + case CmpIPredicate::sge: + case CmpIPredicate::sgt: + return success(llvm::all_of(op.getOperands(), [&analysis](Value v) -> bool { + return succeeded(staticallyNonNegative(analysis, v)); + })); + default: + return failure(); + } +} + +/// Return the unsigned equivalent of a signed comparison predicate, +/// or the predicate itself if there is none. static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { switch (pred) { case CmpIPredicate::sle: @@ -61,72 +76,30 @@ static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { } } -/// Find all cmpi ops that can be replaced by their unsigned equivalents. -static void getConvertableCmpi(Operation *root, OpList &toRewrite, - IntRangeAnalysis &analysis) { - auto nonNegativePred = [&analysis](Value v) -> bool { - return staticallyNonNegative(analysis, v); - }; - root->walk([&nonNegativePred, &toRewrite](arith::CmpIOp orig) { - CmpIPredicate pred = orig.getPredicate(); - if (toUnsignedPred(pred) != pred && - // i1 will spuriously and trivially show up as pontentially negative, - // so don't check the results - llvm::all_of(orig->getOperands(), nonNegativePred)) { - toRewrite.push_back(orig.getOperation()); - } - }); -} - -/// Return ops to be replaced in the order they should be rewritten. -static OpList getMatching(Operation *root, IntRangeAnalysis &analysis) { - OpList ret; - getConvertableOps(root, ret, analysis); - // Since these are in-place changes, they don't need to be topological order - // like the others. - getConvertableCmpi(root, ret, analysis); - return ret; -} +namespace { +template +struct ConvertOpToUnsigned : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; -template -static bool rewriteOp(Operation *op, OpBuilder &b) { - if (isa(op)) { - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(op); - Operation *newOp = b.create(op->getLoc(), op->getResultTypes(), - op->getOperands(), op->getAttrs()); - op->replaceAllUsesWith(newOp->getResults()); - op->erase(); - return true; + LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor, + ConversionPatternRewriter &rw) const override { + rw.replaceOpWithNewOp(op, op->getResultTypes(), + adaptor.getOperands(), op->getAttrs()); + return success(); } - return false; -} +}; -static bool rewriteCmpI(Operation *op, OpBuilder &b) { - if (auto cmpOp = dyn_cast(op)) { - cmpOp.setPredicateAttr(CmpIPredicateAttr::get( - b.getContext(), toUnsignedPred(cmpOp.getPredicate()))); - return true; - } - return false; -} +struct ConvertCmpIToUnsigned : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; -static void rewrite(Operation *root, const OpList &toReplace) { - OpBuilder b(root->getContext()); - b.setInsertionPoint(root); - for (Operation *op : toReplace) { - rewriteOp(op, b) || - rewriteOp(op, b) || - rewriteOp(op, b) || - rewriteOp(op, b) || - rewriteOp(op, b) || - rewriteOp(op, b) || - rewriteOp(op, b) || rewriteCmpI(op, b); + LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor, + ConversionPatternRewriter &rw) const override { + rw.replaceOpWithNewOp(op, toUnsignedPred(op.getPredicate()), + op.getLhs(), op.getRhs()); + return success(); } -} +}; -namespace { struct ArithmeticUnsignedWhenEquivalentPass : public ArithmeticUnsignedWhenEquivalentBase< ArithmeticUnsignedWhenEquivalentPass> { @@ -135,8 +108,35 @@ struct ArithmeticUnsignedWhenEquivalentPass /// ensures that analysis results are not invalidated during rewriting. void runOnOperation() override { Operation *op = getOperation(); + MLIRContext *ctx = op->getContext(); IntRangeAnalysis analysis(op); - rewrite(op, getMatching(op, analysis)); + + ConversionTarget target(*ctx); + target.addLegalDialect(); + target + .addDynamicallyLegalOp( + [&analysis](Operation *op) -> Optional { + return failed(staticallyNonNegative(analysis, op)); + }); + target.addDynamicallyLegalOp( + [&analysis](CmpIOp op) -> Optional { + return failed(isCmpIConvertable(analysis, op)); + }); + + RewritePatternSet patterns(ctx); + patterns.add, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, + ConvertOpToUnsigned, ConvertCmpIToUnsigned>( + ctx); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + signalPassFailure(); + } } }; } // end anonymous namespace