diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 858c1d5392071..8876e4ed6ae4f 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1002,10 +1002,14 @@ class ScalarEvolution { /// (at every loop iteration). It is, at the same time, the minimum number /// of times S is divisible by 2. For example, given {4,+,8} it returns 2. /// If S is guaranteed to be 0, it returns the bitwidth of S. - LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S); + /// If \p CtxI is not nullptr, return a constant multiple valid at \p CtxI. + LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, + const Instruction *CtxI = nullptr); - /// Returns the max constant multiple of S. - LLVM_ABI APInt getConstantMultiple(const SCEV *S); + /// Returns the max constant multiple of S. If \p CtxI is not nullptr, return + /// a constant multiple valid at \p CtxI. + LLVM_ABI APInt getConstantMultiple(const SCEV *S, + const Instruction *CtxI = nullptr); // Returns the max constant multiple of S. If S is exactly 0, return 1. LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S); @@ -1525,8 +1529,10 @@ class ScalarEvolution { /// Return the Value set from which the SCEV expr is generated. ArrayRef getSCEVValues(const SCEV *S); - /// Private helper method for the getConstantMultiple method. - APInt getConstantMultipleImpl(const SCEV *S); + /// Private helper method for the getConstantMultiple method. If \p CtxI is + /// not nullptr, return a constant multiple valid at \p CtxI. + APInt getConstantMultipleImpl(const SCEV *S, + const Instruction *Ctx = nullptr); /// Information about the number of times a particular loop exit may be /// reached before exiting the loop. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 63e1b1462d007..6f6776c827729 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6351,19 +6351,20 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { return getGEPExpr(GEP, IndexExprs); } -APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { +APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S, + const Instruction *CtxI) { uint64_t BitWidth = getTypeSizeInBits(S->getType()); auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) { return TrailingZeros >= BitWidth ? APInt::getZero(BitWidth) : APInt::getOneBitSet(BitWidth, TrailingZeros); }; - auto GetGCDMultiple = [this](const SCEVNAryExpr *N) { + auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) { // The result is GCD of all operands results. - APInt Res = getConstantMultiple(N->getOperand(0)); + APInt Res = getConstantMultiple(N->getOperand(0), CtxI); for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I) Res = APIntOps::GreatestCommonDivisor( - Res, getConstantMultiple(N->getOperand(I))); + Res, getConstantMultiple(N->getOperand(I), CtxI)); return Res; }; @@ -6371,33 +6372,33 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { case scConstant: return cast(S)->getAPInt(); case scPtrToInt: - return getConstantMultiple(cast(S)->getOperand()); + return getConstantMultiple(cast(S)->getOperand(), CtxI); case scUDivExpr: case scVScale: return APInt(BitWidth, 1); case scTruncate: { // Only multiples that are a power of 2 will hold after truncation. const SCEVTruncateExpr *T = cast(S); - uint32_t TZ = getMinTrailingZeros(T->getOperand()); + uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI); return GetShiftedByZeros(TZ); } case scZeroExtend: { const SCEVZeroExtendExpr *Z = cast(S); - return getConstantMultiple(Z->getOperand()).zext(BitWidth); + return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth); } case scSignExtend: { // Only multiples that are a power of 2 will hold after sext. const SCEVSignExtendExpr *E = cast(S); - uint32_t TZ = getMinTrailingZeros(E->getOperand()); + uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI); return GetShiftedByZeros(TZ); } case scMulExpr: { const SCEVMulExpr *M = cast(S); if (M->hasNoUnsignedWrap()) { // The result is the product of all operand results. - APInt Res = getConstantMultiple(M->getOperand(0)); + APInt Res = getConstantMultiple(M->getOperand(0), CtxI); for (const SCEV *Operand : M->operands().drop_front()) - Res = Res * getConstantMultiple(Operand); + Res = Res * getConstantMultiple(Operand, CtxI); return Res; } @@ -6405,7 +6406,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { // sum of trailing zeros for all its operands. uint32_t TZ = 0; for (const SCEV *Operand : M->operands()) - TZ += getMinTrailingZeros(Operand); + TZ += getMinTrailingZeros(Operand, CtxI); return GetShiftedByZeros(TZ); } case scAddExpr: @@ -6414,9 +6415,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { if (N->hasNoUnsignedWrap()) return GetGCDMultiple(N); // Find the trailing bits, which is the minimum of its operands. - uint32_t TZ = getMinTrailingZeros(N->getOperand(0)); + uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI); for (const SCEV *Operand : N->operands().drop_front()) - TZ = std::min(TZ, getMinTrailingZeros(Operand)); + TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI)); return GetShiftedByZeros(TZ); } case scUMaxExpr: @@ -6429,7 +6430,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { // ask ValueTracking for known bits const SCEVUnknown *U = cast(S); unsigned Known = - computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT) + computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT) .countMinTrailingZeros(); return GetShiftedByZeros(Known); } @@ -6439,12 +6440,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { llvm_unreachable("Unknown SCEV kind!"); } -APInt ScalarEvolution::getConstantMultiple(const SCEV *S) { +APInt ScalarEvolution::getConstantMultiple(const SCEV *S, + const Instruction *CtxI) { + // Skip looking up and updating the cache if there is a context instruction, + // as the result will only be valid in the specified context. + if (CtxI) + return getConstantMultipleImpl(S, CtxI); + auto I = ConstantMultipleCache.find(S); if (I != ConstantMultipleCache.end()) return I->second; - APInt Result = getConstantMultipleImpl(S); + APInt Result = getConstantMultipleImpl(S, CtxI); auto InsertPair = ConstantMultipleCache.insert({S, Result}); assert(InsertPair.second && "Should insert a new key"); return InsertPair.first->second; @@ -6455,8 +6462,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) { return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple; } -uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) { - return std::min(getConstantMultiple(S).countTrailingZeros(), +uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S, + const Instruction *CtxI) { + return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(), (unsigned)getTypeSizeInBits(S->getType())); } @@ -10243,8 +10251,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl *Predicates, - - ScalarEvolution &SE) { + ScalarEvolution &SE, const Loop *L) { uint32_t BW = A.getBitWidth(); assert(BW == SE.getTypeSizeInBits(B->getType())); assert(A != 0 && "A must be non-zero."); @@ -10260,7 +10267,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, // // B is divisible by D if and only if the multiplicity of prime factor 2 for B // is not less than multiplicity of this prime factor for D. - if (SE.getMinTrailingZeros(B) < Mult2) { + unsigned MinTZ = SE.getMinTrailingZeros(B); + // Try again with the terminator of the loop predecessor for context-specific + // result, if MinTZ s too small. + if (MinTZ < Mult2 && L->getLoopPredecessor()) + MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator()); + if (MinTZ < Mult2) { // Check if we can prove there's no remainder using URem. const SCEV *URem = SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2))); @@ -10708,7 +10720,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, return getCouldNotCompute(); const SCEV *E = SolveLinEquationWithOverflow( StepC->getAPInt(), getNegativeSCEV(Start), - AllowPredicates ? &Predicates : nullptr, *this); + AllowPredicates ? &Predicates : nullptr, *this, L); const SCEV *M = E; if (E != getCouldNotCompute()) { diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll index b1fe7b1b2b7ee..7ba422da79ad8 100644 --- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll @@ -615,22 +615,14 @@ define void @test_ptrs_aligned_by_4_via_assumption(ptr %start, ptr %end) { ; CHECK-LABEL: 'test_ptrs_aligned_by_4_via_assumption' ; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_4_via_assumption ; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ] -; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)) + %start) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4 -; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)) + %start) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_4_via_assumption -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. -; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. -; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. -; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) -; CHECK-NEXT: Predicates: -; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 -; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903 -; CHECK-NEXT: Predicates: -; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 -; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) -; CHECK-NEXT: Predicates: -; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903 +; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Loop %loop: Trip multiple is 1 ; entry: call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 4) ] @@ -652,22 +644,14 @@ define void @test_ptrs_aligned_by_8_via_assumption(ptr %start, ptr %end) { ; CHECK-LABEL: 'test_ptrs_aligned_by_8_via_assumption' ; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_8_via_assumption ; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ] -; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)) + %start) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4 -; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)) + %start) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_8_via_assumption -; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. -; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. -; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. -; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) -; CHECK-NEXT: Predicates: -; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 -; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903 -; CHECK-NEXT: Predicates: -; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 -; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) -; CHECK-NEXT: Predicates: -; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0 +; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903 +; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4) +; CHECK-NEXT: Loop %loop: Trip multiple is 1 ; entry: call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 8) ]