diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 6ac9018df641e..1814d9a6811c0 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -2237,8 +2237,7 @@ class BoUpSLP { bool isStridedLoad(ArrayRef VL, ArrayRef PointerOps, ArrayRef Order, const TargetTransformInfo &TTI, const DataLayout &DL, ScalarEvolution &SE, - const bool IsAnyPointerUsedOutGraph, const int64_t Diff, - StridedPtrInfo &SPtrInfo) const; + const int64_t Diff, StridedPtrInfo &SPtrInfo) const; /// Checks if the given array of loads can be represented as a vectorized, /// scatter or just simple gather. @@ -6822,10 +6821,19 @@ bool BoUpSLP::isStridedLoad(ArrayRef VL, ArrayRef PointerOps, ArrayRef Order, const TargetTransformInfo &TTI, const DataLayout &DL, ScalarEvolution &SE, - const bool IsAnyPointerUsedOutGraph, const int64_t Diff, StridedPtrInfo &SPtrInfo) const { const size_t Sz = VL.size(); + if (Diff % (Sz - 1) != 0) + return false; + + // Try to generate strided load node. + auto IsAnyPointerUsedOutGraph = any_of(PointerOps, [&](Value *V) { + return isa(V) && any_of(V->users(), [&](User *U) { + return !isVectorized(U) && !MustGather.contains(U); + }); + }); + const uint64_t AbsoluteDiff = std::abs(Diff); Type *ScalarTy = VL.front()->getType(); auto *VecTy = getWidenedType(ScalarTy, Sz); @@ -6956,18 +6964,7 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( cast(V), UserIgnoreList); })) return LoadsState::CompressVectorize; - // Simple check if not a strided access - clear order. - bool IsPossibleStrided = *Diff % (Sz - 1) == 0; - // Try to generate strided load node. - auto IsAnyPointerUsedOutGraph = - IsPossibleStrided && any_of(PointerOps, [&](Value *V) { - return isa(V) && any_of(V->users(), [&](User *U) { - return !isVectorized(U) && !MustGather.contains(U); - }); - }); - if (IsPossibleStrided && - isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE, - IsAnyPointerUsedOutGraph, *Diff, SPtrInfo)) + if (isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE, *Diff, SPtrInfo)) return LoadsState::StridedVectorize; } if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||