diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 4061dae83c10f..502c91d33df2c 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -1028,33 +1028,43 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q, // Make sure that a constant is not the minimum signed value because taking // the abs() of that is undefined. Type *Ty = X->getType(); - const APInt *C; - if (match(X, m_APInt(C)) && !C->isMinSignedValue()) { - // Is the variable divisor magnitude always greater than the constant - // dividend magnitude? - // |Y| > |C| --> Y < -abs(C) or Y > abs(C) - Constant *PosDividendC = ConstantInt::get(Ty, C->abs()); - Constant *NegDividendC = ConstantInt::get(Ty, -C->abs()); + + // Is the variable divisor magnitude always greater than the constant + // dividend magnitude? + // |Y| > |C| --> Y < -abs(C) or Y > abs(C) + auto CheckSignCmp = [Ty, Y, Q, MaxRecurse](const APInt &C) { + if (C.isMinSignedValue()) + return false; + Constant *PosDividendC = ConstantInt::get(Ty, C.abs()); + Constant *NegDividendC = ConstantInt::get(Ty, -C.abs()); if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) || isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse)) return true; - } - if (match(Y, m_APInt(C))) { + return false; + }; + + auto CheckSignCmpY = [Ty, X, Y, Q, MaxRecurse](const APInt &C) { // Special-case: we can't take the abs() of a minimum signed value. If // that's the divisor, then all we have to do is prove that the dividend // is also not the minimum signed value. - if (C->isMinSignedValue()) + if (C.isMinSignedValue()) return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse); // Is the variable dividend magnitude always less than the constant // divisor magnitude? // |X| < |C| --> X > -abs(C) and X < abs(C) - Constant *PosDivisorC = ConstantInt::get(Ty, C->abs()); - Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs()); - if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) && - isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse)) + Constant *PosDividendC = ConstantInt::get(Ty, C.abs()); + Constant *NegDividendC = ConstantInt::get(Ty, -C.abs()); + if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) || + isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse)) return true; - } + return false; + }; + + if (match(X, m_CheckedInt(CheckSignCmp))) + return true; + if (match(Y, m_CheckedInt(CheckSignCmpY))) + return true; return false; } @@ -1063,9 +1073,11 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q, // Is the unsigned dividend known to be less than a constant divisor? // TODO: Convert this (and above) to range analysis // ("computeConstantRangeIncludingKnownBits")? - const APInt *C; - if (match(Y, m_APInt(C)) && - computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C)) + + auto CheckULT1 = [X, Q](const APInt &C) { + return computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(C); + }; + if (match(Y, m_CheckedInt(CheckULT1))) return true; // Try again for any divisor: @@ -2362,15 +2374,16 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // (-1 << X) | (-1 >> (C - X)) --> -1 // (-1 >> X) | (-1 << (C - X)) --> -1 // ...with C <= bitwidth (and commuted variants). - Value *X, *Y; + Value *X = nullptr, *Y = nullptr; + auto CheckULE = [X](const APInt &C) { + return C.ule(X->getType()->getScalarSizeInBits()); + }; if ((match(Op0, m_Shl(m_AllOnes(), m_Value(X))) && match(Op1, m_LShr(m_AllOnes(), m_Value(Y)))) || (match(Op1, m_Shl(m_AllOnes(), m_Value(X))) && match(Op0, m_LShr(m_AllOnes(), m_Value(Y))))) { - const APInt *C; - if ((match(X, m_Sub(m_APInt(C), m_Specific(Y))) || - match(Y, m_Sub(m_APInt(C), m_Specific(X)))) && - C->ule(X->getType()->getScalarSizeInBits())) { + if (match(X, m_Sub(m_CheckedInt(CheckULE), m_Specific(Y))) || + match(Y, m_Sub(m_CheckedInt(CheckULE), m_Specific(X)))) { return ConstantInt::getAllOnesValue(X->getType()); } } @@ -3158,9 +3171,10 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, // x udiv C >=u x --> false for C != 1. // x udiv C == x --> false for C != 1. // TODO: allow non-constant shift amount/divisor - const APInt *C; - if ((match(LBO, m_LShr(m_Specific(RHS), m_APInt(C))) && *C != 0) || - (match(LBO, m_UDiv(m_Specific(RHS), m_APInt(C))) && *C != 1)) { + auto IsNotZero = [](const APInt &C) { return C != 0; }; + auto IsNotOne = [](const APInt &C) { return C != 1; }; + if (match(LBO, m_LShr(m_Specific(RHS), m_CheckedInt(IsNotZero))) || + match(LBO, m_UDiv(m_Specific(RHS), m_CheckedInt(IsNotOne)))) { if (isKnownNonZero(RHS, Q)) { switch (Pred) { default: @@ -3203,6 +3217,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, // (sub C, X) == X, C is odd --> false // (sub C, X) != X, C is odd --> true + const APInt *C; if (match(LBO, m_Sub(m_APIntAllowPoison(C), m_Specific(RHS))) && (*C & 1) == 1 && ICmpInst::isEquality(Pred)) return (Pred == ICmpInst::ICMP_EQ) ? getFalse(ITy) : getTrue(ITy); diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 0dbb39d7c8ec4..9a4ae6cdcf825 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -3275,11 +3275,11 @@ static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2, /// the multiplication is nuw or nsw. static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth, const SimplifyQuery &Q) { + auto NotZeroOrOne = [](const APInt &C) { return !C.isZero() && !C.isOne(); }; if (auto *OBO = dyn_cast(V2)) { - const APInt *C; - return match(OBO, m_Mul(m_Specific(V1), m_APInt(C))) && + return match(OBO, m_Mul(m_Specific(V1), m_CheckedInt(NotZeroOrOne))) && (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) && - !C->isZero() && !C->isOne() && isKnownNonZero(V1, Q, Depth + 1); + isKnownNonZero(V1, Q, Depth + 1); } return false; } @@ -3288,11 +3288,11 @@ static bool isNonEqualMul(const Value *V1, const Value *V2, unsigned Depth, /// the shift is nuw or nsw. static bool isNonEqualShl(const Value *V1, const Value *V2, unsigned Depth, const SimplifyQuery &Q) { + auto NotZeroOrOne = [](const APInt &C) { return !C.isZero() && !C.isOne(); }; if (auto *OBO = dyn_cast(V2)) { - const APInt *C; - return match(OBO, m_Shl(m_Specific(V1), m_APInt(C))) && + return match(OBO, m_Shl(m_Specific(V1), m_CheckedInt(NotZeroOrOne))) && (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) && - !C->isZero() && isKnownNonZero(V1, Q, Depth + 1); + isKnownNonZero(V1, Q, Depth + 1); } return false; } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index cf4a64ffded2e..b5728b0ca3a04 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -30447,11 +30447,12 @@ static std::pair FindSingleBitChange(Value *V) { Value *BitV = I->getOperand(1); Value *AndOp; - const APInt *AndC; - if (match(BitV, m_c_And(m_Value(AndOp), m_APInt(AndC)))) { - // Read past a shiftmask instruction to find count - if (*AndC == (I->getType()->getPrimitiveSizeInBits() - 1)) - BitV = AndOp; + // Read past a shiftmask instruction to find count + auto IsMask = [&I](const APInt &AndC) { + return AndC == I->getType()->getPrimitiveSizeInBits() - 1; + }; + if (match(BitV, m_c_And(m_Value(AndOp), m_CheckedInt(IsMask)))) { + BitV = AndOp; } return {BitV, BTK}; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 51ac77348ed9e..10964eeb8ba81 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1761,7 +1761,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // zext(ctpop(A) >u/!= 1) + (ctlz(A, true) ^ (BW - 1)) // --> // BW - ctlz(A - 1, false) - const APInt *XorC; + auto CheckBW = [A](const APInt &XorC) { + return XorC == A->getType()->getScalarSizeInBits() - 1; + }; if (match(&I, m_c_Add( m_ZExt(m_ICmp(Pred, m_Intrinsic(m_Value(A)), @@ -1769,9 +1771,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { m_OneUse(m_ZExtOrSelf(m_OneUse(m_Xor( m_OneUse(m_TruncOrSelf(m_OneUse( m_Intrinsic(m_Deferred(A), m_One())))), - m_APInt(XorC))))))) && - (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE) && - *XorC == A->getType()->getScalarSizeInBits() - 1) { + m_CheckedInt(CheckBW))))))) && + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE)) { Value *Sub = Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType())); Value *Ctlz = Builder.CreateIntrinsic(Intrinsic::ctlz, {A->getType()}, {Sub, Builder.getFalse()}); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 11e31877de38c..bc6c9fd7deeaf 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -764,8 +764,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { } { - const APInt *C; - if (match(Src, m_Shl(m_APInt(C), m_Value(X))) && (*C)[0] == 1) { + auto CheckOdd = [](const APInt &C) { return (C)[0] == 1; }; + if (match(Src, m_Shl(m_CheckedInt(CheckOdd), m_Value(X)))) { // trunc (C << X) to i1 --> X == 0, where C is odd return new ICmpInst(ICmpInst::Predicate::ICMP_EQ, X, Zero); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 6739b8745d74e..2fbcf29c20d53 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -336,9 +336,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the RHS is a constant, see if we can change it. Don't alter a -1 // constant because that's a canonical 'not' op, and that is better for // combining, SCEV, and codegen. - const APInt *C; - if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) { - if ((*C | ~DemandedMask).isAllOnes()) { + auto IsNotAllOnes = [](const APInt &C) { return C.isAllOnes(); }; + auto IsNotAllOnesAndDemandedMask = [&DemandedMask](const APInt &C) { + return (C | ~DemandedMask).isAllOnes(); + }; + if (match(I->getOperand(1), m_CheckedInt(IsNotAllOnes))) { + if (match(I->getOperand(1), m_CheckedInt(IsNotAllOnesAndDemandedMask))) { // Force bits to 1 to create a 'not' op. I->setOperand(1, ConstantInt::getAllOnesValue(VTy)); return I; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 99f1f8eb34bb5..7c23a22d717b9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -2071,8 +2071,10 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { } case Instruction::Or: { // or X, C --> add X, C (when X and C have no common bits set) - const APInt *C; - if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL)) + auto CheckMaskedValIsZero = [BO0, DL](const APInt &C) { + return MaskedValueIsZero(BO0, C, DL); + }; + if (match(BO1, m_CheckedInt(CheckMaskedValIsZero))) return {Instruction::Add, BO0, BO1}; break; }