Skip to content

Commit

Permalink
[InstCombine] Use APInt for all the math in foldICmpDivConstant
Browse files Browse the repository at this point in the history
Summary: This currently uses ConstantExpr to do its math, but as noted in a TODO it can all be done directly on APInt.

Reviewers: spatel, majnemer

Reviewed By: majnemer

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D38440

llvm-svn: 314640
  • Loading branch information
topperc committed Oct 1, 2017
1 parent c20b46d commit 6e025a3
Showing 1 changed file with 46 additions and 95 deletions.
141 changes: 46 additions & 95 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Expand Up @@ -37,77 +37,30 @@ using namespace PatternMatch;
STATISTIC(NumSel, "Number of select opts");


static ConstantInt *extractElement(Constant *V, Constant *Idx) {
return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx));
}

static bool hasAddOverflow(ConstantInt *Result,
ConstantInt *In1, ConstantInt *In2,
bool IsSigned) {
if (!IsSigned)
return Result->getValue().ult(In1->getValue());

if (In2->isNegative())
return Result->getValue().sgt(In1->getValue());
return Result->getValue().slt(In1->getValue());
}

/// Compute Result = In1+In2, returning true if the result overflowed for this
/// type.
static bool addWithOverflow(Constant *&Result, Constant *In1,
Constant *In2, bool IsSigned = false) {
Result = ConstantExpr::getAdd(In1, In2);

if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) {
for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i);
if (hasAddOverflow(extractElement(Result, Idx),
extractElement(In1, Idx),
extractElement(In2, Idx),
IsSigned))
return true;
}
return false;
}

return hasAddOverflow(cast<ConstantInt>(Result),
cast<ConstantInt>(In1), cast<ConstantInt>(In2),
IsSigned);
}

static bool hasSubOverflow(ConstantInt *Result,
ConstantInt *In1, ConstantInt *In2,
bool IsSigned) {
if (!IsSigned)
return Result->getValue().ugt(In1->getValue());

if (In2->isNegative())
return Result->getValue().slt(In1->getValue());
static bool addWithOverflow(APInt &Result, const APInt &In1,
const APInt &In2, bool IsSigned = false) {
bool Overflow;
if (IsSigned)
Result = In1.sadd_ov(In2, Overflow);
else
Result = In1.uadd_ov(In2, Overflow);

return Result->getValue().sgt(In1->getValue());
return Overflow;
}

/// Compute Result = In1-In2, returning true if the result overflowed for this
/// type.
static bool subWithOverflow(Constant *&Result, Constant *In1,
Constant *In2, bool IsSigned = false) {
Result = ConstantExpr::getSub(In1, In2);

if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) {
for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i);
if (hasSubOverflow(extractElement(Result, Idx),
extractElement(In1, Idx),
extractElement(In2, Idx),
IsSigned))
return true;
}
return false;
}
static bool subWithOverflow(APInt &Result, const APInt &In1,
const APInt &In2, bool IsSigned = false) {
bool Overflow;
if (IsSigned)
Result = In1.ssub_ov(In2, Overflow);
else
Result = In1.usub_ov(In2, Overflow);

return hasSubOverflow(cast<ConstantInt>(Result),
cast<ConstantInt>(In1), cast<ConstantInt>(In2),
IsSigned);
return Overflow;
}

/// Given an icmp instruction, return true if any use of this comparison is a
Expand Down Expand Up @@ -2186,28 +2139,22 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
(DivIsSigned && C2->isAllOnesValue()))
return nullptr;

// TODO: We could do all of the computations below using APInt.
Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1));
Constant *DivRHS = cast<Constant>(Div->getOperand(1));

// Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of
// form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS).
// Compute Prod = C * C2. We are essentially solving an equation of
// form X / C2 = C. We solve for X by multiplying C2 and C.
// By solving for X, we can turn this into a range check instead of computing
// a divide.
Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS);
APInt Prod = *C * *C2;

// Determine if the product overflows by seeing if the product is not equal to
// the divide. Make sure we do the same kind of divide as in the LHS
// instruction that we're folding.
bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS)
: ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS;
bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != *C;

ICmpInst::Predicate Pred = Cmp.getPredicate();

// If the division is known to be exact, then there is no remainder from the
// divide, so the covered range size is unit, otherwise it is the divisor.
Constant *RangeSize =
Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS;
APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2;

// Figure out the interval that is being checked. For example, a comparison
// like "X /u 5 == 0" is really checking that X is in the interval [0, 5).
Expand All @@ -2217,7 +2164,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
// overflow variable is set to 0 if it's corresponding bound variable is valid
// -1 if overflowed off the bottom end, or +1 if overflowed off the top end.
int LoOverflow = 0, HiOverflow = 0;
Constant *LoBound = nullptr, *HiBound = nullptr;
APInt LoBound, HiBound;

if (!DivIsSigned) { // udiv
// e.g. X/5 op 3 --> [15, 20)
Expand All @@ -2231,7 +2178,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
} else if (C2->isStrictlyPositive()) { // Divisor is > 0.
if (C->isNullValue()) { // (X / pos) op 0
// Can't overflow. e.g. X/2 op 0 --> [-1, 2)
LoBound = ConstantExpr::getNeg(SubOne(RangeSize));
LoBound = -(RangeSize - 1);
HiBound = RangeSize;
} else if (C->isStrictlyPositive()) { // (X / pos) op pos
LoBound = Prod; // e.g. X/5 op 3 --> [15, 20)
Expand All @@ -2240,27 +2187,27 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true);
} else { // (X / pos) op neg
// e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14)
HiBound = AddOne(Prod);
HiBound = Prod + 1;
LoOverflow = HiOverflow = ProdOV ? -1 : 0;
if (!LoOverflow) {
Constant *DivNeg = ConstantExpr::getNeg(RangeSize);
APInt DivNeg = -RangeSize;
LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0;
}
}
} else if (C2->isNegative()) { // Divisor is < 0.
if (Div->isExact())
RangeSize = ConstantExpr::getNeg(RangeSize);
RangeSize.negate();
if (C->isNullValue()) { // (X / neg) op 0
// e.g. X/-5 op 0 --> [-4, 5)
LoBound = AddOne(RangeSize);
HiBound = ConstantExpr::getNeg(RangeSize);
if (HiBound == DivRHS) { // -INTMIN = INTMIN
LoBound = RangeSize + 1;
HiBound = -RangeSize;
if (HiBound == *C2) { // -INTMIN = INTMIN
HiOverflow = 1; // [INTMIN+1, overflow)
HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN
HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN
}
} else if (C->isStrictlyPositive()) { // (X / neg) op pos
// e.g. X/-5 op 3 --> [-19, -14)
HiBound = AddOne(Prod);
HiBound = Prod + 1;
HiOverflow = LoOverflow = ProdOV ? -1 : 0;
if (!LoOverflow)
LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0;
Expand All @@ -2283,42 +2230,46 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
return replaceInstUsesWith(Cmp, Builder.getFalse());
if (HiOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE :
ICmpInst::ICMP_UGE, X, LoBound);
ICmpInst::ICMP_UGE, X,
ConstantInt::get(Div->getType(), LoBound));
if (LoOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
ICmpInst::ICMP_ULT, X, HiBound);
ICmpInst::ICMP_ULT, X,
ConstantInt::get(Div->getType(), HiBound));
return replaceInstUsesWith(
Cmp, insertRangeTest(X, LoBound->getUniqueInteger(),
HiBound->getUniqueInteger(), DivIsSigned, true));
Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true));
case ICmpInst::ICMP_NE:
if (LoOverflow && HiOverflow)
return replaceInstUsesWith(Cmp, Builder.getTrue());
if (HiOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
ICmpInst::ICMP_ULT, X, LoBound);
ICmpInst::ICMP_ULT, X,
ConstantInt::get(Div->getType(), LoBound));
if (LoOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE :
ICmpInst::ICMP_UGE, X, HiBound);
ICmpInst::ICMP_UGE, X,
ConstantInt::get(Div->getType(), HiBound));
return replaceInstUsesWith(Cmp,
insertRangeTest(X, LoBound->getUniqueInteger(),
HiBound->getUniqueInteger(),
insertRangeTest(X, LoBound, HiBound,
DivIsSigned, false));
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_SLT:
if (LoOverflow == +1) // Low bound is greater than input range.
return replaceInstUsesWith(Cmp, Builder.getTrue());
if (LoOverflow == -1) // Low bound is less than input range.
return replaceInstUsesWith(Cmp, Builder.getFalse());
return new ICmpInst(Pred, X, LoBound);
return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound));
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_SGT:
if (HiOverflow == +1) // High bound greater than input range.
return replaceInstUsesWith(Cmp, Builder.getFalse());
if (HiOverflow == -1) // High bound less than input range.
return replaceInstUsesWith(Cmp, Builder.getTrue());
if (Pred == ICmpInst::ICMP_UGT)
return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound);
return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
return new ICmpInst(ICmpInst::ICMP_UGE, X,
ConstantInt::get(Div->getType(), HiBound));
return new ICmpInst(ICmpInst::ICMP_SGE, X,
ConstantInt::get(Div->getType(), HiBound));
}

return nullptr;
Expand Down

0 comments on commit 6e025a3

Please sign in to comment.