Skip to content

Commit

Permalink
[SCEV][reland] More precise trip multiples
Browse files Browse the repository at this point in the history
We currently have getMinTrailingZeros(), from which we can get a SCEV's
multiple by computing 1 << MinTrailingZeroes. However, this only gets us
multiples that are a power of 2. This patch introduces a way to get max
constant multiples that are not just a power of 2. The logic is similar
to that of getMinTrailingZeros. getMinTrailingZerosImpl is replaced by
computing the max constant multiple, and counting the number of trailing
bits.

I have so far found this useful in two places:

1) Computing unsigned constant ranges. For example, if we have i8
   {10,+,10}<nuw>, we know the max constant it can be is 250.

2) My original intent was to use this in getSmallConstantTripMultiples,
   but it has no effect right now due to change from D110587. For
   example, if we have backedge count `(6 * %N) - 1`, the trip count
   becomes `1 + zext((6 * %N) - 1)`, and we cannot say that 6 is a
   multiple of the SCEV. I plan to look further into this separately.

The implementation assumes the value is unsigned. It can probably be
extended to handle signed values as well.

If the code sees that a SCEV does not have <nuw>, it will fall back to
finding the max multiple that is a power of 2. Multiples that are a
power of 2 will still be a multiple even after the SCEV overflows. This
does not apply to other values. This is the 1st commit message:

---

This relands https://reviews.llvm.org/D141823. The verification fails
when expensive checks are turned on. This can occur when:

1. SCEV S's multiple is cached
2. SCEV S's no wrap flags are strengthened, and the multiple changes
3. SCEV verifier finds that S's cached and recomputed multiple are
   different

We eliminate most cases by forgetting SCEVAddRecExpr's cached values
when the flags are modified, but there are still cases for other SCEV
types. We relax the check by making sure the cached multiple divides the
recomputed multiple, ensuring the cached multiple is correct,
conservative multiple.

Reviewed By: mkazantsev

Differential Revision: https://reviews.llvm.org/D149529
  • Loading branch information
caojoshua committed May 8, 2023
1 parent 52ae3b5 commit 9c1d5e4
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 72 deletions.
14 changes: 10 additions & 4 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Expand Up @@ -967,6 +967,12 @@ class ScalarEvolution {
/// If S is guaranteed to be 0, it returns the bitwidth of S.
uint32_t getMinTrailingZeros(const SCEV *S);

/// Returns the max constant multiple of S.
APInt getConstantMultiple(const SCEV *S);

// Returns the max constant multiple of S. If S is exactly 0, return 1.
APInt getNonZeroConstantMultiple(const SCEV *S);

/// Determine the unsigned range for a particular SCEV.
/// NOTE: This returns a copy of the reference returned by getRangeRef.
ConstantRange getUnsignedRange(const SCEV *S) {
Expand Down Expand Up @@ -1434,14 +1440,14 @@ class ScalarEvolution {
/// predicate by splitting it into a set of independent predicates.
bool ProvingSplitPredicate = false;

/// Memoized values for the GetMinTrailingZeros
DenseMap<const SCEV *, uint32_t> MinTrailingZerosCache;
/// Memoized values for the getConstantMultiple
DenseMap<const SCEV *, APInt> ConstantMultipleCache;

/// Return the Value set from which the SCEV expr is generated.
ArrayRef<Value *> getSCEVValues(const SCEV *S);

/// Private helper method for the GetMinTrailingZeros method
uint32_t getMinTrailingZerosImpl(const SCEV *S);
/// Private helper method for the getConstantMultiple method.
APInt getConstantMultipleImpl(const SCEV *S);

/// Information about the number of times a particular loop exit may be
/// reached before exiting the loop.
Expand Down
167 changes: 114 additions & 53 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Expand Up @@ -6224,77 +6224,113 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
return getGEPExpr(GEP, IndexExprs);
}

uint32_t ScalarEvolution::getMinTrailingZerosImpl(const SCEV *S) {
APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
uint64_t BitWidth = getTypeSizeInBits(S->getType());
auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
return TrailingZeros >= BitWidth
? APInt::getZero(BitWidth)
: APInt::getOneBitSet(BitWidth, TrailingZeros);
};
auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
// The result is GCD of all operands results.
APInt Res = getConstantMultiple(N->getOperand(0));
for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
Res = APIntOps::GreatestCommonDivisor(
Res, getConstantMultiple(N->getOperand(I)));
return Res;
};

switch (S->getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(S)->getAPInt().countr_zero();
return cast<SCEVConstant>(S)->getAPInt();
case scPtrToInt:
return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
case scUDivExpr:
case scVScale:
return APInt(BitWidth, 1);
case scTruncate: {
// Only multiples that are a power of 2 will hold after truncation.
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
return std::min(getMinTrailingZeros(T->getOperand()),
(uint32_t)getTypeSizeInBits(T->getType()));
uint32_t TZ = getMinTrailingZeros(T->getOperand());
return GetShiftedByZeros(TZ);
}
case scZeroExtend: {
const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
return getConstantMultiple(Z->getOperand()).zext(BitWidth);
}
case scZeroExtend:
case scSignExtend: {
const SCEVIntegralCastExpr *E = cast<SCEVIntegralCastExpr>(S);
uint32_t OpRes = getMinTrailingZeros(E->getOperand());
return OpRes == getTypeSizeInBits(E->getOperand()->getType())
? getTypeSizeInBits(E->getType())
: OpRes;
const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
return getConstantMultiple(E->getOperand()).sext(BitWidth);
}
case scMulExpr: {
const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
// The result is the sum of all operands results.
uint32_t SumOpRes = getMinTrailingZeros(M->getOperand(0));
uint32_t BitWidth = getTypeSizeInBits(M->getType());
for (unsigned I = 1, E = M->getNumOperands();
SumOpRes != BitWidth && I != E; ++I)
SumOpRes =
std::min(SumOpRes + getMinTrailingZeros(M->getOperand(I)), BitWidth);
return SumOpRes;
if (M->hasNoUnsignedWrap()) {
// The result is the product of all operand results.
APInt Res = getConstantMultiple(M->getOperand(0));
for (const SCEV *Operand : M->operands().drop_front())
Res = Res * getConstantMultiple(Operand);
return Res;
}

// If there are no wrap guarentees, find the trailing zeros, which is the
// sum of trailing zeros for all its operands.
uint32_t TZ = 0;
for (const SCEV *Operand : M->operands())
TZ += getMinTrailingZeros(Operand);
return GetShiftedByZeros(TZ);
}
case scVScale:
return 0;
case scUDivExpr:
return 0;
case scPtrToInt:
case scAddExpr:
case scAddRecExpr:
case scAddRecExpr: {
const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
if (N->hasNoUnsignedWrap())
return GetGCDMultiple(N);
// Find the trailing bits, which is the minimum of its operands.
uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
for (const SCEV *Operand : N->operands().drop_front())
TZ = std::min(TZ, getMinTrailingZeros(Operand));
return GetShiftedByZeros(TZ);
}
case scUMaxExpr:
case scSMaxExpr:
case scUMinExpr:
case scSMinExpr:
case scSequentialUMinExpr: {
// The result is the min of all operands results.
ArrayRef<const SCEV *> Ops = S->operands();
uint32_t MinOpRes = getMinTrailingZeros(Ops[0]);
for (unsigned I = 1, E = Ops.size(); MinOpRes && I != E; ++I)
MinOpRes = std::min(MinOpRes, getMinTrailingZeros(Ops[I]));
return MinOpRes;
}
case scSequentialUMinExpr:
return GetGCDMultiple(cast<SCEVNAryExpr>(S));
case scUnknown: {
// ask ValueTracking for known bits
const SCEVUnknown *U = cast<SCEVUnknown>(S);
// For a SCEVUnknown, ask ValueTracking.
KnownBits Known =
computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT);
return Known.countMinTrailingZeros();
unsigned Known =
computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
.countMinTrailingZeros();
return GetShiftedByZeros(Known);
}
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
}
llvm_unreachable("Unknown SCEV kind!");
}

uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
auto I = MinTrailingZerosCache.find(S);
if (I != MinTrailingZerosCache.end())
APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
auto I = ConstantMultipleCache.find(S);
if (I != ConstantMultipleCache.end())
return I->second;

uint32_t Result = getMinTrailingZerosImpl(S);
auto InsertPair = MinTrailingZerosCache.insert({S, Result});
APInt Result = getConstantMultipleImpl(S);
auto InsertPair = ConstantMultipleCache.insert({S, Result});
assert(InsertPair.second && "Should insert a new key");
return InsertPair.first->second;
}

APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
APInt Multiple = getConstantMultiple(S);
return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
}

uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
return std::min(getConstantMultiple(S).countTrailingZeros(),
(unsigned)getTypeSizeInBits(S->getType()));
}

/// Helper method to assign a range to V from metadata present in the IR.
static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
if (Instruction *I = dyn_cast<Instruction>(V))
Expand All @@ -6310,6 +6346,7 @@ void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec,
AddRec->setNoWrapFlags(Flags);
UnsignedRanges.erase(AddRec);
SignedRanges.erase(AddRec);
ConstantMultipleCache.erase(AddRec);
}
}

Expand Down Expand Up @@ -6543,16 +6580,21 @@ const ConstantRange &ScalarEvolution::getRangeRef(

// If the value has known zeros, the maximum value will have those known zeros
// as well.
uint32_t TZ = getMinTrailingZeros(S);
if (TZ != 0) {
if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED)
if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
APInt Multiple = getNonZeroConstantMultiple(S);
APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
if (!Remainder.isZero())
ConservativeResult =
ConstantRange(APInt::getMinValue(BitWidth),
APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
else
APInt::getMaxValue(BitWidth) - Remainder + 1);
}
else {
uint32_t TZ = getMinTrailingZeros(S);
if (TZ != 0) {
ConservativeResult = ConstantRange(
APInt::getSignedMinValue(BitWidth),
APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
}
}

switch (S->getSCEVType()) {
Expand Down Expand Up @@ -8214,10 +8256,12 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
};

const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
if (!TC)
// Attempt to factor more general cases. Returns the greatest power of
// two divisor.
return GetSmallMultiple(getMinTrailingZeros(TCExpr));
if (!TC) {
APInt Multiple = getNonZeroConstantMultiple(TCExpr);
return Multiple.getActiveBits() > 32
? 1
: Multiple.zextOrTrunc(32).getZExtValue();
}

ConstantInt *Result = TC->getValue();
assert(Result && "SCEVConstant expected to have non-null ConstantInt");
Expand Down Expand Up @@ -8398,7 +8442,7 @@ void ScalarEvolution::forgetAllLoops() {
SignedRanges.clear();
ExprValueMap.clear();
HasRecMap.clear();
MinTrailingZerosCache.clear();
ConstantMultipleCache.clear();
PredicatedSCEVRewrites.clear();
FoldCache.clear();
FoldCacheUser.clear();
Expand Down Expand Up @@ -13414,7 +13458,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
PendingMerges(std::move(Arg.PendingMerges)),
MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
PredicatedBackedgeTakenCounts(
std::move(Arg.PredicatedBackedgeTakenCounts)),
Expand Down Expand Up @@ -13889,7 +13933,7 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
UnsignedRanges.erase(S);
SignedRanges.erase(S);
HasRecMap.erase(S);
MinTrailingZerosCache.erase(S);
ConstantMultipleCache.erase(S);

if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
UnsignedWrapViaInductionTried.erase(AR);
Expand Down Expand Up @@ -14269,6 +14313,23 @@ void ScalarEvolution::verify() const {
}
}
}

// Verify that ConstantMultipleCache computations are correct. It is possible
// that a recomputed multiple has a higher multiple than the cached multiple
// due to strengthened wrap flags. In this case, the cached multiple is a
// conservative, but still correct if it divides the recomputed multiple. As
// a special case, if if one multiple is zero, the other must also be zero.
for (auto [S, Multiple] : ConstantMultipleCache) {
APInt RecomputedMultiple = SE2.getConstantMultipleImpl(S);
if ((Multiple != RecomputedMultiple &&
(Multiple == 0 || RecomputedMultiple == 0)) &&
RecomputedMultiple.urem(Multiple) != 0) {
dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
<< *S << " : Computed " << RecomputedMultiple
<< " but cache contains " << Multiple << "!\n";
std::abort();
}
}
}

bool ScalarEvolution::invalidate(
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Analysis/ScalarEvolution/nsw.ll
Expand Up @@ -322,7 +322,7 @@ define void @bad_postinc_nsw_a(i32 %n) {
; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ]
; CHECK-NEXT: --> {0,+,7}<nuw><nsw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (7 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 7) + (1 umin %n))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 7
; CHECK-NEXT: --> {7,+,7}<nuw><%loop> U: [7,0) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {7,+,7}<nuw><%loop> U: [7,-3) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @bad_postinc_nsw_a
; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 7) + (1 umin %n))
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 613566756
Expand Down
12 changes: 6 additions & 6 deletions llvm/test/Analysis/ScalarEvolution/ranges.ll
@@ -1,6 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" -scev-range-iter-threshold=1 2>&1 | FileCheck %s
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>,verify<scalar-evolution>" 2>&1 | FileCheck %s
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>,verify<scalar-evolution>" -scev-range-iter-threshold=1 2>&1 | FileCheck %s

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64"

Expand Down Expand Up @@ -133,7 +133,7 @@ define void @add_6(i32 %n) {
; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ]
; CHECK-NEXT: --> {0,+,6}<nuw><nsw><%loop> U: [0,-2147483648) S: [0,2147483647) Exits: (6 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 6) + (1 umin %n))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 6
; CHECK-NEXT: --> {6,+,6}<nuw><%loop> U: [6,-1) S: [-2147483648,2147483647) Exits: (6 + (6 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 6) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {6,+,6}<nuw><%loop> U: [6,-3) S: [-2147483648,2147483647) Exits: (6 + (6 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 6) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @add_6
; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 6) + (1 umin %n))
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 715827882
Expand All @@ -160,7 +160,7 @@ define void @add_7(i32 %n) {
; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ]
; CHECK-NEXT: --> {0,+,7}<nuw><nsw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (7 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 7) + (1 umin %n))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 7
; CHECK-NEXT: --> {7,+,7}<nuw><%loop> U: [7,0) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {7,+,7}<nuw><%loop> U: [7,-3) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @add_7
; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 7) + (1 umin %n))
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 613566756
Expand Down Expand Up @@ -215,7 +215,7 @@ define void @add_9(i32 %n) {
; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ]
; CHECK-NEXT: --> {0,+,9}<nuw><nsw><%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (9 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 9) + (1 umin %n))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 9
; CHECK-NEXT: --> {9,+,9}<nuw><%loop> U: [9,0) S: [9,0) Exits: (9 + (9 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 9) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {9,+,9}<nuw><%loop> U: [9,-3) S: [9,0) Exits: (9 + (9 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 9) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @add_9
; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 9) + (1 umin %n))
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 477218588
Expand Down Expand Up @@ -243,7 +243,7 @@ define void @add_10(i32 %n) {
; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ]
; CHECK-NEXT: --> {0,+,10}<nuw><nsw><%loop> U: [0,-2147483648) S: [0,2147483647) Exits: (10 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 10) + (1 umin %n))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 10
; CHECK-NEXT: --> {10,+,10}<nuw><%loop> U: [10,-1) S: [-2147483648,2147483647) Exits: (10 + (10 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 10) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: --> {10,+,10}<nuw><%loop> U: [10,-5) S: [-2147483648,2147483647) Exits: (10 + (10 * ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 10) + (1 umin %n)))) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @add_10
; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n))<nuw><nsw> + %n) /u 10) + (1 umin %n))
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 429496729
Expand Down
Expand Up @@ -520,7 +520,7 @@ define void @test_trip_multiple_5(i32 %num) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
; CHECK-NEXT: Predicates:
; CHECK: Loop %for.body: Trip multiple is 1
; CHECK: Loop %for.body: Trip multiple is 5
;
entry:
%u = urem i32 %num, 5
Expand Down

0 comments on commit 9c1d5e4

Please sign in to comment.