132 changes: 53 additions & 79 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6245,37 +6245,30 @@ static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
}

/// Combine constant ranges from computeConstantRange() and computeKnownBits().
static ConstantRange computeConstantRangeIncludingKnownBits(
const Value *V, bool ForSigned, const DataLayout &DL, AssumptionCache *AC,
const Instruction *CxtI, const DominatorTree *DT,
bool UseInstrInfo = true) {
KnownBits Known =
computeKnownBits(V, DL, /*Depth=*/0, AC, CxtI, DT, UseInstrInfo);
static ConstantRange
computeConstantRangeIncludingKnownBits(const Value *V, bool ForSigned,
const SimplifyQuery &SQ) {
KnownBits Known = ::computeKnownBits(V, /*Depth=*/0, SQ);
ConstantRange CR1 = ConstantRange::fromKnownBits(Known, ForSigned);
ConstantRange CR2 = computeConstantRange(V, ForSigned, UseInstrInfo);
ConstantRange CR2 = computeConstantRange(V, ForSigned, SQ.IIQ.UseInstrInfo);
ConstantRange::PreferredRangeType RangeType =
ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
return CR1.intersectWith(CR2, RangeType);
}

OverflowResult llvm::computeOverflowForUnsignedMul(
const Value *LHS, const Value *RHS, const DataLayout &DL,
AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
bool UseInstrInfo) {
KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
UseInstrInfo);
KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
UseInstrInfo);
OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
ConstantRange LHSRange = ConstantRange::fromKnownBits(LHSKnown, false);
ConstantRange RHSRange = ConstantRange::fromKnownBits(RHSKnown, false);
return mapOverflowResult(LHSRange.unsignedMulMayOverflow(RHSRange));
}

OverflowResult
llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
const DataLayout &DL, AssumptionCache *AC,
const Instruction *CxtI,
const DominatorTree *DT, bool UseInstrInfo) {
OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
// Multiplying n * m significant bits yields a result of n + m significant
// bits. If the total number of significant bits does not exceed the
// result bit width (minus 1), there is no overflow.
Expand All @@ -6286,8 +6279,8 @@ llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,

// Note that underestimating the number of sign bits gives a more
// conservative answer.
unsigned SignBits = ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) +
ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT);
unsigned SignBits =
::ComputeNumSignBits(LHS, 0, SQ) + ::ComputeNumSignBits(RHS, 0, SQ);

// First handle the easy case: if we have enough sign bits there's
// definitely no overflow.
Expand All @@ -6304,34 +6297,28 @@ llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS,
// product is exactly the minimum negative number.
// E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
// For simplicity we just check if at least one side is not negative.
KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT,
UseInstrInfo);
KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT,
UseInstrInfo);
KnownBits LHSKnown = ::computeKnownBits(LHS, /*Depth=*/0, SQ);
KnownBits RHSKnown = ::computeKnownBits(RHS, /*Depth=*/0, SQ);
if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
return OverflowResult::NeverOverflows;
}
return OverflowResult::MayOverflow;
}

OverflowResult llvm::computeOverflowForUnsignedAdd(
const Value *LHS, const Value *RHS, const DataLayout &DL,
AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT,
bool UseInstrInfo) {
ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
LHS, /*ForSigned=*/false, DL, AC, CxtI, DT, UseInstrInfo);
ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
RHS, /*ForSigned=*/false, DL, AC, CxtI, DT, UseInstrInfo);
OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS,
const Value *RHS,
const SimplifyQuery &SQ) {
ConstantRange LHSRange =
computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
ConstantRange RHSRange =
computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
return mapOverflowResult(LHSRange.unsignedAddMayOverflow(RHSRange));
}

static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
const Value *RHS,
const AddOperator *Add,
const DataLayout &DL,
AssumptionCache *AC,
const Instruction *CxtI,
const DominatorTree *DT) {
const SimplifyQuery &SQ) {
if (Add && Add->hasNoSignedWrap()) {
return OverflowResult::NeverOverflows;
}
Expand All @@ -6350,14 +6337,14 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
//
// Since the carry into the most significant position is always equal to
// the carry out of the addition, there is no signed overflow.
if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 &&
ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1)
if (::ComputeNumSignBits(LHS, 0, SQ) > 1 &&
::ComputeNumSignBits(RHS, 0, SQ) > 1)
return OverflowResult::NeverOverflows;

ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
LHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
RHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
ConstantRange LHSRange =
computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
ConstantRange RHSRange =
computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
OverflowResult OR =
mapOverflowResult(LHSRange.signedAddMayOverflow(RHSRange));
if (OR != OverflowResult::MayOverflow)
Expand All @@ -6378,8 +6365,7 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,
(LHSRange.isAllNegative() || RHSRange.isAllNegative());
if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
KnownBits AddKnown(LHSRange.getBitWidth());
computeKnownBitsFromAssume(Add, AddKnown, /*Depth=*/0,
SimplifyQuery(DL, DT, AC, CxtI, DT));
computeKnownBitsFromAssume(Add, AddKnown, /*Depth=*/0, SQ);
if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
(AddKnown.isNegative() && LHSOrRHSKnownNegative))
return OverflowResult::NeverOverflows;
Expand All @@ -6390,10 +6376,7 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS,

OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
const Value *RHS,
const DataLayout &DL,
AssumptionCache *AC,
const Instruction *CxtI,
const DominatorTree *DT) {
const SimplifyQuery &SQ) {
// X - (X % ?)
// The remainder of a value can't have greater magnitude than itself,
// so the subtraction can't overflow.
Expand All @@ -6407,32 +6390,29 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
// See simplifyICmpWithBinOpOnLHS() for candidates.
if (match(RHS, m_URem(m_Specific(LHS), m_Value())) ||
match(RHS, m_NUWSub(m_Specific(LHS), m_Value())))
if (isGuaranteedNotToBeUndefOrPoison(LHS, AC, CxtI, DT))
if (isGuaranteedNotToBeUndefOrPoison(LHS, SQ.AC, SQ.CxtI, SQ.DT))
return OverflowResult::NeverOverflows;

// Checking for conditions implied by dominating conditions may be expensive.
// Limit it to usub_with_overflow calls for now.
if (match(CxtI,
if (match(SQ.CxtI,
m_Intrinsic<Intrinsic::usub_with_overflow>(m_Value(), m_Value())))
if (auto C =
isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, CxtI, DL)) {
if (auto C = isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, SQ.CxtI,
SQ.DL)) {
if (*C)
return OverflowResult::NeverOverflows;
return OverflowResult::AlwaysOverflowsLow;
}
ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
LHS, /*ForSigned=*/false, DL, AC, CxtI, DT);
ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
RHS, /*ForSigned=*/false, DL, AC, CxtI, DT);
ConstantRange LHSRange =
computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/false, SQ);
ConstantRange RHSRange =
computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/false, SQ);
return mapOverflowResult(LHSRange.unsignedSubMayOverflow(RHSRange));
}

OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
const Value *RHS,
const DataLayout &DL,
AssumptionCache *AC,
const Instruction *CxtI,
const DominatorTree *DT) {
const SimplifyQuery &SQ) {
// X - (X % ?)
// The remainder of a value can't have greater magnitude than itself,
// so the subtraction can't overflow.
Expand All @@ -6443,19 +6423,19 @@ OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
// then determining no-overflow may allow other transforms.
if (match(RHS, m_SRem(m_Specific(LHS), m_Value())) ||
match(RHS, m_NSWSub(m_Specific(LHS), m_Value())))
if (isGuaranteedNotToBeUndefOrPoison(LHS, AC, CxtI, DT))
if (isGuaranteedNotToBeUndefOrPoison(LHS, SQ.AC, SQ.CxtI, SQ.DT))
return OverflowResult::NeverOverflows;

// If LHS and RHS each have at least two sign bits, the subtraction
// cannot overflow.
if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 &&
ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1)
if (::ComputeNumSignBits(LHS, 0, SQ) > 1 &&
::ComputeNumSignBits(RHS, 0, SQ) > 1)
return OverflowResult::NeverOverflows;

ConstantRange LHSRange = computeConstantRangeIncludingKnownBits(
LHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
ConstantRange RHSRange = computeConstantRangeIncludingKnownBits(
RHS, /*ForSigned=*/true, DL, AC, CxtI, DT);
ConstantRange LHSRange =
computeConstantRangeIncludingKnownBits(LHS, /*ForSigned=*/true, SQ);
ConstantRange RHSRange =
computeConstantRangeIncludingKnownBits(RHS, /*ForSigned=*/true, SQ);
return mapOverflowResult(LHSRange.signedSubMayOverflow(RHSRange));
}

Expand Down Expand Up @@ -6949,21 +6929,15 @@ bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
}

OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
const DataLayout &DL,
AssumptionCache *AC,
const Instruction *CxtI,
const DominatorTree *DT) {
const SimplifyQuery &SQ) {
return ::computeOverflowForSignedAdd(Add->getOperand(0), Add->getOperand(1),
Add, DL, AC, CxtI, DT);
Add, SQ);
}

OverflowResult llvm::computeOverflowForSignedAdd(const Value *LHS,
const Value *RHS,
const DataLayout &DL,
AssumptionCache *AC,
const Instruction *CxtI,
const DominatorTree *DT) {
return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, DL, AC, CxtI, DT);
const SimplifyQuery &SQ) {
return ::computeOverflowForSignedAdd(LHS, RHS, nullptr, SQ);
}

bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
Expand Down
14 changes: 8 additions & 6 deletions llvm/lib/Transforms/Scalar/LICM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2540,8 +2540,9 @@ static bool hoistAdd(ICmpInst::Predicate Pred, Value *VariantLHS,
// we want to avoid this.
auto &DL = L.getHeader()->getModule()->getDataLayout();
bool ProvedNoOverflowAfterReassociate =
computeOverflowForSignedSub(InvariantRHS, InvariantOp, DL, AC, &ICmp,
DT) == llvm::OverflowResult::NeverOverflows;
computeOverflowForSignedSub(InvariantRHS, InvariantOp,
SimplifyQuery(DL, DT, AC, &ICmp)) ==
llvm::OverflowResult::NeverOverflows;
if (!ProvedNoOverflowAfterReassociate)
return false;
auto *Preheader = L.getLoopPreheader();
Expand Down Expand Up @@ -2591,15 +2592,16 @@ static bool hoistSub(ICmpInst::Predicate Pred, Value *VariantLHS,
// we want to avoid this. Likewise, for "C1 - LV < C2" we need to prove that
// "C1 - C2" does not overflow.
auto &DL = L.getHeader()->getModule()->getDataLayout();
SimplifyQuery SQ(DL, DT, AC, &ICmp);
if (VariantSubtracted) {
// C1 - LV < C2 --> LV > C1 - C2
if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, DL, AC, &ICmp,
DT) != llvm::OverflowResult::NeverOverflows)
if (computeOverflowForSignedSub(InvariantOp, InvariantRHS, SQ) !=
llvm::OverflowResult::NeverOverflows)
return false;
} else {
// LV - C1 < C2 --> LV < C1 + C2
if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, DL, AC, &ICmp,
DT) != llvm::OverflowResult::NeverOverflows)
if (computeOverflowForSignedAdd(InvariantOp, InvariantRHS, SQ) !=
llvm::OverflowResult::NeverOverflows)
return false;
}
auto *Preheader = L.getLoopPreheader();
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Scalar/LoopFlatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,9 @@ static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT,
// Check if the multiply could not overflow due to known ranges of the
// input values.
OverflowResult OR = computeOverflowForUnsignedMul(
FI.InnerTripCount, FI.OuterTripCount, DL, AC,
FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
FI.InnerTripCount, FI.OuterTripCount,
SimplifyQuery(DL, DT, AC,
FI.OuterLoop->getLoopPreheader()->getTerminator()));
if (OR != OverflowResult::MayOverflow)
return OR;

Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Scalar/NaryReassociate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
// If the I-th index needs sext and the underlying add is not equipped with
// nsw, we cannot split the add because
// sext(LHS + RHS) != sext(LHS) + sext(RHS).
SimplifyQuery SQ(*DL, DT, AC, GEP);
if (requiresSignExtension(IndexToSplit, GEP) &&
computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) !=
OverflowResult::NeverOverflows)
computeOverflowForSignedAdd(AO, SQ) != OverflowResult::NeverOverflows)
return nullptr;

Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);
Expand Down