-
Notifications
You must be signed in to change notification settings - Fork 14.8k
LAA: thoroughly clarify stride-versioning code (NFC) #97075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-ir Author: Ramkumar Ramachandra (artagnon) ChangesThe current stride versioning code in collectStridedAccess is quite fragile, and has implicit effects. Make it more robust by making it clear that Stride - 1 == BTC is a special case, and operate on ConstantRanges directly. Query the exact backedge-taken count instead of the symbolic maximum of it. This patch has the side effect of making it possible to directly return the SCEVUnknown under a cast in getStrideFromPointer, eliminating a second cast-stripping in collectStridedAccess. It also has the side-effect of a positive test update in symbolic-stride. Full diff: https://github.com/llvm/llvm-project/pull/97075.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 7b94b9c6c6d11..86d0a6b35d748 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -277,6 +277,9 @@ class [[nodiscard]] ConstantRange {
/// Return true if all values in this range are non-negative.
bool isAllNonNegative() const;
+ /// Return true if all values in this range are positive.
+ bool isAllPositive() const;
+
/// Return the largest unsigned value contained in the ConstantRange.
APInt getUnsignedMax() const;
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 38bf6d8160aa9..4932ed61ec1ba 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -35,6 +35,7 @@
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugLoc.h"
@@ -61,6 +62,8 @@
#include <cassert>
#include <cstdint>
#include <iterator>
+#include <optional>
+#include <sys/types.h>
#include <utility>
#include <variant>
#include <vector>
@@ -2914,7 +2917,7 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(V))
if (isa<SCEVUnknown>(C->getOperand()))
- return V;
+ return C->getOperand();
return nullptr;
}
@@ -2930,7 +2933,8 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// computation of an interesting IV - but we chose not to as we
// don't have a cost model here, and broadening the scope exposes
// far too many unprofitable cases.
- const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
+ ScalarEvolution *SE = PSE->getSE();
+ const SCEV *StrideExpr = getStrideFromPointer(Ptr, SE, TheLoop);
if (!StrideExpr)
return;
@@ -2943,10 +2947,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
return;
}
- // Avoid adding the "Stride == 1" predicate when we know that
- // Stride >= Trip-Count. Such a predicate will effectively optimize a single
- // or zero iteration loop, as Trip-Count <= Stride == 1.
- //
// TODO: We are currently not making a very informed decision on when it is
// beneficial to apply stride versioning. It might make more sense that the
// users of this analysis (such as the vectorizer) will trigger it, based on
@@ -2956,40 +2956,30 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// of various possible stride specializations, considering the alternatives
// of using gather/scatters (if available).
- const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
-
- // Match the types so we can compare the stride and the MaxBTC.
- // The Stride can be positive/negative, so we sign extend Stride;
- // The backedgeTakenCount is non-negative, so we zero extend MaxBTC.
- const DataLayout &DL = TheLoop->getHeader()->getDataLayout();
- uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType());
- uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType());
- const SCEV *CastedStride = StrideExpr;
- const SCEV *CastedBECount = MaxBTC;
- ScalarEvolution *SE = PSE->getSE();
- if (BETypeSizeBits >= StrideTypeSizeBits)
- CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType());
- else
- CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType());
- const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
- // Since TripCount == BackEdgeTakenCount + 1, checking:
- // "Stride >= TripCount" is equivalent to checking:
- // Stride - MaxBTC> 0
- if (SE->isKnownPositive(StrideMinusBETaken)) {
- LLVM_DEBUG(
- dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
- "Stride==1 predicate will imply that the loop executes "
- "at most once.\n");
+ // Get two signed ranges and compare them, after adjusting for bitwidth. BTC
+ // range could extend into -1.
+ const SCEV *BTC = PSE->getBackedgeTakenCount();
+ ConstantRange BTCRange = SE->getSignedRange(BTC);
+ ConstantRange StrideRange =
+ SE->getSignedRange(StrideExpr).sextOrTrunc(BTCRange.getBitWidth());
+
+ // Stride is zero-extended to compare with BTC.
+ const SCEV *CastedStride =
+ SE->getTruncateOrZeroExtend(StrideExpr, BTC->getType());
+ const SCEV *StrideMinusOne =
+ SE->getMinusSCEV(CastedStride, SE->getOne(CastedStride->getType()));
+
+ // Stride - 1 exactly equal to BTC is a special case for which the loop should
+ // not be versioned. Otherwise, the loop should not be versioned if the range
+ // difference is all positive.
+ if (StrideMinusOne == BTC ||
+ StrideRange.difference(BTCRange).isAllPositive()) {
+ LLVM_DEBUG(dbgs() << "LAA: Not versioning with Stride==1 predicate.\n");
return;
}
LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
- // Strip back off the integer cast, and check that our result is a
- // SCEVUnknown as we expect.
- const SCEV *StrideBase = StrideExpr;
- if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
- StrideBase = C->getOperand();
- SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
+ SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideExpr);
}
LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 19041704a40be..b942894d34467 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -440,6 +440,11 @@ bool ConstantRange::isAllNonNegative() const {
return !isSignWrappedSet() && Lower.isNonNegative();
}
+bool ConstantRange::isAllPositive() const {
+ // Empty and full set are automatically treated correctly.
+ return !isSignWrappedSet() && Lower.isStrictlyPositive();
+}
+
APInt ConstantRange::getUnsignedMax() const {
if (isFullSet() || isUpperWrapped())
return APInt::getMaxValue(getBitWidth());
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
index 7c1b11e22aef2..e9aeac7ac2bc5 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
@@ -170,23 +170,16 @@ define void @single_stride_castexpr_multiuse(i32 %offset, ptr %src, ptr %dst, i1
; CHECK-NEXT: %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3
; CHECK-NEXT: Grouped accesses:
; CHECK-NEXT: Group [[GRP3]]:
-; CHECK-NEXT: (Low: ((4 * %iv.1) + %dst) High: (804 + (4 * %iv.1) + (-4 * (zext i32 %offset to i64))<nsw> + %dst))
-; CHECK-NEXT: Member: {((4 * %iv.1) + %dst),+,4}<%inner.loop>
+; CHECK-NEXT: (Low: (((4 * %iv.1) + %dst) umin ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst)) High: (4 + (((4 * %iv.1) + %dst) umax ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst))))
+; CHECK-NEXT: Member: {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
; CHECK-NEXT: Group [[GRP4]]:
-; CHECK-NEXT: (Low: (4 + %src) High: (808 + (-4 * (zext i32 %offset to i64))<nsw> + %src))
-; CHECK-NEXT: Member: {(4 + %src),+,4}<%inner.loop>
+; CHECK-NEXT: (Low: ((4 * (zext i32 %offset to i64))<nuw><nsw> + %src) High: (804 + %src))
+; CHECK-NEXT: Member: {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
; CHECK-EMPTY:
; CHECK-NEXT: Non vectorizable stores to invariant address were not found in loop.
; CHECK-NEXT: SCEV assumptions:
-; CHECK-NEXT: Equal predicate: %offset == 1
; CHECK-EMPTY:
; CHECK-NEXT: Expressions re-written:
-; CHECK-NEXT: [PSE] %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3:
-; CHECK-NEXT: {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
-; CHECK-NEXT: --> {(4 + %src),+,4}<%inner.loop>
-; CHECK-NEXT: [PSE] %gep.dst = getelementptr i32, ptr %dst, i64 %iv.2:
-; CHECK-NEXT: {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
-; CHECK-NEXT: --> {((4 * %iv.1) + %dst),+,4}<%inner.loop>
; CHECK-NEXT: outer.header:
; CHECK-NEXT: Report: loop is not the innermost loop
; CHECK-NEXT: Dependences:
|
e81a7eb
to
7b1876a
Compare
Just to double check, the latest version doesn't seem to have any test changes (modulo adding new ones), is that intentional/expected? |
Yes. |
Rewrite the stride-versioning code in LoopAccessAnalysis more robust, and make it possible to directly return the SCEVUnknown under a cast in getStrideFromPointer, eliminating a second cast-stripping in collectStridedAccess.
7b1876a
to
52935f6
Compare
Thanks to the comprehensive testing that was recently landed, we now know that the existing code is correct, and cannot be improved. The patch has now been converted into an NFC. |
@@ -2955,40 +2956,35 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { | |||
// of various possible stride specializations, considering the alternatives | |||
// of using gather/scatters (if available). | |||
|
|||
const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount(); | |||
const SCEV *BTC = PSE->getSymbolicMaxBackedgeTakenCount(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small caveat: I think this should be getBackedgeTakenCount
, the exact version, and not the symbolic-max version. However, it's not going to be easy to cook up a test to support this change, but I can think about it in a follow-up.
Not pursuing this. |
Thoroughly rewrite the comments and debug messages in the stride-versioning code in LoopAccessAnalaysis, learning from the recent a80dd4 (LAA: pre-commit tests for stride-versioning) that added tests to cover it.