Skip to content

Commit

Permalink
[SLP]Check if masked gather can be emitted as a serie of loads/insert…
Browse files Browse the repository at this point in the history
… subvector.

Masked gather is very expensive operation and sometimes better to
represent it as a serie of consecutive/strided loads + insertsubvectors
sequences. Patch adds some basic estimation and if loads+insertsubvector
is cheaper, decides to represent it in this way rather than masked
gather.

Reviewers: RKSimon

Reviewed By: RKSimon

Pull Request: #83481
  • Loading branch information
alexey-bataev committed Mar 1, 2024
1 parent dd426fa commit 3a30d8e
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 16 deletions.
96 changes: 89 additions & 7 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4000,12 +4000,14 @@ static bool isReverseOrder(ArrayRef<unsigned> Order) {

/// Checks if the given array of loads can be represented as a vectorized,
/// scatter or just simple gather.
static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
static LoadsState canVectorizeLoads(const BoUpSLP &R, ArrayRef<Value *> VL,
const Value *VL0,
const TargetTransformInfo &TTI,
const DataLayout &DL, ScalarEvolution &SE,
LoopInfo &LI, const TargetLibraryInfo &TLI,
SmallVectorImpl<unsigned> &Order,
SmallVectorImpl<Value *> &PointerOps) {
SmallVectorImpl<Value *> &PointerOps,
bool TryRecursiveCheck = true) {
// Check that a vectorized load would load the same memory as a scalar
// load. For example, we don't want to vectorize loads that are smaller
// than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM
Expand Down Expand Up @@ -4098,6 +4100,78 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
}
}
}
auto CheckForShuffledLoads = [&](Align CommonAlignment) {
unsigned Sz = DL.getTypeSizeInBits(ScalarTy);
unsigned MinVF = R.getMinVF(Sz);
unsigned MaxVF = std::max<unsigned>(bit_floor(VL.size() / 2), MinVF);
MaxVF = std::min(R.getMaximumVF(Sz, Instruction::Load), MaxVF);
for (unsigned VF = MaxVF; VF >= MinVF; VF /= 2) {
unsigned VectorizedCnt = 0;
SmallVector<LoadsState> States;
for (unsigned Cnt = 0, End = VL.size(); Cnt + VF <= End;
Cnt += VF, ++VectorizedCnt) {
ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
SmallVector<unsigned> Order;
SmallVector<Value *> PointerOps;
LoadsState LS =
canVectorizeLoads(R, Slice, Slice.front(), TTI, DL, SE, LI, TLI,
Order, PointerOps, /*TryRecursiveCheck=*/false);
// Check that the sorted loads are consecutive.
if (LS == LoadsState::Gather)
break;
// If need the reorder - consider as high-cost masked gather for now.
if ((LS == LoadsState::Vectorize ||
LS == LoadsState::StridedVectorize) &&
!Order.empty() && !isReverseOrder(Order))
LS = LoadsState::ScatterVectorize;
States.push_back(LS);
}
// Can be vectorized later as a serie of loads/insertelements.
if (VectorizedCnt == VL.size() / VF) {
// Compare masked gather cost and loads + insersubvector costs.
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost MaskedGatherCost = TTI.getGatherScatterOpCost(
Instruction::Load, VecTy,
cast<LoadInst>(VL0)->getPointerOperand(),
/*VariableMask=*/false, CommonAlignment, CostKind);
InstructionCost VecLdCost = 0;
auto *SubVecTy = FixedVectorType::get(ScalarTy, VF);
for (auto [I, LS] : enumerate(States)) {
auto *LI0 = cast<LoadInst>(VL[I * VF]);
switch (LS) {
case LoadsState::Vectorize:
VecLdCost += TTI.getMemoryOpCost(
Instruction::Load, SubVecTy, LI0->getAlign(),
LI0->getPointerAddressSpace(), CostKind,
TTI::OperandValueInfo());
break;
case LoadsState::StridedVectorize:
VecLdCost += TTI.getStridedMemoryOpCost(
Instruction::Load, SubVecTy, LI0->getPointerOperand(),
/*VariableMask=*/false, CommonAlignment, CostKind);
break;
case LoadsState::ScatterVectorize:
VecLdCost += TTI.getGatherScatterOpCost(
Instruction::Load, SubVecTy, LI0->getPointerOperand(),
/*VariableMask=*/false, CommonAlignment, CostKind);
break;
case LoadsState::Gather:
llvm_unreachable(
"Expected only consecutive, strided or masked gather loads.");
}
VecLdCost +=
TTI.getShuffleCost(TTI ::SK_InsertSubvector, VecTy,
std::nullopt, CostKind, I * VF, SubVecTy);
}
// If masked gather cost is higher - better to vectorize, so
// consider it as a gather node. It will be better estimated
// later.
if (MaskedGatherCost > VecLdCost)
return true;
}
}
return false;
};
// TODO: need to improve analysis of the pointers, if not all of them are
// GEPs or have > 2 operands, we end up with a gather node, which just
// increases the cost.
Expand All @@ -4114,8 +4188,17 @@ static LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
})) {
Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
if (TTI.isLegalMaskedGather(VecTy, CommonAlignment) &&
!TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment))
!TTI.forceScalarizeMaskedGather(VecTy, CommonAlignment)) {
// Check if potential masked gather can be represented as series
// of loads + insertsubvectors.
if (TryRecursiveCheck && CheckForShuffledLoads(CommonAlignment)) {
// If masked gather cost is higher - better to vectorize, so
// consider it as a gather node. It will be better estimated
// later.
return LoadsState::Gather;
}
return LoadsState::ScatterVectorize;
}
}
}

Expand Down Expand Up @@ -5554,8 +5637,8 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
// treats loading/storing it as an i8 struct. If we vectorize loads/stores
// from such a struct, we read/write packed bits disagreeing with the
// unvectorized version.
switch (canVectorizeLoads(VL, VL0, *TTI, *DL, *SE, *LI, *TLI, CurrentOrder,
PointerOps)) {
switch (canVectorizeLoads(*this, VL, VL0, *TTI, *DL, *SE, *LI, *TLI,
CurrentOrder, PointerOps)) {
case LoadsState::Vectorize:
return TreeEntry::Vectorize;
case LoadsState::ScatterVectorize:
Expand Down Expand Up @@ -7336,7 +7419,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
SmallVector<Value *> PointerOps;
OrdersType CurrentOrder;
LoadsState LS =
canVectorizeLoads(Slice, Slice.front(), TTI, *R.DL, *R.SE,
canVectorizeLoads(R, Slice, Slice.front(), TTI, *R.DL, *R.SE,
*R.LI, *R.TLI, CurrentOrder, PointerOps);
switch (LS) {
case LoadsState::Vectorize:
Expand Down Expand Up @@ -7599,7 +7682,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
transformMaskAfterShuffle(CommonMask, CommonMask);
}
SameNodesEstimated = false;
Cost += createShuffle(&E1, E2, Mask);
if (!E2 && InVectors.size() == 1) {
unsigned VF = E1.getVectorFactor();
if (Value *V1 = InVectors.front().dyn_cast<Value *>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,23 @@ define void @test(i1 %c, ptr %arg) {
; CHECK-LABEL: @test(
; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
; CHECK: if:
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <4 x ptr> poison, ptr [[ARG:%.*]], i32 0
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x ptr> [[TMP1]], <4 x ptr> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, <4 x ptr> [[TMP2]], <4 x i64> <i64 32, i64 24, i64 8, i64 0>
; CHECK-NEXT: [[TMP4:%.*]] = call <4 x i64> @llvm.masked.gather.v4i64.v4p0(<4 x ptr> [[TMP3]], i32 8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i64> poison)
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr [[ARG:%.*]], align 8
; CHECK-NEXT: [[ARG2_2:%.*]] = getelementptr inbounds i8, ptr [[ARG]], i64 24
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x i64>, ptr [[ARG2_2]], align 8
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <2 x i64> [[TMP2]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i64> [[TMP3]], <4 x i64> [[TMP4]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
; CHECK-NEXT: br label [[JOIN:%.*]]
; CHECK: else:
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x ptr> poison, ptr [[ARG]], i32 0
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x ptr> [[TMP5]], <4 x ptr> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, <4 x ptr> [[TMP6]], <4 x i64> <i64 32, i64 24, i64 8, i64 0>
; CHECK-NEXT: [[TMP8:%.*]] = call <4 x i64> @llvm.masked.gather.v4i64.v4p0(<4 x ptr> [[TMP7]], i32 8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i64> poison)
; CHECK-NEXT: [[TMP6:%.*]] = load <2 x i64>, ptr [[ARG]], align 8
; CHECK-NEXT: [[ARG_2:%.*]] = getelementptr inbounds i8, ptr [[ARG]], i64 24
; CHECK-NEXT: [[TMP7:%.*]] = load <2 x i64>, ptr [[ARG_2]], align 8
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <2 x i64> [[TMP7]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <2 x i64> [[TMP6]], <2 x i64> poison, <4 x i32> <i32 1, i32 0, i32 poison, i32 poison>
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <4 x i64> [[TMP8]], <4 x i64> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
; CHECK-NEXT: br label [[JOIN]]
; CHECK: join:
; CHECK-NEXT: [[TMP9:%.*]] = phi <4 x i64> [ [[TMP4]], [[IF]] ], [ [[TMP8]], [[ELSE]] ]
; CHECK-NEXT: [[TMP11:%.*]] = phi <4 x i64> [ [[TMP5]], [[IF]] ], [ [[TMP10]], [[ELSE]] ]
; CHECK-NEXT: ret void
;
br i1 %c, label %if, label %else
Expand Down

0 comments on commit 3a30d8e

Please sign in to comment.