diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 9e7a97e9667d2..1f0da263c5980 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -2234,11 +2234,25 @@ class BoUpSLP { /// TODO: If load combining is allowed in the IR optimizer, this analysis /// may not be necessary. bool isLoadCombineCandidate(ArrayRef Stores) const; - bool isStridedLoad(ArrayRef VL, ArrayRef PointerOps, - ArrayRef Order, const TargetTransformInfo &TTI, - const DataLayout &DL, ScalarEvolution &SE, - const bool IsAnyPointerUsedOutGraph, const int64_t Diff, - StridedPtrInfo &SPtrInfo) const; + + /// Suppose we are given pointers of the form: %b + x * %c. + /// where %c is constant. Check if the pointers can be rearranged as follows: + /// %b + 0 * %c + /// %b + 1 * %c + /// %b + 2 * %c + /// ... + /// %b + n * %c + bool analyzeConstantStrideCandidate(ArrayRef PointerOps, + Type *ElemTy, Align CommonAlignment, + SmallVectorImpl &SortedIndices, + StridedPtrInfo &SPtrInfo, int64_t Diff, + Value *Ptr0, Value *PtrN) const; + + /// Same as analyzeConstantStrideCandidate but for run-time stride. + bool analyzeRtStrideCandidate(ArrayRef PointerOps, Type *ElemTy, + Align CommonAlignment, + SmallVectorImpl &SortedIndices, + StridedPtrInfo &SPtrInfo) const; /// Checks if the given array of loads can be represented as a vectorized, /// scatter or just simple gather. @@ -6805,53 +6819,50 @@ isMaskedLoadCompress(ArrayRef VL, ArrayRef PointerOps, CompressMask, LoadVecTy); } -/// Checks if strided loads can be generated out of \p VL loads with pointers \p -/// PointerOps: -/// 1. Target with strided load support is detected. -/// 2. The number of loads is greater than MinProfitableStridedLoads, or the -/// potential stride <= MaxProfitableLoadStride and the potential stride is -/// power-of-2 (to avoid perf regressions for the very small number of loads) -/// and max distance > number of loads, or potential stride is -1. -/// 3. The loads are ordered, or number of unordered loads <= -/// MaxProfitableUnorderedLoads, or loads are in reversed order. (this check is -/// to avoid extra costs for very expensive shuffles). -/// 4. Any pointer operand is an instruction with the users outside of the -/// current graph (for masked gathers extra extractelement instructions -/// might be required). -bool BoUpSLP::isStridedLoad(ArrayRef VL, ArrayRef PointerOps, - ArrayRef Order, - const TargetTransformInfo &TTI, - const DataLayout &DL, ScalarEvolution &SE, - const bool IsAnyPointerUsedOutGraph, - const int64_t Diff, - StridedPtrInfo &SPtrInfo) const { - const size_t Sz = VL.size(); +bool BoUpSLP::analyzeConstantStrideCandidate( + ArrayRef PointerOps, Type *ElemTy, Align CommonAlignment, + SmallVectorImpl &SortedIndices, StridedPtrInfo &SPtrInfo, + int64_t Diff, Value *Ptr0, Value *PtrN) const { + const size_t Sz = PointerOps.size(); + auto *StridedLoadTy = getWidenedType(ElemTy, Sz); + + // Try to generate strided load node if: + // 1. Target with strided load support is detected. + // 2. The number of loads is greater than MinProfitableStridedLoads, + // or the potential stride <= MaxProfitableLoadStride and the + // potential stride is power-of-2 (to avoid perf regressions for the very + // small number of loads) and max distance > number of loads, or potential + // stride is -1. + // 3. The loads are ordered, or number of unordered loads <= + // MaxProfitableUnorderedLoads, or loads are in reversed order. + // (this check is to avoid extra costs for very expensive shuffles). + // 4. Any pointer operand is an instruction with the users outside of the + // current graph (for masked gathers extra extractelement instructions + // might be required). + + if (!TTI->isTypeLegal(StridedLoadTy) || + !TTI->isLegalStridedLoadStore(StridedLoadTy, CommonAlignment)) + return false; + + // Simple check if not a strided access - clear order. + bool IsPossibleStrided = Diff % (Sz - 1) == 0; + auto IsAnyPointerUsedOutGraph = + IsPossibleStrided && any_of(PointerOps, [&](Value *V) { + return isa(V) && any_of(V->users(), [&](User *U) { + return !isVectorized(U) && !MustGather.contains(U); + }); + }); const uint64_t AbsoluteDiff = std::abs(Diff); - Type *ScalarTy = VL.front()->getType(); - auto *VecTy = getWidenedType(ScalarTy, Sz); if (IsAnyPointerUsedOutGraph || - (AbsoluteDiff > Sz && - (Sz > MinProfitableStridedLoads || + ((Sz > MinProfitableStridedLoads || (AbsoluteDiff <= MaxProfitableLoadStride * Sz && - AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz)))) || + has_single_bit(AbsoluteDiff))) && + AbsoluteDiff > Sz) || Diff == -(static_cast(Sz) - 1)) { int64_t Stride = Diff / static_cast(Sz - 1); if (Diff != Stride * static_cast(Sz - 1)) return false; - Align Alignment = - cast(Order.empty() ? VL.front() : VL[Order.front()]) - ->getAlign(); - if (!TTI.isLegalStridedLoadStore(VecTy, Alignment)) - return false; - Value *Ptr0; - Value *PtrN; - if (Order.empty()) { - Ptr0 = PointerOps.front(); - PtrN = PointerOps.back(); - } else { - Ptr0 = PointerOps[Order.front()]; - PtrN = PointerOps[Order.back()]; - } + // Iterate through all pointers and check if all distances are // unique multiple of Dist. SmallSet Dists; @@ -6860,22 +6871,40 @@ bool BoUpSLP::isStridedLoad(ArrayRef VL, ArrayRef PointerOps, if (Ptr == PtrN) Dist = Diff; else if (Ptr != Ptr0) - Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, DL, SE); + Dist = *getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE); // If the strides are not the same or repeated, we can't // vectorize. if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) break; } if (Dists.size() == Sz) { - Type *StrideTy = DL.getIndexType(Ptr0->getType()); + Type *StrideTy = DL->getIndexType(Ptr0->getType()); SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); - SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); + SPtrInfo.Ty = StridedLoadTy; return true; } } return false; } +bool BoUpSLP::analyzeRtStrideCandidate(ArrayRef PointerOps, + Type *ElemTy, Align CommonAlignment, + SmallVectorImpl &SortedIndices, + StridedPtrInfo &SPtrInfo) const { + const size_t Sz = PointerOps.size(); + auto *VecTy = getWidenedType(ElemTy, Sz); + if (!TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) + return false; + const SCEV *Stride = + calculateRtStride(PointerOps, ElemTy, *DL, *SE, SortedIndices); + if (!Stride) + return false; + + SPtrInfo.Ty = VecTy; + SPtrInfo.StrideSCEV = Stride; + return true; +} + BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( ArrayRef VL, const Value *VL0, SmallVectorImpl &Order, SmallVectorImpl &PointerOps, StridedPtrInfo &SPtrInfo, @@ -6916,15 +6945,10 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( auto *VecTy = getWidenedType(ScalarTy, Sz); Align CommonAlignment = computeCommonAlignment(VL); if (!IsSorted) { - if (Sz > MinProfitableStridedLoads && TTI->isTypeLegal(VecTy)) { - if (const SCEV *Stride = - calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order); - Stride && TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) { - SPtrInfo.Ty = getWidenedType(ScalarTy, PointerOps.size()); - SPtrInfo.StrideSCEV = Stride; - return LoadsState::StridedVectorize; - } - } + if (Sz > MinProfitableStridedLoads && + analyzeRtStrideCandidate(PointerOps, ScalarTy, CommonAlignment, Order, + SPtrInfo)) + return LoadsState::StridedVectorize; if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) || TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment)) @@ -6956,18 +6980,9 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( cast(V), UserIgnoreList); })) return LoadsState::CompressVectorize; - // Simple check if not a strided access - clear order. - bool IsPossibleStrided = *Diff % (Sz - 1) == 0; - // Try to generate strided load node. - auto IsAnyPointerUsedOutGraph = - IsPossibleStrided && any_of(PointerOps, [&](Value *V) { - return isa(V) && any_of(V->users(), [&](User *U) { - return !isVectorized(U) && !MustGather.contains(U); - }); - }); - if (IsPossibleStrided && - isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE, - IsAnyPointerUsedOutGraph, *Diff, SPtrInfo)) + + if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, CommonAlignment, + Order, SPtrInfo, *Diff, Ptr0, PtrN)) return LoadsState::StridedVectorize; } if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||