-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SLPVectorizer][NFC] Refactor canVectorizeLoads
.
#157911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2234,11 +2234,25 @@ class BoUpSLP { | |
/// TODO: If load combining is allowed in the IR optimizer, this analysis | ||
/// may not be necessary. | ||
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const; | ||
bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, | ||
ArrayRef<unsigned> Order, const TargetTransformInfo &TTI, | ||
const DataLayout &DL, ScalarEvolution &SE, | ||
const bool IsAnyPointerUsedOutGraph, const int64_t Diff, | ||
StridedPtrInfo &SPtrInfo) const; | ||
|
||
/// Suppose we are given pointers of the form: %b + x * %c. | ||
/// where %c is constant. Check if the pointers can be rearranged as follows: | ||
/// %b + 0 * %c | ||
/// %b + 1 * %c | ||
/// %b + 2 * %c | ||
/// ... | ||
/// %b + n * %c | ||
bool analyzeConstantStrideCandidate(ArrayRef<Value *> PointerOps, | ||
Type *ElemTy, Align CommonAlignment, | ||
SmallVectorImpl<unsigned> &SortedIndices, | ||
StridedPtrInfo &SPtrInfo, int64_t Diff, | ||
Value *Ptr0, Value *PtrN) const; | ||
|
||
/// Same as analyzeConstantStrideCandidate but for run-time stride. | ||
bool analyzeRtStrideCandidate(ArrayRef<Value *> PointerOps, Type *ElemTy, | ||
Align CommonAlignment, | ||
SmallVectorImpl<unsigned> &SortedIndices, | ||
StridedPtrInfo &SPtrInfo) const; | ||
|
||
/// Checks if the given array of loads can be represented as a vectorized, | ||
/// scatter or just simple gather. | ||
|
@@ -6805,53 +6819,50 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, | |
CompressMask, LoadVecTy); | ||
} | ||
|
||
/// Checks if strided loads can be generated out of \p VL loads with pointers \p | ||
/// PointerOps: | ||
/// 1. Target with strided load support is detected. | ||
/// 2. The number of loads is greater than MinProfitableStridedLoads, or the | ||
/// potential stride <= MaxProfitableLoadStride and the potential stride is | ||
/// power-of-2 (to avoid perf regressions for the very small number of loads) | ||
/// and max distance > number of loads, or potential stride is -1. | ||
/// 3. The loads are ordered, or number of unordered loads <= | ||
/// MaxProfitableUnorderedLoads, or loads are in reversed order. (this check is | ||
/// to avoid extra costs for very expensive shuffles). | ||
/// 4. Any pointer operand is an instruction with the users outside of the | ||
/// current graph (for masked gathers extra extractelement instructions | ||
/// might be required). | ||
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, | ||
ArrayRef<unsigned> Order, | ||
const TargetTransformInfo &TTI, | ||
const DataLayout &DL, ScalarEvolution &SE, | ||
const bool IsAnyPointerUsedOutGraph, | ||
const int64_t Diff, | ||
StridedPtrInfo &SPtrInfo) const { | ||
const size_t Sz = VL.size(); | ||
bool BoUpSLP::analyzeConstantStrideCandidate( | ||
ArrayRef<Value *> PointerOps, Type *ElemTy, Align CommonAlignment, | ||
SmallVectorImpl<unsigned> &SortedIndices, StridedPtrInfo &SPtrInfo, | ||
int64_t Diff, Value *Ptr0, Value *PtrN) const { | ||
const size_t Sz = PointerOps.size(); | ||
auto *StridedLoadTy = getWidenedType(ElemTy, Sz); | ||
|
||
// Try to generate strided load node if: | ||
// 1. Target with strided load support is detected. | ||
// 2. The number of loads is greater than MinProfitableStridedLoads, | ||
// or the potential stride <= MaxProfitableLoadStride and the | ||
// potential stride is power-of-2 (to avoid perf regressions for the very | ||
// small number of loads) and max distance > number of loads, or potential | ||
// stride is -1. | ||
// 3. The loads are ordered, or number of unordered loads <= | ||
// MaxProfitableUnorderedLoads, or loads are in reversed order. | ||
// (this check is to avoid extra costs for very expensive shuffles). | ||
// 4. Any pointer operand is an instruction with the users outside of the | ||
// current graph (for masked gathers extra extractelement instructions | ||
// might be required). | ||
|
||
if (!TTI->isTypeLegal(StridedLoadTy) || | ||
!TTI->isLegalStridedLoadStore(StridedLoadTy, CommonAlignment)) | ||
return false; | ||
|
||
// Simple check if not a strided access - clear order. | ||
bool IsPossibleStrided = Diff % (Sz - 1) == 0; | ||
auto IsAnyPointerUsedOutGraph = | ||
|
||
IsPossibleStrided && any_of(PointerOps, [&](Value *V) { | ||
return isa<Instruction>(V) && any_of(V->users(), [&](User *U) { | ||
mgudim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
return !isVectorized(U) && !MustGather.contains(U); | ||
}); | ||
}); | ||
const uint64_t AbsoluteDiff = std::abs(Diff); | ||
Type *ScalarTy = VL.front()->getType(); | ||
auto *VecTy = getWidenedType(ScalarTy, Sz); | ||
if (IsAnyPointerUsedOutGraph || | ||
(AbsoluteDiff > Sz && | ||
(Sz > MinProfitableStridedLoads || | ||
((Sz > MinProfitableStridedLoads || | ||
(AbsoluteDiff <= MaxProfitableLoadStride * Sz && | ||
AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz)))) || | ||
has_single_bit(AbsoluteDiff))) && | ||
AbsoluteDiff > Sz) || | ||
Comment on lines
+6859
to
+6860
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, this one. Why did you change it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, not sure what this is. Let me split it up in even smaller patches. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
Diff == -(static_cast<int64_t>(Sz) - 1)) { | ||
int64_t Stride = Diff / static_cast<int64_t>(Sz - 1); | ||
if (Diff != Stride * static_cast<int64_t>(Sz - 1)) | ||
return false; | ||
Align Alignment = | ||
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()]) | ||
->getAlign(); | ||
if (!TTI.isLegalStridedLoadStore(VecTy, Alignment)) | ||
return false; | ||
Value *Ptr0; | ||
Value *PtrN; | ||
if (Order.empty()) { | ||
Ptr0 = PointerOps.front(); | ||
PtrN = PointerOps.back(); | ||
} else { | ||
Ptr0 = PointerOps[Order.front()]; | ||
PtrN = PointerOps[Order.back()]; | ||
} | ||
|
||
// Iterate through all pointers and check if all distances are | ||
// unique multiple of Dist. | ||
SmallSet<int64_t, 4> Dists; | ||
|
@@ -6860,22 +6871,40 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps, | |
if (Ptr == PtrN) | ||
Dist = Diff; | ||
else if (Ptr != Ptr0) | ||
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, DL, SE); | ||
Dist = *getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE); | ||
// If the strides are not the same or repeated, we can't | ||
// vectorize. | ||
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second) | ||
break; | ||
} | ||
if (Dists.size() == Sz) { | ||
Type *StrideTy = DL.getIndexType(Ptr0->getType()); | ||
Type *StrideTy = DL->getIndexType(Ptr0->getType()); | ||
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride); | ||
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz); | ||
SPtrInfo.Ty = StridedLoadTy; | ||
return true; | ||
} | ||
mgudim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
} | ||
return false; | ||
} | ||
|
||
bool BoUpSLP::analyzeRtStrideCandidate(ArrayRef<Value *> PointerOps, | ||
Type *ElemTy, Align CommonAlignment, | ||
SmallVectorImpl<unsigned> &SortedIndices, | ||
StridedPtrInfo &SPtrInfo) const { | ||
mgudim marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
const size_t Sz = PointerOps.size(); | ||
auto *VecTy = getWidenedType(ElemTy, Sz); | ||
if (!TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) | ||
return false; | ||
const SCEV *Stride = | ||
calculateRtStride(PointerOps, ElemTy, *DL, *SE, SortedIndices); | ||
if (!Stride) | ||
return false; | ||
|
||
SPtrInfo.Ty = VecTy; | ||
SPtrInfo.StrideSCEV = Stride; | ||
return true; | ||
} | ||
|
||
BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( | ||
ArrayRef<Value *> VL, const Value *VL0, SmallVectorImpl<unsigned> &Order, | ||
SmallVectorImpl<Value *> &PointerOps, StridedPtrInfo &SPtrInfo, | ||
|
@@ -6916,15 +6945,10 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( | |
auto *VecTy = getWidenedType(ScalarTy, Sz); | ||
Align CommonAlignment = computeCommonAlignment<LoadInst>(VL); | ||
if (!IsSorted) { | ||
if (Sz > MinProfitableStridedLoads && TTI->isTypeLegal(VecTy)) { | ||
if (const SCEV *Stride = | ||
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order); | ||
Stride && TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) { | ||
SPtrInfo.Ty = getWidenedType(ScalarTy, PointerOps.size()); | ||
SPtrInfo.StrideSCEV = Stride; | ||
return LoadsState::StridedVectorize; | ||
} | ||
} | ||
if (Sz > MinProfitableStridedLoads && | ||
analyzeRtStrideCandidate(PointerOps, ScalarTy, CommonAlignment, Order, | ||
SPtrInfo)) | ||
return LoadsState::StridedVectorize; | ||
|
||
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) || | ||
TTI->forceScalarizeMaskedGather(VecTy, CommonAlignment)) | ||
|
@@ -6956,18 +6980,9 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads( | |
cast<Instruction>(V), UserIgnoreList); | ||
})) | ||
return LoadsState::CompressVectorize; | ||
// Simple check if not a strided access - clear order. | ||
bool IsPossibleStrided = *Diff % (Sz - 1) == 0; | ||
// Try to generate strided load node. | ||
auto IsAnyPointerUsedOutGraph = | ||
IsPossibleStrided && any_of(PointerOps, [&](Value *V) { | ||
return isa<Instruction>(V) && any_of(V->users(), [&](User *U) { | ||
return !isVectorized(U) && !MustGather.contains(U); | ||
}); | ||
}); | ||
if (IsPossibleStrided && | ||
isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE, | ||
IsAnyPointerUsedOutGraph, *Diff, SPtrInfo)) | ||
Comment on lines
-6959
to
-6970
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move this code (and other code movements) in separate NFC patch(es)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, CommonAlignment, | ||
Order, SPtrInfo, *Diff, Ptr0, PtrN)) | ||
return LoadsState::StridedVectorize; | ||
} | ||
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) || | ||
|
Uh oh!
There was an error while loading. Please reload this page.