diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 9f5f65cf3bb2d..c64463da9a16a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2427,6 +2427,11 @@ Instruction *InstCombinerImpl::foldICmpUDivConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Div->getOperand(0); + Value *Y = Div->getOperand(1); + Type *Ty = Div->getType(); + // Fold: icmp pred ([us]div X, C2), C -> range test // Fold this div into the comparison, producing a range check. // Determine, based on the divide type, what the range is being @@ -2434,7 +2439,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // it, otherwise compute the range [low, hi) bounding the new value. // See: InsertRangeTest above for the kinds of replacements possible. const APInt *C2; - if (!match(Div->getOperand(1), m_APInt(C2))) + if (!match(Y, m_APInt(C2))) return nullptr; // FIXME: If the operand types don't match the type of the divide @@ -2467,8 +2472,6 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // instruction that we're folding. bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C; - ICmpInst::Predicate Pred = Cmp.getPredicate(); - // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2; @@ -2483,7 +2486,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, int LoOverflow = 0, HiOverflow = 0; APInt LoBound, HiBound; - if (!DivIsSigned) { // udiv + if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) LoBound = Prod; HiOverflow = LoOverflow = ProdOV; @@ -2498,7 +2501,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, LoBound = -(RangeSize - 1); HiBound = RangeSize; } else if (C.isStrictlyPositive()) { // (X / pos) op pos - LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) + LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); @@ -2518,18 +2521,19 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // e.g. X/-5 op 0 --> [-4, 5) LoBound = RangeSize + 1; HiBound = -RangeSize; - if (HiBound == *C2) { // -INTMIN = INTMIN - HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN + if (HiBound == *C2) { // -INTMIN = INTMIN + HiOverflow = 1; // [INTMIN+1, overflow) + HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } } else if (C.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) - LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; - } else { // (X / neg) op neg - LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) + LoOverflow = + addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1 : 0; + } else { // (X / neg) op neg + LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) HiOverflow = subWithOverflow(HiBound, Prod, RangeSize, true); @@ -2539,54 +2543,47 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, Pred = ICmpInst::getSwappedPredicate(Pred); } - Value *X = Div->getOperand(0); switch (Pred) { - default: llvm_unreachable("Unhandled icmp opcode!"); - case ICmpInst::ICMP_EQ: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(Cmp, Builder.getFalse()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, - ConstantInt::get(Div->getType(), LoBound)); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, - ConstantInt::get(Div->getType(), HiBound)); - return replaceInstUsesWith( - Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); - case ICmpInst::ICMP_NE: - if (LoOverflow && HiOverflow) - return replaceInstUsesWith(Cmp, Builder.getTrue()); - if (HiOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, - ConstantInt::get(Div->getType(), LoBound)); - if (LoOverflow) - return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, - ConstantInt::get(Div->getType(), HiBound)); - return replaceInstUsesWith(Cmp, - insertRangeTest(X, LoBound, HiBound, - DivIsSigned, false)); - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLT: - if (LoOverflow == +1) // Low bound is greater than input range. - return replaceInstUsesWith(Cmp, Builder.getTrue()); - if (LoOverflow == -1) // Low bound is less than input range. - return replaceInstUsesWith(Cmp, Builder.getFalse()); - return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound)); - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGT: - if (HiOverflow == +1) // High bound greater than input range. - return replaceInstUsesWith(Cmp, Builder.getFalse()); - if (HiOverflow == -1) // High bound less than input range. - return replaceInstUsesWith(Cmp, Builder.getTrue()); - if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, - ConstantInt::get(Div->getType(), HiBound)); - return new ICmpInst(ICmpInst::ICMP_SGE, X, - ConstantInt::get(Div->getType(), HiBound)); + default: + llvm_unreachable("Unhandled icmp predicate!"); + case ICmpInst::ICMP_EQ: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder.getFalse()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, + X, ConstantInt::get(Ty, LoBound)); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + X, ConstantInt::get(Ty, HiBound)); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); + case ICmpInst::ICMP_NE: + if (LoOverflow && HiOverflow) + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (HiOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + X, ConstantInt::get(Ty, LoBound)); + if (LoOverflow) + return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, + X, ConstantInt::get(Ty, HiBound)); + return replaceInstUsesWith( + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLT: + if (LoOverflow == +1) // Low bound is greater than input range. + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (LoOverflow == -1) // Low bound is less than input range. + return replaceInstUsesWith(Cmp, Builder.getFalse()); + return new ICmpInst(Pred, X, ConstantInt::get(Ty, LoBound)); + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGT: + if (HiOverflow == +1) // High bound greater than input range. + return replaceInstUsesWith(Cmp, Builder.getFalse()); + if (HiOverflow == -1) // High bound less than input range. + return replaceInstUsesWith(Cmp, Builder.getTrue()); + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, HiBound)); + return new ICmpInst(ICmpInst::ICMP_SGE, X, ConstantInt::get(Ty, HiBound)); } return nullptr;