diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 040e2dafb56a6..c85d994cf09d6 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6323,17 +6323,11 @@ static bool isReverseOrder(ArrayRef Order) { } /// Checks if the provided list of pointers \p Pointers represents the strided -/// pointers for type ElemTy. If they are not, std::nullopt is returned. -/// Otherwise, if \p Inst is not specified, just initialized optional value is -/// returned to show that the pointers represent strided pointers. If \p Inst -/// specified, the runtime stride is materialized before the given \p Inst. -/// \returns std::nullopt if the pointers are not pointers with the runtime -/// stride, nullptr or actual stride value, otherwise. -static std::optional -calculateRtStride(ArrayRef PointerOps, Type *ElemTy, - const DataLayout &DL, ScalarEvolution &SE, - SmallVectorImpl &SortedIndices, - Instruction *Inst = nullptr) { +/// pointers for type ElemTy. If they are not, nullptr is returned. +/// Otherwise, SCEV* of the stride value is returned. +static const SCEV *calculateRtStride(ArrayRef PointerOps, Type *ElemTy, + const DataLayout &DL, ScalarEvolution &SE, + SmallVectorImpl &SortedIndices) { SmallVector SCEVs; const SCEV *PtrSCEVLowest = nullptr; const SCEV *PtrSCEVHighest = nullptr; @@ -6342,7 +6336,7 @@ calculateRtStride(ArrayRef PointerOps, Type *ElemTy, for (Value *Ptr : PointerOps) { const SCEV *PtrSCEV = SE.getSCEV(Ptr); if (!PtrSCEV) - return std::nullopt; + return nullptr; SCEVs.push_back(PtrSCEV); if (!PtrSCEVLowest && !PtrSCEVHighest) { PtrSCEVLowest = PtrSCEVHighest = PtrSCEV; @@ -6350,14 +6344,14 @@ calculateRtStride(ArrayRef PointerOps, Type *ElemTy, } const SCEV *Diff = SE.getMinusSCEV(PtrSCEV, PtrSCEVLowest); if (isa(Diff)) - return std::nullopt; + return nullptr; if (Diff->isNonConstantNegative()) { PtrSCEVLowest = PtrSCEV; continue; } const SCEV *Diff1 = SE.getMinusSCEV(PtrSCEVHighest, PtrSCEV); if (isa(Diff1)) - return std::nullopt; + return nullptr; if (Diff1->isNonConstantNegative()) { PtrSCEVHighest = PtrSCEV; continue; @@ -6366,7 +6360,7 @@ calculateRtStride(ArrayRef PointerOps, Type *ElemTy, // Dist = PtrSCEVHighest - PtrSCEVLowest; const SCEV *Dist = SE.getMinusSCEV(PtrSCEVHighest, PtrSCEVLowest); if (isa(Dist)) - return std::nullopt; + return nullptr; int Size = DL.getTypeStoreSize(ElemTy); auto TryGetStride = [&](const SCEV *Dist, const SCEV *Multiplier) -> const SCEV * { @@ -6387,10 +6381,10 @@ calculateRtStride(ArrayRef PointerOps, Type *ElemTy, const SCEV *Sz = SE.getConstant(Dist->getType(), Size * (SCEVs.size() - 1)); Stride = TryGetStride(Dist, Sz); if (!Stride) - return std::nullopt; + return nullptr; } if (!Stride || isa(Stride)) - return std::nullopt; + return nullptr; // Iterate through all pointers and check if all distances are // unique multiple of Stride. using DistOrdPair = std::pair; @@ -6404,28 +6398,28 @@ calculateRtStride(ArrayRef PointerOps, Type *ElemTy, const SCEV *Diff = SE.getMinusSCEV(PtrSCEV, PtrSCEVLowest); const SCEV *Coeff = TryGetStride(Diff, Stride); if (!Coeff) - return std::nullopt; + return nullptr; const auto *SC = dyn_cast(Coeff); if (!SC || isa(SC)) - return std::nullopt; + return nullptr; if (!SE.getMinusSCEV(PtrSCEV, SE.getAddExpr(PtrSCEVLowest, SE.getMulExpr(Stride, SC))) ->isZero()) - return std::nullopt; + return nullptr; Dist = SC->getAPInt().getZExtValue(); } // If the strides are not the same or repeated, we can't vectorize. if ((Dist / Size) * Size != Dist || (Dist / Size) >= SCEVs.size()) - return std::nullopt; + return nullptr; auto Res = Offsets.emplace(Dist, Cnt); if (!Res.second) - return std::nullopt; + return nullptr; // Consecutive order if the inserted element is the last one. IsConsecutive = IsConsecutive && std::next(Res.first) == Offsets.end(); ++Cnt; } if (Offsets.size() != SCEVs.size()) - return std::nullopt; + return nullptr; SortedIndices.clear(); if (!IsConsecutive) { // Fill SortedIndices array only if it is non-consecutive. @@ -6436,10 +6430,7 @@ calculateRtStride(ArrayRef PointerOps, Type *ElemTy, ++Cnt; } } - if (!Inst) - return nullptr; - SCEVExpander Expander(SE, DL, "strided-load-vec"); - return Expander.expandCodeFor(Stride, Stride->getType(), Inst); + return Stride; } static std::pair @@ -19520,11 +19511,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { return cast(V)->getPointerOperand(); }); OrdersType Order; - std::optional Stride = - calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order, - &*Builder.GetInsertPoint()); + const SCEV *StrideSCEV = + calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order); + assert(StrideSCEV && "At this point stride should be known"); + SCEVExpander Expander(*SE, *DL, "strided-load-vec"); + Value *Stride = Expander.expandCodeFor( + StrideSCEV, StrideSCEV->getType(), &*Builder.GetInsertPoint()); Value *NewStride = - Builder.CreateIntCast(*Stride, StrideTy, /*isSigned=*/true); + Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true); StrideVal = Builder.CreateMul( NewStride, ConstantInt::get(