From 52935f610360d345415dbeab7ec240bb95c81064 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Thu, 27 Jun 2024 17:12:33 +0100 Subject: [PATCH] LAA: make stride versioning code more robust Rewrite the stride-versioning code in LoopAccessAnalysis more robust, and make it possible to directly return the SCEVUnknown under a cast in getStrideFromPointer, eliminating a second cast-stripping in collectStridedAccess. --- llvm/lib/Analysis/LoopAccessAnalysis.cpp | 56 +++++++++++------------- 1 file changed, 26 insertions(+), 30 deletions(-) diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 980f142f11326..109c256121b77 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -2929,13 +2929,14 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { // computation of an interesting IV - but we chose not to as we // don't have a cost model here, and broadening the scope exposes // far too many unprofitable cases. - const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop); - if (!StrideExpr) + ScalarEvolution *SE = PSE->getSE(); + const SCEV *Stride = getStrideFromPointer(Ptr, SE, TheLoop); + if (!Stride) return; LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for " "versioning:"); - LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *StrideExpr << "\n"); + LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); if (!SpeculateUnitStride) { LLVM_DEBUG(dbgs() << " Chose not to due to -laa-speculate-unit-stride\n"); @@ -2955,40 +2956,35 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { // of various possible stride specializations, considering the alternatives // of using gather/scatters (if available). - const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount(); + const SCEV *BTC = PSE->getSymbolicMaxBackedgeTakenCount(); - // Match the types so we can compare the stride and the MaxBTC. - // The Stride can be positive/negative, so we sign extend Stride; - // The backedgeTakenCount is non-negative, so we zero extend MaxBTC. + // Sign-extend the stride or zero-extend the BTC, as appropriate, before + // performing subtraction. We take care to do this because an unknown stride + // might equal an unknown TC, and we don't want to version the loop in that + // case. + const SCEV *CastedStride = Stride; + const SCEV *CastedBTC = BTC; const DataLayout &DL = TheLoop->getHeader()->getDataLayout(); - uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType()); - uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType()); - const SCEV *CastedStride = StrideExpr; - const SCEV *CastedBECount = MaxBTC; - ScalarEvolution *SE = PSE->getSE(); - if (BETypeSizeBits >= StrideTypeSizeBits) - CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType()); + if (DL.getTypeSizeInBits(BTC->getType()) >= + DL.getTypeSizeInBits(Stride->getType())) + CastedStride = SE->getNoopOrSignExtend(Stride, BTC->getType()); else - CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType()); - const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount); - // Since TripCount == BackEdgeTakenCount + 1, checking: - // "Stride >= TripCount" is equivalent to checking: - // Stride - MaxBTC> 0 - if (SE->isKnownPositive(StrideMinusBETaken)) { - LLVM_DEBUG( - dbgs() << "LAA: Stride>=TripCount; No point in versioning as the " - "Stride==1 predicate will imply that the loop executes " - "at most once.\n"); + CastedBTC = SE->getZeroExtendExpr(BTC, Stride->getType()); + const SCEV *StrideMinusBTC = SE->getMinusSCEV(CastedStride, CastedBTC); + + // Stride - BTC > 0 is equivalent to Stride >= TripCount, but computing + // TripCount from BTC would introduce more casts, and Stride - TC might fail + // the known-non-negative test. + if (SE->isKnownPositive(StrideMinusBTC)) { + LLVM_DEBUG(dbgs() << "LAA: Not versioning with Stride==1 predicate.\n"); return; } LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n"); - // Strip back off the integer cast, and check that our result is a - // SCEVUnknown as we expect. - const SCEV *StrideBase = StrideExpr; - if (const auto *C = dyn_cast(StrideBase)) - StrideBase = C->getOperand(); - SymbolicStrides[Ptr] = cast(StrideBase); + // Strip back off the integer cast, to get the resulting SCEVUnknown. + if (const auto *C = dyn_cast(Stride)) + Stride = C->getOperand(); + SymbolicStrides[Ptr] = cast(Stride); } LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,