diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 239cd16c5db69..fcb4af8f0fbcd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2053,31 +2053,32 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Type *MulTy = Mul->getType(); + Value *X = Mul->getOperand(0); + // If there's no overflow: // X * X == 0 --> X == 0 // X * X != 0 --> X != 0 - Type *MulTy = Mul->getType(); - if (Cmp.isEquality() && C.isZero() && - Mul->getOperand(0) == Mul->getOperand(1) && + if (Cmp.isEquality() && C.isZero() && X == Mul->getOperand(1) && (Mul->hasNoUnsignedWrap() || Mul->hasNoSignedWrap())) - return new ICmpInst(Cmp.getPredicate(), Mul->getOperand(0), - ConstantInt::getNullValue(MulTy)); + return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); const APInt *MulC; if (!match(Mul->getOperand(1), m_APInt(MulC))) return nullptr; // If this is a test of the sign bit and the multiply is sign-preserving with - // a constant operand, use the multiply LHS operand instead. - ICmpInst::Predicate Pred = Cmp.getPredicate(); + // a constant operand, use the multiply LHS operand instead: + // (X * +MulC) < 0 --> X < 0 + // (X * -MulC) < 0 --> X > 0 if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) { if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); - return new ICmpInst(Pred, Mul->getOperand(0), - Constant::getNullValue(Mul->getType())); + return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); } - if (MulC->isZero() || !(Mul->hasNoSignedWrap() || Mul->hasNoUnsignedWrap())) + if (MulC->isZero() || (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap())) return nullptr; // If the multiply does not wrap, try to divide the compare constant by the @@ -2085,48 +2086,45 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, if (Cmp.isEquality()) { // (mul nsw X, MulC) == C --> X == C /s MulC if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); - return new ICmpInst(Pred, Mul->getOperand(0), NewC); + Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC)); + return new ICmpInst(Pred, X, NewC); } // (mul nuw X, MulC) == C --> X == C /u MulC if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC)); - return new ICmpInst(Pred, Mul->getOperand(0), NewC); + Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); + return new ICmpInst(Pred, X, NewC); } } + // With a matching no-overflow guarantee, fold the constants: + // (X * MulC) < C --> X < (C / MulC) + // (X * MulC) > C --> X > (C / MulC) + // TODO: Assert that Pred is not equal to SGE, SLE, UGE, ULE? Constant *NewC = nullptr; - - // FIXME: Add assert that Pred is not equal to ICMP_SGE, ICMP_SLE, - // ICMP_UGE, ICMP_ULE. - if (Mul->hasNoSignedWrap()) { - if (MulC->isNegative()) { - // MININT / -1 --> overflow. - if (C.isMinSignedValue() && MulC->isAllOnes()) - return nullptr; + // MININT / -1 --> overflow. + if (C.isMinSignedValue() && MulC->isAllOnes()) + return nullptr; + if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); - } + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); + MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); - } else if (Mul->hasNoUnsignedWrap()) { + MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); + } else { + assert(Mul->hasNoUnsignedWrap() && "Expected mul nuw"); if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); + MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); + MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); } - return NewC ? new ICmpInst(Pred, Mul->getOperand(0), NewC) : nullptr; + return NewC ? new ICmpInst(Pred, X, NewC) : nullptr; } /// Fold icmp (shl 1, Y), C.