Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1002,10 +1002,14 @@ class ScalarEvolution {
/// (at every loop iteration). It is, at the same time, the minimum number
/// of times S is divisible by 2. For example, given {4,+,8} it returns 2.
/// If S is guaranteed to be 0, it returns the bitwidth of S.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S);
/// If \p CtxI is not nullptr, return a constant multiple valid at \p CtxI.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S,
const Instruction *CtxI = nullptr);

/// Returns the max constant multiple of S.
LLVM_ABI APInt getConstantMultiple(const SCEV *S);
/// Returns the max constant multiple of S. If \p CtxI is not nullptr, return
/// a constant multiple valid at \p CtxI.
LLVM_ABI APInt getConstantMultiple(const SCEV *S,
const Instruction *CtxI = nullptr);

// Returns the max constant multiple of S. If S is exactly 0, return 1.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S);
Expand Down Expand Up @@ -1525,8 +1529,10 @@ class ScalarEvolution {
/// Return the Value set from which the SCEV expr is generated.
ArrayRef<Value *> getSCEVValues(const SCEV *S);

/// Private helper method for the getConstantMultiple method.
APInt getConstantMultipleImpl(const SCEV *S);
/// Private helper method for the getConstantMultiple method. If \p CtxI is
/// not nullptr, return a constant multiple valid at \p CtxI.
APInt getConstantMultipleImpl(const SCEV *S,
const Instruction *Ctx = nullptr);

/// Information about the number of times a particular loop exit may be
/// reached before exiting the loop.
Expand Down
56 changes: 34 additions & 22 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6351,61 +6351,62 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
return getGEPExpr(GEP, IndexExprs);
}

APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
const Instruction *CtxI) {
uint64_t BitWidth = getTypeSizeInBits(S->getType());
auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
return TrailingZeros >= BitWidth
? APInt::getZero(BitWidth)
: APInt::getOneBitSet(BitWidth, TrailingZeros);
};
auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
// The result is GCD of all operands results.
APInt Res = getConstantMultiple(N->getOperand(0));
APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
Res = APIntOps::GreatestCommonDivisor(
Res, getConstantMultiple(N->getOperand(I)));
Res, getConstantMultiple(N->getOperand(I), CtxI));
return Res;
};

switch (S->getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(S)->getAPInt();
case scPtrToInt:
return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
case scUDivExpr:
case scVScale:
return APInt(BitWidth, 1);
case scTruncate: {
// Only multiples that are a power of 2 will hold after truncation.
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
uint32_t TZ = getMinTrailingZeros(T->getOperand());
uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
return GetShiftedByZeros(TZ);
}
case scZeroExtend: {
const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
return getConstantMultiple(Z->getOperand()).zext(BitWidth);
return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
}
case scSignExtend: {
// Only multiples that are a power of 2 will hold after sext.
const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
uint32_t TZ = getMinTrailingZeros(E->getOperand());
uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
return GetShiftedByZeros(TZ);
}
case scMulExpr: {
const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
if (M->hasNoUnsignedWrap()) {
// The result is the product of all operand results.
APInt Res = getConstantMultiple(M->getOperand(0));
APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
for (const SCEV *Operand : M->operands().drop_front())
Res = Res * getConstantMultiple(Operand);
Res = Res * getConstantMultiple(Operand, CtxI);
return Res;
}

// If there are no wrap guarentees, find the trailing zeros, which is the
// sum of trailing zeros for all its operands.
uint32_t TZ = 0;
for (const SCEV *Operand : M->operands())
TZ += getMinTrailingZeros(Operand);
TZ += getMinTrailingZeros(Operand, CtxI);
return GetShiftedByZeros(TZ);
}
case scAddExpr:
Expand All @@ -6414,9 +6415,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
if (N->hasNoUnsignedWrap())
return GetGCDMultiple(N);
// Find the trailing bits, which is the minimum of its operands.
uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
for (const SCEV *Operand : N->operands().drop_front())
TZ = std::min(TZ, getMinTrailingZeros(Operand));
TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
return GetShiftedByZeros(TZ);
}
case scUMaxExpr:
Expand All @@ -6429,7 +6430,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
// ask ValueTracking for known bits
const SCEVUnknown *U = cast<SCEVUnknown>(S);
unsigned Known =
computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
.countMinTrailingZeros();
return GetShiftedByZeros(Known);
}
Expand All @@ -6439,12 +6440,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
llvm_unreachable("Unknown SCEV kind!");
}

APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
APInt ScalarEvolution::getConstantMultiple(const SCEV *S,
const Instruction *CtxI) {
// Skip looking up and updating the cache if there is a context instruction,
// as the result will only be valid in the specified context.
if (CtxI)
return getConstantMultipleImpl(S, CtxI);

auto I = ConstantMultipleCache.find(S);
if (I != ConstantMultipleCache.end())
return I->second;

APInt Result = getConstantMultipleImpl(S);
APInt Result = getConstantMultipleImpl(S, CtxI);
auto InsertPair = ConstantMultipleCache.insert({S, Result});
assert(InsertPair.second && "Should insert a new key");
return InsertPair.first->second;
Expand All @@ -6455,8 +6462,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
}

uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
return std::min(getConstantMultiple(S).countTrailingZeros(),
uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S,
const Instruction *CtxI) {
return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
(unsigned)getTypeSizeInBits(S->getType()));
}

Expand Down Expand Up @@ -10243,8 +10251,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
static const SCEV *
SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
SmallVectorImpl<const SCEVPredicate *> *Predicates,

ScalarEvolution &SE) {
ScalarEvolution &SE, const Loop *L) {
uint32_t BW = A.getBitWidth();
assert(BW == SE.getTypeSizeInBits(B->getType()));
assert(A != 0 && "A must be non-zero.");
Expand All @@ -10260,7 +10267,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
//
// B is divisible by D if and only if the multiplicity of prime factor 2 for B
// is not less than multiplicity of this prime factor for D.
if (SE.getMinTrailingZeros(B) < Mult2) {
unsigned MinTZ = SE.getMinTrailingZeros(B);
// Try again with the terminator of the loop predecessor for context-specific
// result, if MinTZ s too small.
if (MinTZ < Mult2 && L->getLoopPredecessor())
MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
if (MinTZ < Mult2) {
// Check if we can prove there's no remainder using URem.
const SCEV *URem =
SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
Expand Down Expand Up @@ -10708,7 +10720,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
return getCouldNotCompute();
const SCEV *E = SolveLinEquationWithOverflow(
StepC->getAPInt(), getNegativeSCEV(Start),
AllowPredicates ? &Predicates : nullptr, *this);
AllowPredicates ? &Predicates : nullptr, *this, L);

const SCEV *M = E;
if (E != getCouldNotCompute()) {
Expand Down
40 changes: 12 additions & 28 deletions llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
Original file line number Diff line number Diff line change
Expand Up @@ -615,22 +615,14 @@ define void @test_ptrs_aligned_by_4_via_assumption(ptr %start, ptr %end) {
; CHECK-LABEL: 'test_ptrs_aligned_by_4_via_assumption'
; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_4_via_assumption
; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ]
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_4_via_assumption
; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count.
; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count.
; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
; CHECK-NEXT: Predicates:
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903
; CHECK-NEXT: Predicates:
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
; CHECK-NEXT: Predicates:
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
; CHECK-NEXT: Loop %loop: Trip multiple is 1
;
entry:
call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 4) ]
Expand All @@ -652,22 +644,14 @@ define void @test_ptrs_aligned_by_8_via_assumption(ptr %start, ptr %end) {
; CHECK-LABEL: 'test_ptrs_aligned_by_8_via_assumption'
; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_8_via_assumption
; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ]
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_8_via_assumption
; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count.
; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count.
; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count.
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
; CHECK-NEXT: Predicates:
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903
; CHECK-NEXT: Predicates:
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
; CHECK-NEXT: Predicates:
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
; CHECK-NEXT: Loop %loop: Trip multiple is 1
;
entry:
call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 8) ]
Expand Down