Skip to content

LAA: thoroughly clarify stride-versioning code (NFC) #97075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

artagnon
Copy link
Contributor

@artagnon artagnon commented Jun 28, 2024

Thoroughly rewrite the comments and debug messages in the stride-versioning code in LoopAccessAnalaysis, learning from the recent a80dd4 (LAA: pre-commit tests for stride-versioning) that added tests to cover it.

@artagnon artagnon requested review from nikic, fhahn and preames June 28, 2024 15:56
@llvmbot llvmbot added llvm:ir llvm:analysis Includes value tracking, cost tables and constant folding labels Jun 28, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 28, 2024

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-ir

Author: Ramkumar Ramachandra (artagnon)

Changes

The current stride versioning code in collectStridedAccess is quite fragile, and has implicit effects. Make it more robust by making it clear that Stride - 1 == BTC is a special case, and operate on ConstantRanges directly. Query the exact backedge-taken count instead of the symbolic maximum of it. This patch has the side effect of making it possible to directly return the SCEVUnknown under a cast in getStrideFromPointer, eliminating a second cast-stripping in collectStridedAccess. It also has the side-effect of a positive test update in symbolic-stride.


Full diff: https://github.com/llvm/llvm-project/pull/97075.diff

4 Files Affected:

  • (modified) llvm/include/llvm/IR/ConstantRange.h (+3)
  • (modified) llvm/lib/Analysis/LoopAccessAnalysis.cpp (+26-36)
  • (modified) llvm/lib/IR/ConstantRange.cpp (+5)
  • (modified) llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll (+4-11)
diff --git a/llvm/include/llvm/IR/ConstantRange.h b/llvm/include/llvm/IR/ConstantRange.h
index 7b94b9c6c6d11..86d0a6b35d748 100644
--- a/llvm/include/llvm/IR/ConstantRange.h
+++ b/llvm/include/llvm/IR/ConstantRange.h
@@ -277,6 +277,9 @@ class [[nodiscard]] ConstantRange {
   /// Return true if all values in this range are non-negative.
   bool isAllNonNegative() const;
 
+  /// Return true if all values in this range are positive.
+  bool isAllPositive() const;
+
   /// Return the largest unsigned value contained in the ConstantRange.
   APInt getUnsignedMax() const;
 
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 38bf6d8160aa9..4932ed61ec1ba 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -35,6 +35,7 @@
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugLoc.h"
@@ -61,6 +62,8 @@
 #include <cassert>
 #include <cstdint>
 #include <iterator>
+#include <optional>
+#include <sys/types.h>
 #include <utility>
 #include <variant>
 #include <vector>
@@ -2914,7 +2917,7 @@ static const SCEV *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *L
 
   if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(V))
     if (isa<SCEVUnknown>(C->getOperand()))
-      return V;
+      return C->getOperand();
 
   return nullptr;
 }
@@ -2930,7 +2933,8 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // computation of an interesting IV - but we chose not to as we
   // don't have a cost model here, and broadening the scope exposes
   // far too many unprofitable cases.
-  const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
+  ScalarEvolution *SE = PSE->getSE();
+  const SCEV *StrideExpr = getStrideFromPointer(Ptr, SE, TheLoop);
   if (!StrideExpr)
     return;
 
@@ -2943,10 +2947,6 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
     return;
   }
 
-  // Avoid adding the "Stride == 1" predicate when we know that
-  // Stride >= Trip-Count. Such a predicate will effectively optimize a single
-  // or zero iteration loop, as Trip-Count <= Stride == 1.
-  //
   // TODO: We are currently not making a very informed decision on when it is
   // beneficial to apply stride versioning. It might make more sense that the
   // users of this analysis (such as the vectorizer) will trigger it, based on
@@ -2956,40 +2956,30 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
   // of various possible stride specializations, considering the alternatives
   // of using gather/scatters (if available).
 
-  const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
-
-  // Match the types so we can compare the stride and the MaxBTC.
-  // The Stride can be positive/negative, so we sign extend Stride;
-  // The backedgeTakenCount is non-negative, so we zero extend MaxBTC.
-  const DataLayout &DL = TheLoop->getHeader()->getDataLayout();
-  uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType());
-  uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType());
-  const SCEV *CastedStride = StrideExpr;
-  const SCEV *CastedBECount = MaxBTC;
-  ScalarEvolution *SE = PSE->getSE();
-  if (BETypeSizeBits >= StrideTypeSizeBits)
-    CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType());
-  else
-    CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType());
-  const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
-  // Since TripCount == BackEdgeTakenCount + 1, checking:
-  // "Stride >= TripCount" is equivalent to checking:
-  // Stride - MaxBTC> 0
-  if (SE->isKnownPositive(StrideMinusBETaken)) {
-    LLVM_DEBUG(
-        dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
-                  "Stride==1 predicate will imply that the loop executes "
-                  "at most once.\n");
+  // Get two signed ranges and compare them, after adjusting for bitwidth. BTC
+  // range could extend into -1.
+  const SCEV *BTC = PSE->getBackedgeTakenCount();
+  ConstantRange BTCRange = SE->getSignedRange(BTC);
+  ConstantRange StrideRange =
+      SE->getSignedRange(StrideExpr).sextOrTrunc(BTCRange.getBitWidth());
+
+  // Stride is zero-extended to compare with BTC.
+  const SCEV *CastedStride =
+      SE->getTruncateOrZeroExtend(StrideExpr, BTC->getType());
+  const SCEV *StrideMinusOne =
+      SE->getMinusSCEV(CastedStride, SE->getOne(CastedStride->getType()));
+
+  // Stride - 1 exactly equal to BTC is a special case for which the loop should
+  // not be versioned. Otherwise, the loop should not be versioned if the range
+  // difference is all positive.
+  if (StrideMinusOne == BTC ||
+      StrideRange.difference(BTCRange).isAllPositive()) {
+    LLVM_DEBUG(dbgs() << "LAA: Not versioning with Stride==1 predicate.\n");
     return;
   }
   LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");
 
-  // Strip back off the integer cast, and check that our result is a
-  // SCEVUnknown as we expect.
-  const SCEV *StrideBase = StrideExpr;
-  if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
-    StrideBase = C->getOperand();
-  SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
+  SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideExpr);
 }
 
 LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
diff --git a/llvm/lib/IR/ConstantRange.cpp b/llvm/lib/IR/ConstantRange.cpp
index 19041704a40be..b942894d34467 100644
--- a/llvm/lib/IR/ConstantRange.cpp
+++ b/llvm/lib/IR/ConstantRange.cpp
@@ -440,6 +440,11 @@ bool ConstantRange::isAllNonNegative() const {
   return !isSignWrappedSet() && Lower.isNonNegative();
 }
 
+bool ConstantRange::isAllPositive() const {
+  // Empty and full set are automatically treated correctly.
+  return !isSignWrappedSet() && Lower.isStrictlyPositive();
+}
+
 APInt ConstantRange::getUnsignedMax() const {
   if (isFullSet() || isUpperWrapped())
     return APInt::getMaxValue(getBitWidth());
diff --git a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
index 7c1b11e22aef2..e9aeac7ac2bc5 100644
--- a/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
+++ b/llvm/test/Analysis/LoopAccessAnalysis/symbolic-stride.ll
@@ -170,23 +170,16 @@ define void @single_stride_castexpr_multiuse(i32 %offset, ptr %src, ptr %dst, i1
 ; CHECK-NEXT:          %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3
 ; CHECK-NEXT:      Grouped accesses:
 ; CHECK-NEXT:        Group [[GRP3]]:
-; CHECK-NEXT:          (Low: ((4 * %iv.1) + %dst) High: (804 + (4 * %iv.1) + (-4 * (zext i32 %offset to i64))<nsw> + %dst))
-; CHECK-NEXT:            Member: {((4 * %iv.1) + %dst),+,4}<%inner.loop>
+; CHECK-NEXT:          (Low: (((4 * %iv.1) + %dst) umin ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst)) High: (4 + (((4 * %iv.1) + %dst) umax ((4 * %iv.1) + (4 * (sext i32 %offset to i64) * (200 + (-1 * (zext i32 %offset to i64))<nsw>)<nsw>) + %dst))))
+; CHECK-NEXT:            Member: {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
 ; CHECK-NEXT:        Group [[GRP4]]:
-; CHECK-NEXT:          (Low: (4 + %src) High: (808 + (-4 * (zext i32 %offset to i64))<nsw> + %src))
-; CHECK-NEXT:            Member: {(4 + %src),+,4}<%inner.loop>
+; CHECK-NEXT:          (Low: ((4 * (zext i32 %offset to i64))<nuw><nsw> + %src) High: (804 + %src))
+; CHECK-NEXT:            Member: {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
 ; CHECK-EMPTY:
 ; CHECK-NEXT:      Non vectorizable stores to invariant address were not found in loop.
 ; CHECK-NEXT:      SCEV assumptions:
-; CHECK-NEXT:      Equal predicate: %offset == 1
 ; CHECK-EMPTY:
 ; CHECK-NEXT:      Expressions re-written:
-; CHECK-NEXT:      [PSE] %gep.src = getelementptr inbounds i32, ptr %src, i64 %iv.3:
-; CHECK-NEXT:        {((4 * (zext i32 %offset to i64))<nuw><nsw> + %src),+,4}<%inner.loop>
-; CHECK-NEXT:        --> {(4 + %src),+,4}<%inner.loop>
-; CHECK-NEXT:      [PSE] %gep.dst = getelementptr i32, ptr %dst, i64 %iv.2:
-; CHECK-NEXT:        {((4 * %iv.1) + %dst),+,(4 * (sext i32 %offset to i64))<nsw>}<%inner.loop>
-; CHECK-NEXT:        --> {((4 * %iv.1) + %dst),+,4}<%inner.loop>
 ; CHECK-NEXT:    outer.header:
 ; CHECK-NEXT:      Report: loop is not the innermost loop
 ; CHECK-NEXT:      Dependences:

@artagnon artagnon force-pushed the laa-stridever-robust branch from e81a7eb to 7b1876a Compare July 3, 2024 13:04
@fhahn
Copy link
Contributor

fhahn commented Jul 7, 2024

Just to double check, the latest version doesn't seem to have any test changes (modulo adding new ones), is that intentional/expected?

@artagnon
Copy link
Contributor Author

artagnon commented Jul 8, 2024

Just to double check, the latest version doesn't seem to have any test changes (modulo adding new ones), is that intentional/expected?

Yes.

Rewrite the stride-versioning code in LoopAccessAnalysis more robust,
and make it possible to directly return the SCEVUnknown under a cast in
getStrideFromPointer, eliminating a second cast-stripping in
collectStridedAccess.
@artagnon artagnon force-pushed the laa-stridever-robust branch from 7b1876a to 52935f6 Compare August 21, 2024 18:20
@artagnon artagnon changed the title LAA: make stride versioning code more robust LAA: thoroughly clarify stride-versioning code (NFC) Aug 21, 2024
@artagnon
Copy link
Contributor Author

Thanks to the comprehensive testing that was recently landed, we now know that the existing code is correct, and cannot be improved. The patch has now been converted into an NFC.

@@ -2955,40 +2956,35 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// of various possible stride specializations, considering the alternatives
// of using gather/scatters (if available).

const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
const SCEV *BTC = PSE->getSymbolicMaxBackedgeTakenCount();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small caveat: I think this should be getBackedgeTakenCount, the exact version, and not the symbolic-max version. However, it's not going to be easy to cook up a test to support this change, but I can think about it in a follow-up.

@artagnon
Copy link
Contributor Author

artagnon commented Oct 7, 2024

Not pursuing this.

@artagnon artagnon closed this Oct 7, 2024
@artagnon artagnon deleted the laa-stridever-robust branch October 7, 2024 09:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants