-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SCEV] Pass loop pred branch as context instruction to getMinTrailingZ. #160941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis Author: Florian Hahn (fhahn) ChangesWhen computing the backedge taken count, we know that the expression must be valid just before we enter the loop. Using the terminator of the loop predecessor as context instruction for getConstantMultiple, getMinTrailingZeros allows using information from things like alignment assumptions. When a context instruction is used, the result is not cached, as it is only valid at the specific context instruction. Compile-time looks neutral: http://llvm-compile-time-tracker.com/compare.php?from=9be276ec75c087595ebb62fe11b35c1a90371a49&to=745980f5e1c8094ea1293cd145d0ef1390f03029&stat=instructions:u No impact on llvm-opt-benchmark (dtcxzyw/llvm-opt-benchmark#2867), but leads to additonal unrolling in ~90 files across a C/C++ based corpus including LLVM on AArch64 using libc++ (which emits alignment assumptions for things like std::vector::begin). Full diff: https://github.com/llvm/llvm-project/pull/160941.diff 3 Files Affected:
diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 858c1d5392071..8876e4ed6ae4f 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1002,10 +1002,14 @@ class ScalarEvolution {
/// (at every loop iteration). It is, at the same time, the minimum number
/// of times S is divisible by 2. For example, given {4,+,8} it returns 2.
/// If S is guaranteed to be 0, it returns the bitwidth of S.
- LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S);
+ /// If \p CtxI is not nullptr, return a constant multiple valid at \p CtxI.
+ LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S,
+ const Instruction *CtxI = nullptr);
- /// Returns the max constant multiple of S.
- LLVM_ABI APInt getConstantMultiple(const SCEV *S);
+ /// Returns the max constant multiple of S. If \p CtxI is not nullptr, return
+ /// a constant multiple valid at \p CtxI.
+ LLVM_ABI APInt getConstantMultiple(const SCEV *S,
+ const Instruction *CtxI = nullptr);
// Returns the max constant multiple of S. If S is exactly 0, return 1.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S);
@@ -1525,8 +1529,10 @@ class ScalarEvolution {
/// Return the Value set from which the SCEV expr is generated.
ArrayRef<Value *> getSCEVValues(const SCEV *S);
- /// Private helper method for the getConstantMultiple method.
- APInt getConstantMultipleImpl(const SCEV *S);
+ /// Private helper method for the getConstantMultiple method. If \p CtxI is
+ /// not nullptr, return a constant multiple valid at \p CtxI.
+ APInt getConstantMultipleImpl(const SCEV *S,
+ const Instruction *Ctx = nullptr);
/// Information about the number of times a particular loop exit may be
/// reached before exiting the loop.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index b08399b381f34..b201ea47b111f 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -6344,19 +6344,20 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
return getGEPExpr(GEP, IndexExprs);
}
-APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
+APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
+ const Instruction *CtxI) {
uint64_t BitWidth = getTypeSizeInBits(S->getType());
auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
return TrailingZeros >= BitWidth
? APInt::getZero(BitWidth)
: APInt::getOneBitSet(BitWidth, TrailingZeros);
};
- auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
+ auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
// The result is GCD of all operands results.
- APInt Res = getConstantMultiple(N->getOperand(0));
+ APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
Res = APIntOps::GreatestCommonDivisor(
- Res, getConstantMultiple(N->getOperand(I)));
+ Res, getConstantMultiple(N->getOperand(I), CtxI));
return Res;
};
@@ -6364,33 +6365,33 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
case scConstant:
return cast<SCEVConstant>(S)->getAPInt();
case scPtrToInt:
- return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
+ return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
case scUDivExpr:
case scVScale:
return APInt(BitWidth, 1);
case scTruncate: {
// Only multiples that are a power of 2 will hold after truncation.
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
- uint32_t TZ = getMinTrailingZeros(T->getOperand());
+ uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
return GetShiftedByZeros(TZ);
}
case scZeroExtend: {
const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
- return getConstantMultiple(Z->getOperand()).zext(BitWidth);
+ return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
}
case scSignExtend: {
// Only multiples that are a power of 2 will hold after sext.
const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
- uint32_t TZ = getMinTrailingZeros(E->getOperand());
+ uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
return GetShiftedByZeros(TZ);
}
case scMulExpr: {
const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
if (M->hasNoUnsignedWrap()) {
// The result is the product of all operand results.
- APInt Res = getConstantMultiple(M->getOperand(0));
+ APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
for (const SCEV *Operand : M->operands().drop_front())
- Res = Res * getConstantMultiple(Operand);
+ Res = Res * getConstantMultiple(Operand, CtxI);
return Res;
}
@@ -6398,7 +6399,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
// sum of trailing zeros for all its operands.
uint32_t TZ = 0;
for (const SCEV *Operand : M->operands())
- TZ += getMinTrailingZeros(Operand);
+ TZ += getMinTrailingZeros(Operand, CtxI);
return GetShiftedByZeros(TZ);
}
case scAddExpr:
@@ -6407,9 +6408,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
if (N->hasNoUnsignedWrap())
return GetGCDMultiple(N);
// Find the trailing bits, which is the minimum of its operands.
- uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
+ uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
for (const SCEV *Operand : N->operands().drop_front())
- TZ = std::min(TZ, getMinTrailingZeros(Operand));
+ TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
return GetShiftedByZeros(TZ);
}
case scUMaxExpr:
@@ -6422,7 +6423,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
// ask ValueTracking for known bits
const SCEVUnknown *U = cast<SCEVUnknown>(S);
unsigned Known =
- computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
+ computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
.countMinTrailingZeros();
return GetShiftedByZeros(Known);
}
@@ -6432,12 +6433,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
llvm_unreachable("Unknown SCEV kind!");
}
-APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
+APInt ScalarEvolution::getConstantMultiple(const SCEV *S,
+ const Instruction *CtxI) {
+ // Skip looking up and updating the cache if there is a context instruction,
+ // as the result will only be valid in the specified context.
+ if (CtxI)
+ return getConstantMultipleImpl(S, CtxI);
+
auto I = ConstantMultipleCache.find(S);
if (I != ConstantMultipleCache.end())
return I->second;
- APInt Result = getConstantMultipleImpl(S);
+ APInt Result = getConstantMultipleImpl(S, CtxI);
auto InsertPair = ConstantMultipleCache.insert({S, Result});
assert(InsertPair.second && "Should insert a new key");
return InsertPair.first->second;
@@ -6448,8 +6455,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
}
-uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
- return std::min(getConstantMultiple(S).countTrailingZeros(),
+uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S,
+ const Instruction *CtxI) {
+ return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
(unsigned)getTypeSizeInBits(S->getType()));
}
@@ -10236,8 +10244,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
static const SCEV *
SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
SmallVectorImpl<const SCEVPredicate *> *Predicates,
-
- ScalarEvolution &SE) {
+ ScalarEvolution &SE, const Loop *L) {
uint32_t BW = A.getBitWidth();
assert(BW == SE.getTypeSizeInBits(B->getType()));
assert(A != 0 && "A must be non-zero.");
@@ -10253,7 +10260,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
//
// B is divisible by D if and only if the multiplicity of prime factor 2 for B
// is not less than multiplicity of this prime factor for D.
- if (SE.getMinTrailingZeros(B) < Mult2) {
+ unsigned MinTZ = SE.getMinTrailingZeros(B);
+ // Try again with the terminator of the loop predecessor for context-specific
+ // result, if MinTZ s too small.
+ if (MinTZ < Mult2 && L->getLoopPredecessor())
+ MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
+ if (MinTZ < Mult2) {
// Check if we can prove there's no remainder using URem.
const SCEV *URem =
SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
@@ -10701,7 +10713,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
return getCouldNotCompute();
const SCEV *E = SolveLinEquationWithOverflow(
StepC->getAPInt(), getNegativeSCEV(Start),
- AllowPredicates ? &Predicates : nullptr, *this);
+ AllowPredicates ? &Predicates : nullptr, *this, L);
const SCEV *M = E;
if (E != getCouldNotCompute()) {
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
index b1fe7b1b2b7ee..7ba422da79ad8 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
@@ -615,22 +615,14 @@ define void @test_ptrs_aligned_by_4_via_assumption(ptr %start, ptr %end) {
; CHECK-LABEL: 'test_ptrs_aligned_by_4_via_assumption'
; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_4_via_assumption
; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ]
-; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4
-; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_4_via_assumption
-; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count.
-; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count.
-; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
-; CHECK-NEXT: Predicates:
-; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
-; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903
-; CHECK-NEXT: Predicates:
-; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
-; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
-; CHECK-NEXT: Predicates:
-; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
+; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
+; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903
+; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
+; CHECK-NEXT: Loop %loop: Trip multiple is 1
;
entry:
call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 4) ]
@@ -652,22 +644,14 @@ define void @test_ptrs_aligned_by_8_via_assumption(ptr %start, ptr %end) {
; CHECK-LABEL: 'test_ptrs_aligned_by_8_via_assumption'
; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_8_via_assumption
; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ]
-; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4
-; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_8_via_assumption
-; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count.
-; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count.
-; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count.
-; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
-; CHECK-NEXT: Predicates:
-; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
-; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903
-; CHECK-NEXT: Predicates:
-; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
-; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
-; CHECK-NEXT: Predicates:
-; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
+; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
+; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903
+; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
+; CHECK-NEXT: Loop %loop: Trip multiple is 1
;
entry:
call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 8) ]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we handling this case via context instruction rather than loop guards?
Ah yes good question, I forgot to mention that. I think to use loop guards, we would need to rewrite The only way I could think of would be something like the equivalent to |
When computing the backedge taken count, we know that the expression must be valid just before we enter the loop. Using the terminator of the loop predecessor as context instruction for getConstantMultiple, getMinTrailingZeros allows using information from things like alignment assumptions. When a context instruction is used, the result is not cached, as it is only valid at the specific context instruction.
325f0f8
to
d8499ac
Compare
ping :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Agree that we shouldn't rewrite pointers with casts to add the divisibility information, so this seems like the best we can do...
…MinTrailingZ. (#160941) When computing the backedge taken count, we know that the expression must be valid just before we enter the loop. Using the terminator of the loop predecessor as context instruction for getConstantMultiple, getMinTrailingZeros allows using information from things like alignment assumptions. When a context instruction is used, the result is not cached, as it is only valid at the specific context instruction. Compile-time looks neutral: http://llvm-compile-time-tracker.com/compare.php?from=9be276ec75c087595ebb62fe11b35c1a90371a49&to=745980f5e1c8094ea1293cd145d0ef1390f03029&stat=instructions:u No impact on llvm-opt-benchmark (dtcxzyw/llvm-opt-benchmark#2867), but leads to additonal unrolling in ~90 files across a C/C++ based corpus including LLVM on AArch64 using libc++ (which emits alignment assumptions for things like std::vector::begin). PR: llvm/llvm-project#160941
…Z. (#160941) When computing the backedge taken count, we know that the expression must be valid just before we enter the loop. Using the terminator of the loop predecessor as context instruction for getConstantMultiple, getMinTrailingZeros allows using information from things like alignment assumptions. When a context instruction is used, the result is not cached, as it is only valid at the specific context instruction. Compile-time looks neutral: http://llvm-compile-time-tracker.com/compare.php?from=9be276ec75c087595ebb62fe11b35c1a90371a49&to=745980f5e1c8094ea1293cd145d0ef1390f03029&stat=instructions:u No impact on llvm-opt-benchmark (dtcxzyw/llvm-opt-benchmark#2867), but leads to additonal unrolling in ~90 files across a C/C++ based corpus including LLVM on AArch64 using libc++ (which emits alignment assumptions for things like std::vector::begin). PR: #160941
…Z. (llvm#160941) When computing the backedge taken count, we know that the expression must be valid just before we enter the loop. Using the terminator of the loop predecessor as context instruction for getConstantMultiple, getMinTrailingZeros allows using information from things like alignment assumptions. When a context instruction is used, the result is not cached, as it is only valid at the specific context instruction. Compile-time looks neutral: http://llvm-compile-time-tracker.com/compare.php?from=9be276ec75c087595ebb62fe11b35c1a90371a49&to=745980f5e1c8094ea1293cd145d0ef1390f03029&stat=instructions:u No impact on llvm-opt-benchmark (dtcxzyw/llvm-opt-benchmark#2867), but leads to additonal unrolling in ~90 files across a C/C++ based corpus including LLVM on AArch64 using libc++ (which emits alignment assumptions for things like std::vector::begin). PR: llvm#160941
When computing the backedge taken count, we know that the expression must be valid just before we enter the loop. Using the terminator of the loop predecessor as context instruction for getConstantMultiple, getMinTrailingZeros allows using information from things like alignment assumptions.
When a context instruction is used, the result is not cached, as it is only valid at the specific context instruction.
Compile-time looks neutral: http://llvm-compile-time-tracker.com/compare.php?from=9be276ec75c087595ebb62fe11b35c1a90371a49&to=745980f5e1c8094ea1293cd145d0ef1390f03029&stat=instructions:u
No impact on llvm-opt-benchmark (dtcxzyw/llvm-opt-benchmark#2867), but leads to additonal unrolling in ~90 files across a C/C++ based corpus including LLVM on AArch64 using libc++ (which emits alignment assumptions for things like std::vector::begin).