diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 6168e24569f99..6da6eca8677f8 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -2685,6 +2685,81 @@ m_UAddWithOverflow(const LHS_t &L, const RHS_t &R, const Sum_t &S) { return UAddWithOverflow_match(L, R, S); } +template +struct USubWithOverflow_match { + LHS_t L; + RHS_t R; + Diff_t S; + + USubWithOverflow_match(const LHS_t &L, const RHS_t &R, const Diff_t &S) + : L(L), R(R), S(S) {} + + template bool match(OpTy *V) const { + Value *ICmpLHS = nullptr, *ICmpRHS = nullptr; + CmpPredicate Pred; + if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) + return false; + + Value *SubLHS = nullptr, *SubRHS = nullptr; + auto SubExpr = m_Sub(m_Value(SubLHS), m_Value(SubRHS)); + + Value *AddLHS = nullptr, *AddRHS = nullptr; + auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS)); + + // (a - b) >u a OR (a + (-c)) >u a (allow add-canonicalized forms + // but only where the RHS is a constant APInt that is negative) + if (Pred == ICmpInst::ICMP_UGT) { + if (SubExpr.match(ICmpLHS) && ICmpRHS == SubLHS) + return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS); + + if (AddExpr.match(ICmpLHS)) { + const APInt *AddC = nullptr; + if (m_APInt(AddC).match(AddRHS) && ICmpRHS == AddLHS) { + APInt NegC = -(*AddC); + Constant *NegConst = ConstantInt::get(AddRHS->getType(), NegC); + return L.match(AddLHS) && R.match(NegConst) && S.match(ICmpLHS); + } + } + } + + // a getType(), NegC); + return L.match(AddLHS) && R.match(NegConst) && S.match(ICmpRHS); + } + } + } + + // Special-case for 0 - a != 0 (common canonicalization) + if (Pred == ICmpInst::ICMP_NE) { + // (0 - a) != 0 + if (SubExpr.match(ICmpLHS) && m_Zero().match(ICmpRHS) && + m_Zero().match(SubLHS)) + return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS); + + // 0 != (0 - a) + if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) && + m_Zero().match(SubLHS)) + return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS); + } + + return false; + } +}; + +template +USubWithOverflow_match +m_USubWithOverflow(const LHS_t &L, const RHS_t &R, const Diff_t &S) { + return USubWithOverflow_match(L, R, S); +} + template struct Argument_match { unsigned OpI; Opnd_t Val; diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index d290f202f3cca..cc596aed4cc85 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -1695,19 +1695,23 @@ bool CodeGenPrepare::combineToUAddWithOverflow(CmpInst *Cmp, return true; } -bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp, - ModifyDT &ModifiedDT) { - // We are not expecting non-canonical/degenerate code. Just bail out. +static bool matchUSubWithOverflowConstantEdgeCases(CmpInst *Cmp, + BinaryOperator *&Sub) { + // A - B, A u> B --> usubo(A, B) Value *A = Cmp->getOperand(0), *B = Cmp->getOperand(1); + + // We are not expecting non-canonical/degenerate code. Just bail out. if (isa(A) && isa(B)) return false; - // Convert (A u> B) to (A u< B) to simplify pattern matching. ICmpInst::Predicate Pred = Cmp->getPredicate(); + + // Normalize: convert (A u> B) -> (B u< A) if (Pred == ICmpInst::ICMP_UGT) { std::swap(A, B); Pred = ICmpInst::ICMP_ULT; } + // Convert special-case: (A == 0) is the same as (A u< 1). if (Pred == ICmpInst::ICMP_EQ && match(B, m_ZeroInt())) { B = ConstantInt::get(B->getType(), 1); @@ -1718,19 +1722,22 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp, std::swap(A, B); Pred = ICmpInst::ICMP_ULT; } + if (Pred != ICmpInst::ICMP_ULT) return false; - // Walk the users of a variable operand of a compare looking for a subtract or - // add with that same operand. Also match the 2nd operand of the compare to - // the add/sub, but that may be a negated constant operand of an add. + // Walk the users of the variable operand of the compare looking for a + // subtract or add with that same operand. Also match the 2nd operand of the + // compare to the add/sub, but that may be a negated constant operand of an + // add. Value *CmpVariableOperand = isa(A) ? B : A; - BinaryOperator *Sub = nullptr; + Sub = nullptr; + for (User *U : CmpVariableOperand->users()) { // A - B, A u< B --> usubo(A, B) if (match(U, m_Sub(m_Specific(A), m_Specific(B)))) { Sub = cast(U); - break; + return true; } // A + (-C), A u< C (canonicalized form of (sub A, C)) @@ -1738,19 +1745,42 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp, if (match(U, m_Add(m_Specific(A), m_APInt(AddC))) && match(B, m_APInt(CmpC)) && *AddC == -(*CmpC)) { Sub = cast(U); - break; + return true; } } - if (!Sub) - return false; + return false; +} + +bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp, + ModifyDT &ModifiedDT) { + bool EdgeCase = false; + Value *A = nullptr, *B = nullptr; + BinaryOperator *Sub = nullptr; + + // If the compare already matches the (sub, icmp) pattern use it directly. + if (!match(Cmp, m_USubWithOverflow(m_Value(A), m_Value(B), m_BinOp(Sub)))) { + // Otherwise try to recognize constant-edge-case forms like + // icmp ne (sub 0, B), 0 or + // icmp eq (sub A, 1), 0 + if (!matchUSubWithOverflowConstantEdgeCases(Cmp, Sub)) + return false; + // Set A/B from the discovered Sub and record that this was an edge-case + // match. + A = Sub->getOperand(0); + B = Sub->getOperand(1); + EdgeCase = true; + } + + // Check target wants the overflow intrinsic formed. When matching an + // edge-case we allow forming the intrinsic with fewer uses. if (!TLI->shouldFormOverflowOp(ISD::USUBO, TLI->getValueType(*DL, Sub->getType()), - Sub->hasNUsesOrMore(1))) + Sub->hasNUsesOrMore(EdgeCase ? 1 : 2))) return false; - if (!replaceMathCmpWithIntrinsic(Sub, Sub->getOperand(0), Sub->getOperand(1), - Cmp, Intrinsic::usub_with_overflow)) + if (!replaceMathCmpWithIntrinsic(Sub, A, B, Cmp, + Intrinsic::usub_with_overflow)) return false; // Reset callers - do not crash by iterating over a dead instruction. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index e4cb457499ef5..5c7aae5f91fab 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -7829,6 +7829,23 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } } + Instruction *SubI = nullptr; + if (match(&I, m_USubWithOverflow(m_Value(X), m_Value(Y), + m_Instruction(SubI))) && + isa(X->getType())) { + Value *Result; + Constant *Overflow; + // m_UAddWithOverflow can match patterns that do not include an explicit + // "add" instruction, so check the opcode of the matched op. + if (SubI->getOpcode() == Instruction::Sub && + OptimizeOverflowCheck(Instruction::Sub, /*Signed*/ false, X, Y, *SubI, + Result, Overflow)) { + replaceInstUsesWith(*SubI, Result); + eraseInstFromFunction(*SubI); + return replaceInstUsesWith(I, Overflow); + } + } + // (zext X) * (zext Y) --> llvm.umul.with.overflow. if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) && match(Op1, m_APInt(C))) {