Skip to content

Commit

Permalink
[SLP]Do not reorder reduction nodes.
Browse files Browse the repository at this point in the history
The final reduction nodes should not be reordered, the order does not
matter for reductions. Also, it might be profitable to vectorize smaller
reduction trees, reduction cost may compensate small tree cost.

Part of D111574

Differential Revision: https://reviews.llvm.org/D112467
  • Loading branch information
alexey-bataev committed Oct 26, 2021
1 parent 158083f commit ce14d1b
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 59 deletions.
76 changes: 47 additions & 29 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ class BoUpSLP {
/// operands. Plus, even the leaf nodes have different orders, it allows to
/// sink reordering in the graph closer to the root node and merge it later
/// during analysis.
void reorderBottomToTop();
void reorderBottomToTop(bool IgnoreReorder = false);

/// \return The vector element size in bits to use when vectorizing the
/// expression tree ending at \p V. If V is a store, the size is the width of
Expand Down Expand Up @@ -824,7 +824,7 @@ class BoUpSLP {

/// \returns True if the VectorizableTree is both tiny and not fully
/// vectorizable. We do not vectorize such trees.
bool isTreeTinyAndNotFullyVectorizable() const;
bool isTreeTinyAndNotFullyVectorizable(bool ForReduction = false) const;

/// Assume that a legal-sized 'or'-reduction of shifted/zexted loaded values
/// can be load combined in the backend. Load combining may not be allowed in
Expand Down Expand Up @@ -1620,7 +1620,7 @@ class BoUpSLP {

/// \returns whether the VectorizableTree is fully vectorizable and will
/// be beneficial even the tree height is tiny.
bool isFullyVectorizableTinyTree() const;
bool isFullyVectorizableTinyTree(bool ForReduction) const;

/// Reorder commutative or alt operands to get better probability of
/// generating vectorized code.
Expand Down Expand Up @@ -2820,7 +2820,7 @@ void BoUpSLP::reorderTopToBottom() {
}
}

void BoUpSLP::reorderBottomToTop() {
void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
SetVector<TreeEntry *> OrderedEntries;
DenseMap<const TreeEntry *, OrdersType> GathersToOrders;
// Find all reorderable leaf nodes with the given VF.
Expand Down Expand Up @@ -2950,7 +2950,8 @@ void BoUpSLP::reorderBottomToTop() {
SmallPtrSet<const TreeEntry *, 4> VisitedOps;
for (const auto &Op : Data.second) {
TreeEntry *OpTE = Op.second;
if (!OpTE->ReuseShuffleIndices.empty())
if (!OpTE->ReuseShuffleIndices.empty() ||
(IgnoreReorder && OpTE == VectorizableTree.front().get()))
continue;
const auto &Order = [OpTE, &GathersToOrders]() -> const OrdersType & {
if (OpTE->State == TreeEntry::NeedToGather)
Expand Down Expand Up @@ -3061,6 +3062,10 @@ void BoUpSLP::reorderBottomToTop() {
}
}
}
// If the reordering is unnecessary, just remove the reorder.
if (IgnoreReorder && !VectorizableTree.front()->ReorderIndices.empty() &&
VectorizableTree.front()->ReuseShuffleIndices.empty())
VectorizableTree.front()->ReorderIndices.clear();
}

void BoUpSLP::buildExternalUses(
Expand Down Expand Up @@ -4894,13 +4899,29 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
}
}

bool BoUpSLP::isFullyVectorizableTinyTree() const {
bool BoUpSLP::isFullyVectorizableTinyTree(bool ForReduction) const {
LLVM_DEBUG(dbgs() << "SLP: Check whether the tree with height "
<< VectorizableTree.size() << " is fully vectorizable .\n");

auto &&AreVectorizableGathers = [this](const TreeEntry *TE, unsigned Limit) {
SmallVector<int> Mask;
return TE->State == TreeEntry::NeedToGather &&
!any_of(TE->Scalars,
[this](Value *V) { return EphValues.contains(V); }) &&
(allConstant(TE->Scalars) || isSplat(TE->Scalars) ||
TE->Scalars.size() < Limit ||
(TE->getOpcode() == Instruction::ExtractElement &&
isFixedVectorShuffle(TE->Scalars, Mask)));
};

// We only handle trees of heights 1 and 2.
if (VectorizableTree.size() == 1 &&
VectorizableTree[0]->State == TreeEntry::Vectorize)
(VectorizableTree[0]->State == TreeEntry::Vectorize ||
(ForReduction &&
AreVectorizableGathers(VectorizableTree[0].get(),
VectorizableTree[0]->Scalars.size()) &&
(VectorizableTree[0]->Scalars.size() > 2 ||
VectorizableTree[0]->ReuseShuffleIndices.size() > 2))))
return true;

if (VectorizableTree.size() != 2)
Expand All @@ -4912,19 +4933,14 @@ bool BoUpSLP::isFullyVectorizableTinyTree() const {
// or they are extractelements, which form shuffle.
SmallVector<int> Mask;
if (VectorizableTree[0]->State == TreeEntry::Vectorize &&
(allConstant(VectorizableTree[1]->Scalars) ||
isSplat(VectorizableTree[1]->Scalars) ||
(VectorizableTree[1]->State == TreeEntry::NeedToGather &&
VectorizableTree[1]->Scalars.size() <
VectorizableTree[0]->Scalars.size()) ||
(VectorizableTree[1]->State == TreeEntry::NeedToGather &&
VectorizableTree[1]->getOpcode() == Instruction::ExtractElement &&
isFixedVectorShuffle(VectorizableTree[1]->Scalars, Mask))))
AreVectorizableGathers(VectorizableTree[1].get(),
VectorizableTree[0]->Scalars.size()))
return true;

// Gathering cost would be too much for tiny trees.
if (VectorizableTree[0]->State == TreeEntry::NeedToGather ||
VectorizableTree[1]->State == TreeEntry::NeedToGather)
(VectorizableTree[1]->State == TreeEntry::NeedToGather &&
VectorizableTree[0]->State != TreeEntry::ScatterVectorize))
return false;

return true;
Expand Down Expand Up @@ -4993,7 +5009,7 @@ bool BoUpSLP::isLoadCombineCandidate() const {
return true;
}

bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const {
bool BoUpSLP::isTreeTinyAndNotFullyVectorizable(bool ForReduction) const {
// No need to vectorize inserts of gathered values.
if (VectorizableTree.size() == 2 &&
isa<InsertElementInst>(VectorizableTree[0]->Scalars[0]) &&
Expand All @@ -5007,7 +5023,7 @@ bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() const {

// If we have a tiny tree (a tree whose size is less than MinTreeSize), we
// can vectorize it if we can prove it fully vectorizable.
if (isFullyVectorizableTinyTree())
if (isFullyVectorizableTinyTree(ForReduction))
return false;

assert(VectorizableTree.empty()
Expand Down Expand Up @@ -5769,7 +5785,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
VF = E->ReuseShuffleIndices.size();
ShuffleInstructionBuilder ShuffleBuilder(Builder, VF);
if (E->State == TreeEntry::NeedToGather) {
setInsertPointAfterBundle(E);
if (E->getMainOp())
setInsertPointAfterBundle(E);
Value *Vec;
SmallVector<int> Mask;
SmallVector<const TreeEntry *> Entries;
Expand Down Expand Up @@ -8447,12 +8464,12 @@ class HorizontalReduction {
while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) {
ArrayRef<Value *> VL(&ReducedVals[i], ReduxWidth);
V.buildTree(VL, IgnoreList);
if (V.isTreeTinyAndNotFullyVectorizable())
if (V.isTreeTinyAndNotFullyVectorizable(/*ForReduction=*/true))
break;
if (V.isLoadCombineReductionCandidate(RdxKind))
break;
V.reorderTopToBottom();
V.reorderBottomToTop();
V.reorderBottomToTop(/*IgnoreReorder=*/true);
V.buildExternalUses(ExternallyUsedValues);

// For a poison-safe boolean logic reduction, do not replace select
Expand Down Expand Up @@ -8630,6 +8647,7 @@ class HorizontalReduction {
assert(isPowerOf2_32(ReduxWidth) &&
"We only handle power-of-two reductions for now");

++NumVectorInstructions;
return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind,
ReductionOps.back());
}
Expand Down Expand Up @@ -8889,15 +8907,15 @@ static bool tryToVectorizeHorReductionOrInstOperands(
continue;
}
}
// Set P to nullptr to avoid re-analysis of phi node in
// matchAssociativeReduction function unless this is the root node.
P = nullptr;
// Do not try to vectorize CmpInst operands, this is done separately.
// Final attempt for binop args vectorization should happen after the loop
// to try to find reductions.
if (!isa<CmpInst>(Inst))
PostponedInsts.push_back(Inst);
}
// Set P to nullptr to avoid re-analysis of phi node in
// matchAssociativeReduction function unless this is the root node.
P = nullptr;
// Do not try to vectorize CmpInst operands, this is done separately.
// Final attempt for binop args vectorization should happen after the loop
// to try to find reductions.
if (!isa<CmpInst>(Inst))
PostponedInsts.push_back(Inst);

// Try to vectorize operands.
// Continue analysis for the instruction from the same basic block only to
Expand Down
11 changes: 3 additions & 8 deletions llvm/test/Transforms/PhaseOrdering/X86/vector-reductions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,9 @@ define i32 @ext_ext_partial_add_reduction_v4i32(<4 x i32> %x) {

define i32 @ext_ext_partial_add_reduction_and_extra_add_v4i32(<4 x i32> %x, <4 x i32> %y) {
; CHECK-LABEL: @ext_ext_partial_add_reduction_and_extra_add_v4i32(
; CHECK-NEXT: [[SHIFT:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> poison, <4 x i32> <i32 2, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP1:%.*]] = add <4 x i32> [[SHIFT]], [[Y:%.*]]
; CHECK-NEXT: [[SHIFT1:%.*]] = shufflevector <4 x i32> [[Y]], <4 x i32> poison, <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP2:%.*]] = add <4 x i32> [[TMP1]], [[SHIFT1]]
; CHECK-NEXT: [[SHIFT2:%.*]] = shufflevector <4 x i32> [[Y]], <4 x i32> poison, <4 x i32> <i32 2, i32 undef, i32 undef, i32 undef>
; CHECK-NEXT: [[TMP3:%.*]] = add <4 x i32> [[TMP2]], [[SHIFT2]]
; CHECK-NEXT: [[X2Y210:%.*]] = extractelement <4 x i32> [[TMP3]], i32 0
; CHECK-NEXT: ret i32 [[X2Y210]]
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> [[Y:%.*]], <4 x i32> <i32 4, i32 2, i32 5, i32 6>
; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP1]])
; CHECK-NEXT: ret i32 [[TMP2]]
;
%y0 = extractelement <4 x i32> %y, i32 0
%y1 = extractelement <4 x i32> %y, i32 1
Expand Down
48 changes: 26 additions & 22 deletions llvm/test/Transforms/SLPVectorizer/X86/horizontal-list.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,25 @@ define float @baz() {
; CHECK-NEXT: [[TMP2:%.*]] = load <2 x float>, <2 x float>* bitcast ([20 x float]* @arr1 to <2 x float>*), align 16
; CHECK-NEXT: [[TMP3:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP1]]
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x float> [[TMP3]], i32 0
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[TMP4]], [[CONV]]
; CHECK-NEXT: [[TMP5:%.*]] = extractelement <2 x float> [[TMP3]], i32 1
; CHECK-NEXT: [[ADD_1:%.*]] = fadd fast float [[TMP5]], [[ADD]]
; CHECK-NEXT: [[TMP6:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr, i64 0, i64 2) to <2 x float>*), align 8
; CHECK-NEXT: [[TMP7:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr1, i64 0, i64 2) to <2 x float>*), align 8
; CHECK-NEXT: [[TMP8:%.*]] = fmul fast <2 x float> [[TMP7]], [[TMP6]]
; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x float> [[TMP8]], i32 0
; CHECK-NEXT: [[ADD_2:%.*]] = fadd fast float [[TMP9]], [[ADD_1]]
; CHECK-NEXT: [[TMP10:%.*]] = extractelement <2 x float> [[TMP8]], i32 1
; CHECK-NEXT: [[ADD_3:%.*]] = fadd fast float [[TMP10]], [[ADD_2]]
; CHECK-NEXT: [[ADD7:%.*]] = fadd fast float [[ADD_3]], [[CONV]]
; CHECK-NEXT: [[ADD19:%.*]] = fadd fast float [[TMP4]], [[ADD7]]
; CHECK-NEXT: [[ADD19_1:%.*]] = fadd fast float [[TMP5]], [[ADD19]]
; CHECK-NEXT: [[ADD19_2:%.*]] = fadd fast float [[TMP9]], [[ADD19_1]]
; CHECK-NEXT: [[ADD19_3:%.*]] = fadd fast float [[TMP10]], [[ADD19_2]]
; CHECK-NEXT: store float [[ADD19_3]], float* @res, align 4
; CHECK-NEXT: ret float [[ADD19_3]]
; CHECK-NEXT: [[TMP11:%.*]] = insertelement <8 x float> poison, float [[TMP10]], i32 0
; CHECK-NEXT: [[TMP12:%.*]] = insertelement <8 x float> [[TMP11]], float [[TMP9]], i32 1
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <8 x float> [[TMP12]], float [[TMP5]], i32 2
; CHECK-NEXT: [[TMP14:%.*]] = insertelement <8 x float> [[TMP13]], float [[TMP4]], i32 3
; CHECK-NEXT: [[TMP15:%.*]] = insertelement <8 x float> [[TMP14]], float [[TMP10]], i32 4
; CHECK-NEXT: [[TMP16:%.*]] = insertelement <8 x float> [[TMP15]], float [[TMP9]], i32 5
; CHECK-NEXT: [[TMP17:%.*]] = insertelement <8 x float> [[TMP16]], float [[TMP5]], i32 6
; CHECK-NEXT: [[TMP18:%.*]] = insertelement <8 x float> [[TMP17]], float [[TMP4]], i32 7
; CHECK-NEXT: [[TMP19:%.*]] = call fast float @llvm.vector.reduce.fadd.v8f32(float -0.000000e+00, <8 x float> [[TMP18]])
; CHECK-NEXT: [[OP_EXTRA:%.*]] = fadd fast float [[TMP19]], [[CONV]]
; CHECK-NEXT: [[OP_EXTRA1:%.*]] = fadd fast float [[OP_EXTRA]], [[CONV]]
; CHECK-NEXT: store float [[OP_EXTRA1]], float* @res, align 4
; CHECK-NEXT: ret float [[OP_EXTRA1]]
;
; THRESHOLD-LABEL: @baz(
; THRESHOLD-NEXT: entry:
Expand All @@ -44,23 +46,25 @@ define float @baz() {
; THRESHOLD-NEXT: [[TMP2:%.*]] = load <2 x float>, <2 x float>* bitcast ([20 x float]* @arr1 to <2 x float>*), align 16
; THRESHOLD-NEXT: [[TMP3:%.*]] = fmul fast <2 x float> [[TMP2]], [[TMP1]]
; THRESHOLD-NEXT: [[TMP4:%.*]] = extractelement <2 x float> [[TMP3]], i32 0
; THRESHOLD-NEXT: [[ADD:%.*]] = fadd fast float [[TMP4]], [[CONV]]
; THRESHOLD-NEXT: [[TMP5:%.*]] = extractelement <2 x float> [[TMP3]], i32 1
; THRESHOLD-NEXT: [[ADD_1:%.*]] = fadd fast float [[TMP5]], [[ADD]]
; THRESHOLD-NEXT: [[TMP6:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr, i64 0, i64 2) to <2 x float>*), align 8
; THRESHOLD-NEXT: [[TMP7:%.*]] = load <2 x float>, <2 x float>* bitcast (float* getelementptr inbounds ([20 x float], [20 x float]* @arr1, i64 0, i64 2) to <2 x float>*), align 8
; THRESHOLD-NEXT: [[TMP8:%.*]] = fmul fast <2 x float> [[TMP7]], [[TMP6]]
; THRESHOLD-NEXT: [[TMP9:%.*]] = extractelement <2 x float> [[TMP8]], i32 0
; THRESHOLD-NEXT: [[ADD_2:%.*]] = fadd fast float [[TMP9]], [[ADD_1]]
; THRESHOLD-NEXT: [[TMP10:%.*]] = extractelement <2 x float> [[TMP8]], i32 1
; THRESHOLD-NEXT: [[ADD_3:%.*]] = fadd fast float [[TMP10]], [[ADD_2]]
; THRESHOLD-NEXT: [[ADD7:%.*]] = fadd fast float [[ADD_3]], [[CONV]]
; THRESHOLD-NEXT: [[ADD19:%.*]] = fadd fast float [[TMP4]], [[ADD7]]
; THRESHOLD-NEXT: [[ADD19_1:%.*]] = fadd fast float [[TMP5]], [[ADD19]]
; THRESHOLD-NEXT: [[ADD19_2:%.*]] = fadd fast float [[TMP9]], [[ADD19_1]]
; THRESHOLD-NEXT: [[ADD19_3:%.*]] = fadd fast float [[TMP10]], [[ADD19_2]]
; THRESHOLD-NEXT: store float [[ADD19_3]], float* @res, align 4
; THRESHOLD-NEXT: ret float [[ADD19_3]]
; THRESHOLD-NEXT: [[TMP11:%.*]] = insertelement <8 x float> poison, float [[TMP10]], i32 0
; THRESHOLD-NEXT: [[TMP12:%.*]] = insertelement <8 x float> [[TMP11]], float [[TMP9]], i32 1
; THRESHOLD-NEXT: [[TMP13:%.*]] = insertelement <8 x float> [[TMP12]], float [[TMP5]], i32 2
; THRESHOLD-NEXT: [[TMP14:%.*]] = insertelement <8 x float> [[TMP13]], float [[TMP4]], i32 3
; THRESHOLD-NEXT: [[TMP15:%.*]] = insertelement <8 x float> [[TMP14]], float [[TMP10]], i32 4
; THRESHOLD-NEXT: [[TMP16:%.*]] = insertelement <8 x float> [[TMP15]], float [[TMP9]], i32 5
; THRESHOLD-NEXT: [[TMP17:%.*]] = insertelement <8 x float> [[TMP16]], float [[TMP5]], i32 6
; THRESHOLD-NEXT: [[TMP18:%.*]] = insertelement <8 x float> [[TMP17]], float [[TMP4]], i32 7
; THRESHOLD-NEXT: [[TMP19:%.*]] = call fast float @llvm.vector.reduce.fadd.v8f32(float -0.000000e+00, <8 x float> [[TMP18]])
; THRESHOLD-NEXT: [[OP_EXTRA:%.*]] = fadd fast float [[TMP19]], [[CONV]]
; THRESHOLD-NEXT: [[OP_EXTRA1:%.*]] = fadd fast float [[OP_EXTRA]], [[CONV]]
; THRESHOLD-NEXT: store float [[OP_EXTRA1]], float* @res, align 4
; THRESHOLD-NEXT: ret float [[OP_EXTRA1]]
;
entry:
%0 = load i32, i32* @n, align 4
Expand Down

0 comments on commit ce14d1b

Please sign in to comment.