diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 7ca43efb47c6e..4d67fb7892c8f 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1926,6 +1926,19 @@ class BoUpSLP { class ShuffleCostEstimator; class ShuffleInstructionBuilder; + /// If we decide to generate strided load / store, this struct contains all + /// the necessary info. It's fields are calculated by analyzeRtStrideCandidate + /// and analyzeConstantStrideCandidate. Note that Stride can be given either + /// as a SCEV or as a Value if it already exists. To get the stride in bytes, + /// StrideVal (or value obtained from StrideSCEV) has to by multiplied by the + /// size of element of FixedVectorType. + struct StridedPtrInfo { + Value *StrideVal = nullptr; + const SCEV *StrideSCEV = nullptr; + FixedVectorType *Ty = nullptr; + }; + SmallDenseMap TreeEntryToStridedPtrInfoMap; + public: /// Tracks the state we can represent the loads in the given sequence. enum class LoadsState { @@ -2221,6 +2234,11 @@ class BoUpSLP { /// TODO: If load combining is allowed in the IR optimizer, this analysis /// may not be necessary. bool isLoadCombineCandidate(ArrayRef Stores) const; + bool isStridedLoad(ArrayRef VL, ArrayRef PointerOps, + ArrayRef Order, const TargetTransformInfo &TTI, + const DataLayout &DL, ScalarEvolution &SE, + const bool IsAnyPointerUsedOutGraph, const int64_t Diff, + StridedPtrInfo &SPtrInfo) const; /// Checks if the given array of loads can be represented as a vectorized, /// scatter or just simple gather. @@ -2235,6 +2253,7 @@ class BoUpSLP { LoadsState canVectorizeLoads(ArrayRef VL, const Value *VL0, SmallVectorImpl &Order, SmallVectorImpl &PointerOps, + StridedPtrInfo &SPtrInfo, unsigned *BestVF = nullptr, bool TryRecursiveCheck = true) const; @@ -4479,11 +4498,10 @@ class BoUpSLP { /// Checks if the specified list of the instructions/values can be vectorized /// and fills required data before actual scheduling of the instructions. - TreeEntry::EntryState - getScalarsVectorizationState(const InstructionsState &S, ArrayRef VL, - bool IsScatterVectorizeUserTE, - OrdersType &CurrentOrder, - SmallVectorImpl &PointerOps); + TreeEntry::EntryState getScalarsVectorizationState( + const InstructionsState &S, ArrayRef VL, + bool IsScatterVectorizeUserTE, OrdersType &CurrentOrder, + SmallVectorImpl &PointerOps, StridedPtrInfo &SPtrInfo); /// Maps a specific scalar to its tree entry(ies). SmallDenseMap> ScalarToTreeEntries; @@ -6799,12 +6817,13 @@ isMaskedLoadCompress(ArrayRef VL, ArrayRef PointerOps, /// 4. Any pointer operand is an instruction with the users outside of the /// current graph (for masked gathers extra extractelement instructions /// might be required). -static bool isStridedLoad(ArrayRef VL, ArrayRef PointerOps, - ArrayRef Order, - const TargetTransformInfo &TTI, const DataLayout &DL, - ScalarEvolution &SE, - const bool IsAnyPointerUsedOutGraph, - const int64_t Diff) { +bool BoUpSLP::isStridedLoad(ArrayRef VL, ArrayRef PointerOps, + ArrayRef Order, + const TargetTransformInfo &TTI, + const DataLayout &DL, ScalarEvolution &SE, + const bool IsAnyPointerUsedOutGraph, + const int64_t Diff, + StridedPtrInfo &SPtrInfo) const { const size_t Sz = VL.size(); const uint64_t AbsoluteDiff = std::abs(Diff); Type *ScalarTy = VL.front()->getType(); @@ -6846,17 +6865,20 @@ static bool isStridedLoad(ArrayRef VL, ArrayRef PointerOps, if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) break; } - if (Dists.size() == Sz) + if (Dists.size() == Sz) { + Type *StrideTy = DL.getIndexType(Ptr0->getType()); + SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); + SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); return true; + } } return false; } -BoUpSLP::LoadsState -BoUpSLP::canVectorizeLoads(ArrayRef VL, const Value *VL0, - SmallVectorImpl &Order, - SmallVectorImpl &PointerOps, - unsigned *BestVF, bool TryRecursiveCheck) const { +BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( + ArrayRef VL, const Value *VL0, SmallVectorImpl &Order, + SmallVectorImpl &PointerOps, StridedPtrInfo &SPtrInfo, + unsigned *BestVF, bool TryRecursiveCheck) const { // Check that a vectorized load would load the same memory as a scalar // load. For example, we don't want to vectorize loads that are smaller // than 8-bit. Even though we have a packed struct {} LLVM @@ -6894,9 +6916,13 @@ BoUpSLP::canVectorizeLoads(ArrayRef VL, const Value *VL0, Align CommonAlignment = computeCommonAlignment(VL); if (!IsSorted) { if (Sz > MinProfitableStridedLoads && TTI->isTypeLegal(VecTy)) { - if (TTI->isLegalStridedLoadStore(VecTy, CommonAlignment) && - calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order)) + if (const SCEV *Stride = + calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order); + Stride && TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) { + SPtrInfo.Ty = getWidenedType(ScalarTy, PointerOps.size()); + SPtrInfo.StrideSCEV = Stride; return LoadsState::StridedVectorize; + } } if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) || @@ -6940,7 +6966,7 @@ BoUpSLP::canVectorizeLoads(ArrayRef VL, const Value *VL0, }); if (IsPossibleStrided && isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE, - IsAnyPointerUsedOutGraph, *Diff)) + IsAnyPointerUsedOutGraph, *Diff, SPtrInfo)) return LoadsState::StridedVectorize; } if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) || @@ -7024,9 +7050,9 @@ BoUpSLP::canVectorizeLoads(ArrayRef VL, const Value *VL0, ArrayRef Slice = VL.slice(Cnt, VF); SmallVector Order; SmallVector PointerOps; - LoadsState LS = - canVectorizeLoads(Slice, Slice.front(), Order, PointerOps, BestVF, - /*TryRecursiveCheck=*/false); + LoadsState LS = canVectorizeLoads(Slice, Slice.front(), Order, + PointerOps, SPtrInfo, BestVF, + /*TryRecursiveCheck=*/false); // Check that the sorted loads are consecutive. if (LS == LoadsState::Gather) { if (BestVF) { @@ -7698,9 +7724,10 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom, // extra analysis later, so include such nodes into a special list. if (TE.hasState() && TE.getOpcode() == Instruction::Load) { SmallVector PointerOps; + StridedPtrInfo SPtrInfo; OrdersType CurrentOrder; LoadsState Res = canVectorizeLoads(TE.Scalars, TE.Scalars.front(), - CurrentOrder, PointerOps); + CurrentOrder, PointerOps, SPtrInfo); if (Res == LoadsState::Vectorize || Res == LoadsState::StridedVectorize || Res == LoadsState::CompressVectorize) return std::move(CurrentOrder); @@ -9206,8 +9233,9 @@ void BoUpSLP::tryToVectorizeGatheredLoads( // Try to build vector load. ArrayRef Values( reinterpret_cast(Slice.begin()), Slice.size()); + StridedPtrInfo SPtrInfo; LoadsState LS = canVectorizeLoads(Values, Slice.front(), CurrentOrder, - PointerOps, &BestVF); + PointerOps, SPtrInfo, &BestVF); if (LS != LoadsState::Gather || (BestVF > 1 && static_cast(NumElts) == 2 * BestVF)) { if (LS == LoadsState::ScatterVectorize) { @@ -9401,6 +9429,7 @@ void BoUpSLP::tryToVectorizeGatheredLoads( unsigned VF = *CommonVF; OrdersType Order; SmallVector PointerOps; + StridedPtrInfo SPtrInfo; // Segmented load detected - vectorize at maximum vector factor. if (InterleaveFactor <= Slice.size() && TTI.isLegalInterleavedAccessType( @@ -9409,8 +9438,8 @@ void BoUpSLP::tryToVectorizeGatheredLoads( cast(Slice.front())->getAlign(), cast(Slice.front()) ->getPointerAddressSpace()) && - canVectorizeLoads(Slice, Slice.front(), Order, - PointerOps) == LoadsState::Vectorize) { + canVectorizeLoads(Slice, Slice.front(), Order, PointerOps, + SPtrInfo) == LoadsState::Vectorize) { UserMaxVF = InterleaveFactor * VF; } else { InterleaveFactor = 0; @@ -9432,8 +9461,9 @@ void BoUpSLP::tryToVectorizeGatheredLoads( ArrayRef VL = TE.Scalars; OrdersType Order; SmallVector PointerOps; + StridedPtrInfo SPtrInfo; LoadsState State = canVectorizeLoads( - VL, VL.front(), Order, PointerOps); + VL, VL.front(), Order, PointerOps, SPtrInfo); if (State == LoadsState::ScatterVectorize || State == LoadsState::CompressVectorize) return false; @@ -9451,11 +9481,11 @@ void BoUpSLP::tryToVectorizeGatheredLoads( [&, Slice = Slice](unsigned Idx) { OrdersType Order; SmallVector PointerOps; + StridedPtrInfo SPtrInfo; return canVectorizeLoads( Slice.slice(Idx * UserMaxVF, UserMaxVF), - Slice[Idx * UserMaxVF], Order, - PointerOps) == - LoadsState::ScatterVectorize; + Slice[Idx * UserMaxVF], Order, PointerOps, + SPtrInfo) == LoadsState::ScatterVectorize; })) UserMaxVF = MaxVF; if (Slice.size() != ConsecutiveNodesSize) @@ -9812,7 +9842,7 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy, BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( const InstructionsState &S, ArrayRef VL, bool IsScatterVectorizeUserTE, OrdersType &CurrentOrder, - SmallVectorImpl &PointerOps) { + SmallVectorImpl &PointerOps, StridedPtrInfo &SPtrInfo) { assert(S.getMainOp() && "Expected instructions with same/alternate opcodes only."); @@ -9914,7 +9944,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState( }); }); }; - switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps)) { + switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps, SPtrInfo)) { case LoadsState::Vectorize: return TreeEntry::Vectorize; case LoadsState::CompressVectorize: @@ -11384,8 +11414,9 @@ void BoUpSLP::buildTreeRec(ArrayRef VLRef, unsigned Depth, UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize; OrdersType CurrentOrder; SmallVector PointerOps; + StridedPtrInfo SPtrInfo; TreeEntry::EntryState State = getScalarsVectorizationState( - S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps); + S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps, SPtrInfo); if (State == TreeEntry::NeedToGather) { newGatherTreeEntry(VL, S, UserTreeIdx, ReuseShuffleIndices); return; @@ -11545,6 +11576,7 @@ void BoUpSLP::buildTreeRec(ArrayRef VLRef, unsigned Depth, // Vectorizing non-consecutive loads with `llvm.masked.gather`. TE = newTreeEntry(VL, TreeEntry::StridedVectorize, Bundle, S, UserTreeIdx, ReuseShuffleIndices, CurrentOrder); + TreeEntryToStridedPtrInfoMap[TE] = SPtrInfo; LLVM_DEBUG(dbgs() << "SLP: added a new TreeEntry (strided LoadInst).\n"; TE->dump()); break; @@ -12933,8 +12965,9 @@ void BoUpSLP::transformNodes() { if (S.getOpcode() == Instruction::Load) { OrdersType Order; SmallVector PointerOps; - LoadsState Res = - canVectorizeLoads(Slice, Slice.front(), Order, PointerOps); + StridedPtrInfo SPtrInfo; + LoadsState Res = canVectorizeLoads(Slice, Slice.front(), Order, + PointerOps, SPtrInfo); AllStrided &= Res == LoadsState::StridedVectorize || Res == LoadsState::ScatterVectorize || Res == LoadsState::Gather; @@ -13040,10 +13073,18 @@ void BoUpSLP::transformNodes() { InstructionCost StridedCost = TTI->getStridedMemoryOpCost( Instruction::Load, VecTy, BaseLI->getPointerOperand(), /*VariableMask=*/false, CommonAlignment, CostKind, BaseLI); - if (StridedCost < OriginalVecCost || ForceStridedLoads) + if (StridedCost < OriginalVecCost || ForceStridedLoads) { // Strided load is more profitable than consecutive load + reverse - // transform the node to strided load. + Type *StrideTy = DL->getIndexType(cast(E.Scalars.front()) + ->getPointerOperand() + ->getType()); + StridedPtrInfo SPtrInfo; + SPtrInfo.StrideVal = ConstantInt::get(StrideTy, 1); + SPtrInfo.Ty = VecTy; + TreeEntryToStridedPtrInfoMap[&E] = SPtrInfo; E.State = TreeEntry::StridedVectorize; + } } break; } @@ -19484,6 +19525,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { LoadInst *LI = cast(VL0); Instruction *NewLI; + FixedVectorType *StridedLoadTy = nullptr; Value *PO = LI->getPointerOperand(); if (E->State == TreeEntry::Vectorize) { NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign()); @@ -19521,43 +19563,36 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Value *Ptr0 = cast(E->Scalars.front())->getPointerOperand(); Value *PtrN = cast(E->Scalars.back())->getPointerOperand(); PO = IsReverseOrder ? PtrN : Ptr0; - std::optional Diff = getPointersDiff( - VL0->getType(), Ptr0, VL0->getType(), PtrN, *DL, *SE); Type *StrideTy = DL->getIndexType(PO->getType()); Value *StrideVal; - if (Diff) { - int64_t Stride = - *Diff / (static_cast(E->Scalars.size()) - 1); - StrideVal = - ConstantInt::get(StrideTy, (IsReverseOrder ? -1 : 1) * Stride * - DL->getTypeAllocSize(ScalarTy)); - } else { - SmallVector PointerOps(E->Scalars.size(), nullptr); - transform(E->Scalars, PointerOps.begin(), [](Value *V) { - return cast(V)->getPointerOperand(); - }); - OrdersType Order; - const SCEV *StrideSCEV = - calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order); - assert(StrideSCEV && "At this point stride should be known"); + const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E); + StridedLoadTy = SPtrInfo.Ty; + assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry."); + unsigned StridedLoadEC = + StridedLoadTy->getElementCount().getKnownMinValue(); + + Value *Stride = SPtrInfo.StrideVal; + if (!Stride) { + const SCEV *StrideSCEV = SPtrInfo.StrideSCEV; + assert(StrideSCEV && "Neither StrideVal nor StrideSCEV were set."); SCEVExpander Expander(*SE, *DL, "strided-load-vec"); - Value *Stride = Expander.expandCodeFor( - StrideSCEV, StrideSCEV->getType(), &*Builder.GetInsertPoint()); - Value *NewStride = - Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true); - StrideVal = Builder.CreateMul( - NewStride, - ConstantInt::get( - StrideTy, - (IsReverseOrder ? -1 : 1) * - static_cast(DL->getTypeAllocSize(ScalarTy)))); - } + Stride = Expander.expandCodeFor(StrideSCEV, StrideSCEV->getType(), + &*Builder.GetInsertPoint()); + } + Value *NewStride = + Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true); + StrideVal = Builder.CreateMul( + NewStride, ConstantInt::get( + StrideTy, (IsReverseOrder ? -1 : 1) * + static_cast( + DL->getTypeAllocSize(ScalarTy)))); Align CommonAlignment = computeCommonAlignment(E->Scalars); auto *Inst = Builder.CreateIntrinsic( Intrinsic::experimental_vp_strided_load, - {VecTy, PO->getType(), StrideTy}, - {PO, StrideVal, Builder.getAllOnesMask(VecTy->getElementCount()), - Builder.getInt32(E->Scalars.size())}); + {StridedLoadTy, PO->getType(), StrideTy}, + {PO, StrideVal, + Builder.getAllOnesMask(ElementCount::getFixed(StridedLoadEC)), + Builder.getInt32(StridedLoadEC)}); Inst->addParamAttr( /*ArgNo=*/0, Attribute::getWithAlignment(Inst->getContext(), CommonAlignment));