Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 26 additions & 30 deletions llvm/lib/Analysis/LoopAccessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2929,13 +2929,14 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// computation of an interesting IV - but we chose not to as we
// don't have a cost model here, and broadening the scope exposes
// far too many unprofitable cases.
const SCEV *StrideExpr = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop);
if (!StrideExpr)
ScalarEvolution *SE = PSE->getSE();
const SCEV *Stride = getStrideFromPointer(Ptr, SE, TheLoop);
if (!Stride)
return;

LLVM_DEBUG(dbgs() << "LAA: Found a strided access that is a candidate for "
"versioning:");
LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *StrideExpr << "\n");
LLVM_DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n");

if (!SpeculateUnitStride) {
LLVM_DEBUG(dbgs() << " Chose not to due to -laa-speculate-unit-stride\n");
Expand All @@ -2955,40 +2956,35 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) {
// of various possible stride specializations, considering the alternatives
// of using gather/scatters (if available).

const SCEV *MaxBTC = PSE->getSymbolicMaxBackedgeTakenCount();
const SCEV *BTC = PSE->getSymbolicMaxBackedgeTakenCount();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small caveat: I think this should be getBackedgeTakenCount, the exact version, and not the symbolic-max version. However, it's not going to be easy to cook up a test to support this change, but I can think about it in a follow-up.


// Match the types so we can compare the stride and the MaxBTC.
// The Stride can be positive/negative, so we sign extend Stride;
// The backedgeTakenCount is non-negative, so we zero extend MaxBTC.
// Sign-extend the stride or zero-extend the BTC, as appropriate, before
// performing subtraction. We take care to do this because an unknown stride
// might equal an unknown TC, and we don't want to version the loop in that
// case.
const SCEV *CastedStride = Stride;
const SCEV *CastedBTC = BTC;
const DataLayout &DL = TheLoop->getHeader()->getDataLayout();
uint64_t StrideTypeSizeBits = DL.getTypeSizeInBits(StrideExpr->getType());
uint64_t BETypeSizeBits = DL.getTypeSizeInBits(MaxBTC->getType());
const SCEV *CastedStride = StrideExpr;
const SCEV *CastedBECount = MaxBTC;
ScalarEvolution *SE = PSE->getSE();
if (BETypeSizeBits >= StrideTypeSizeBits)
CastedStride = SE->getNoopOrSignExtend(StrideExpr, MaxBTC->getType());
if (DL.getTypeSizeInBits(BTC->getType()) >=
DL.getTypeSizeInBits(Stride->getType()))
CastedStride = SE->getNoopOrSignExtend(Stride, BTC->getType());
else
CastedBECount = SE->getZeroExtendExpr(MaxBTC, StrideExpr->getType());
const SCEV *StrideMinusBETaken = SE->getMinusSCEV(CastedStride, CastedBECount);
// Since TripCount == BackEdgeTakenCount + 1, checking:
// "Stride >= TripCount" is equivalent to checking:
// Stride - MaxBTC> 0
if (SE->isKnownPositive(StrideMinusBETaken)) {
LLVM_DEBUG(
dbgs() << "LAA: Stride>=TripCount; No point in versioning as the "
"Stride==1 predicate will imply that the loop executes "
"at most once.\n");
CastedBTC = SE->getZeroExtendExpr(BTC, Stride->getType());
const SCEV *StrideMinusBTC = SE->getMinusSCEV(CastedStride, CastedBTC);

// Stride - BTC > 0 is equivalent to Stride >= TripCount, but computing
// TripCount from BTC would introduce more casts, and Stride - TC might fail
// the known-non-negative test.
if (SE->isKnownPositive(StrideMinusBTC)) {
LLVM_DEBUG(dbgs() << "LAA: Not versioning with Stride==1 predicate.\n");
return;
}
LLVM_DEBUG(dbgs() << "LAA: Found a strided access that we can version.\n");

// Strip back off the integer cast, and check that our result is a
// SCEVUnknown as we expect.
const SCEV *StrideBase = StrideExpr;
if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(StrideBase))
StrideBase = C->getOperand();
SymbolicStrides[Ptr] = cast<SCEVUnknown>(StrideBase);
// Strip back off the integer cast, to get the resulting SCEVUnknown.
if (const auto *C = dyn_cast<SCEVIntegralCastExpr>(Stride))
Stride = C->getOperand();
SymbolicStrides[Ptr] = cast<SCEVUnknown>(Stride);
}

LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
Expand Down
Loading