Skip to content

Commit

Permalink
[X86] More accurately model the cost of horizontal reductions.
Browse files Browse the repository at this point in the history
This patch attempts to more accurately model the reduction of
power of 2 vectors of types we natively support. This takes into
account the narrowing of vectors that occur as we go from 512
bits to 256 bits, to 128 bits. It also takes into account the use
of wider elements in the shuffles for the first 2 steps of a
reduction from 128 bits. And uses a v8i16 shift for the final step
of vXi8 reduction.

The default implementation uses the legalized type for the arithmetic
for all levels. And uses the single source permute cost of the
legalized type for all levels. This penalizes things like
lack of v16i8 pshufb on pre-sse3 targets and the splitting and
joining that needs to be done for integer types on AVX1. We never
need v16i8 shuffle for a reduction and we only need split AVX1 ops
when type the type wide and needs to be split. I think we're still
over costing splits and joins for AVX1, but we're closer now.

I've also removed all pairwise special casing because I don't
think we ever want to generate that on X86. I've also adjusted
the add handling to more accurately account for any type splitting
that occurs before we reach a legal type.

Differential Revision: https://reviews.llvm.org/D76478
  • Loading branch information
topperc committed Mar 22, 2020
1 parent 7cfd5de commit f4c67df
Show file tree
Hide file tree
Showing 8 changed files with 551 additions and 736 deletions.
180 changes: 102 additions & 78 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Expand Up @@ -2648,39 +2648,13 @@ int X86TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,

int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
bool IsPairwise) {
// Just use the default implementation for pair reductions.
if (IsPairwise)
return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);

// We use the Intel Architecture Code Analyzer(IACA) to measure the throughput
// and make it as the cost.

static const CostTblEntry SLMCostTblPairWise[] = {
{ ISD::FADD, MVT::v2f64, 3 },
{ ISD::ADD, MVT::v2i64, 5 },
};

static const CostTblEntry SSE2CostTblPairWise[] = {
{ ISD::FADD, MVT::v2f64, 2 },
{ ISD::FADD, MVT::v4f32, 4 },
{ ISD::ADD, MVT::v2i64, 2 }, // The data reported by the IACA tool is "1.6".
{ ISD::ADD, MVT::v2i32, 2 }, // FIXME: chosen to be less than v4i32.
{ ISD::ADD, MVT::v4i32, 3 }, // The data reported by the IACA tool is "3.5".
{ ISD::ADD, MVT::v2i16, 3 }, // FIXME: chosen to be less than v4i16
{ ISD::ADD, MVT::v4i16, 4 }, // FIXME: chosen to be less than v8i16
{ ISD::ADD, MVT::v8i16, 5 },
{ ISD::ADD, MVT::v2i8, 2 },
{ ISD::ADD, MVT::v4i8, 2 },
{ ISD::ADD, MVT::v8i8, 2 },
{ ISD::ADD, MVT::v16i8, 3 },
};

static const CostTblEntry AVX1CostTblPairWise[] = {
{ ISD::FADD, MVT::v4f64, 5 },
{ ISD::FADD, MVT::v8f32, 7 },
{ ISD::ADD, MVT::v2i64, 1 }, // The data reported by the IACA tool is "1.5".
{ ISD::ADD, MVT::v4i64, 5 }, // The data reported by the IACA tool is "4.8".
{ ISD::ADD, MVT::v8i32, 5 },
{ ISD::ADD, MVT::v16i16, 6 },
{ ISD::ADD, MVT::v32i8, 4 },
};

static const CostTblEntry SLMCostTblNoPairWise[] = {
{ ISD::FADD, MVT::v2f64, 3 },
{ ISD::ADD, MVT::v2i64, 5 },
Expand Down Expand Up @@ -2721,62 +2695,44 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
EVT VT = TLI->getValueType(DL, ValTy);
if (VT.isSimple()) {
MVT MTy = VT.getSimpleVT();
if (IsPairwise) {
if (ST->isSLM())
if (const auto *Entry = CostTableLookup(SLMCostTblPairWise, ISD, MTy))
return Entry->Cost;

if (ST->hasAVX())
if (const auto *Entry = CostTableLookup(AVX1CostTblPairWise, ISD, MTy))
return Entry->Cost;

if (ST->hasSSE2())
if (const auto *Entry = CostTableLookup(SSE2CostTblPairWise, ISD, MTy))
return Entry->Cost;
} else {
if (ST->isSLM())
if (const auto *Entry = CostTableLookup(SLMCostTblNoPairWise, ISD, MTy))
return Entry->Cost;
if (ST->isSLM())
if (const auto *Entry = CostTableLookup(SLMCostTblNoPairWise, ISD, MTy))
return Entry->Cost;

if (ST->hasAVX())
if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
return Entry->Cost;
if (ST->hasAVX())
if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
return Entry->Cost;

if (ST->hasSSE2())
if (const auto *Entry = CostTableLookup(SSE2CostTblNoPairWise, ISD, MTy))
return Entry->Cost;
}
if (ST->hasSSE2())
if (const auto *Entry = CostTableLookup(SSE2CostTblNoPairWise, ISD, MTy))
return Entry->Cost;
}

std::pair<int, MVT> LT = TLI->getTypeLegalizationCost(DL, ValTy);

MVT MTy = LT.second;

if (IsPairwise) {
if (ST->isSLM())
if (const auto *Entry = CostTableLookup(SLMCostTblPairWise, ISD, MTy))
return LT.first * Entry->Cost;

if (ST->hasAVX())
if (const auto *Entry = CostTableLookup(AVX1CostTblPairWise, ISD, MTy))
return LT.first * Entry->Cost;
unsigned ArithmeticCost = 0;
if (LT.first != 1 && MTy.isVector() &&
MTy.getVectorNumElements() < ValTy->getVectorNumElements()) {
// Type needs to be split. We need LT.first - 1 arithmetic ops.
Type *SingleOpTy = VectorType::get(ValTy->getVectorElementType(),
MTy.getVectorNumElements());
ArithmeticCost = getArithmeticInstrCost(Opcode, SingleOpTy);
ArithmeticCost *= LT.first - 1;
}

if (ST->hasSSE2())
if (const auto *Entry = CostTableLookup(SSE2CostTblPairWise, ISD, MTy))
return LT.first * Entry->Cost;
} else {
if (ST->isSLM())
if (const auto *Entry = CostTableLookup(SLMCostTblNoPairWise, ISD, MTy))
return LT.first * Entry->Cost;
if (ST->isSLM())
if (const auto *Entry = CostTableLookup(SLMCostTblNoPairWise, ISD, MTy))
return ArithmeticCost + Entry->Cost;

if (ST->hasAVX())
if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
return LT.first * Entry->Cost;
if (ST->hasAVX())
if (const auto *Entry = CostTableLookup(AVX1CostTblNoPairWise, ISD, MTy))
return ArithmeticCost + Entry->Cost;

if (ST->hasSSE2())
if (const auto *Entry = CostTableLookup(SSE2CostTblNoPairWise, ISD, MTy))
return LT.first * Entry->Cost;
}
if (ST->hasSSE2())
if (const auto *Entry = CostTableLookup(SSE2CostTblNoPairWise, ISD, MTy))
return ArithmeticCost + Entry->Cost;

// FIXME: These assume a naive kshift+binop lowering, which is probably
// conservative in most cases.
Expand Down Expand Up @@ -2825,9 +2781,9 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
};

// Handle bool allof/anyof patterns.
if (!IsPairwise && ValTy->getVectorElementType()->isIntegerTy(1)) {
if (ValTy->getVectorElementType()->isIntegerTy(1)) {
unsigned ArithmeticCost = 0;
if (MTy.isVector() &&
if (LT.first != 1 && MTy.isVector() &&
MTy.getVectorNumElements() < ValTy->getVectorNumElements()) {
// Type needs to be split. We need LT.first - 1 arithmetic ops.
Type *SingleOpTy = VectorType::get(ValTy->getVectorElementType(),
Expand All @@ -2848,9 +2804,77 @@ int X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, Type *ValTy,
if (ST->hasSSE2())
if (const auto *Entry = CostTableLookup(SSE2BoolReduction, ISD, MTy))
return ArithmeticCost + Entry->Cost;

return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);
}

unsigned NumVecElts = ValTy->getVectorNumElements();
unsigned ScalarSize = ValTy->getScalarSizeInBits();

// Special case power of 2 reductions where the scalar type isn't changed
// by type legalization.
if (!isPowerOf2_32(NumVecElts) || ScalarSize != MTy.getScalarSizeInBits())
return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);

unsigned ReductionCost = 0;

Type *Ty = ValTy;
if (LT.first != 1 && MTy.isVector() &&
MTy.getVectorNumElements() < ValTy->getVectorNumElements()) {
// Type needs to be split. We need LT.first - 1 arithmetic ops.
Ty = VectorType::get(ValTy->getVectorElementType(),
MTy.getVectorNumElements());
ReductionCost = getArithmeticInstrCost(Opcode, Ty);
ReductionCost *= LT.first - 1;
NumVecElts = MTy.getVectorNumElements();
}

// Now handle reduction with the legal type, taking into account size changes
// at each level.
while (NumVecElts > 1) {
// Determine the size of the remaining vector we need to reduce.
unsigned Size = NumVecElts * ScalarSize;
NumVecElts /= 2;
// If we're reducing from 256/512 bits, use an extract_subvector.
if (Size > 128) {
Type *SubTy = VectorType::get(ValTy->getVectorElementType(), NumVecElts);
ReductionCost +=
getShuffleCost(TTI::SK_ExtractSubvector, Ty, NumVecElts, SubTy);
Ty = SubTy;
} else if (Size == 128) {
// Reducing from 128 bits is a permute of v2f64/v2i64.
Type *ShufTy;
if (ValTy->isFloatingPointTy())
ShufTy = VectorType::get(Type::getDoubleTy(ValTy->getContext()), 2);
else
ShufTy = VectorType::get(Type::getInt64Ty(ValTy->getContext()), 2);
ReductionCost +=
getShuffleCost(TTI::SK_PermuteSingleSrc, ShufTy, 0, nullptr);
} else if (Size == 64) {
// Reducing from 64 bits is a shuffle of v4f32/v4i32.
Type *ShufTy;
if (ValTy->isFloatingPointTy())
ShufTy = VectorType::get(Type::getFloatTy(ValTy->getContext()), 4);
else
ShufTy = VectorType::get(Type::getInt32Ty(ValTy->getContext()), 4);
ReductionCost +=
getShuffleCost(TTI::SK_PermuteSingleSrc, ShufTy, 0, nullptr);
} else {
// Reducing from smaller size is a shift by immediate.
Type *ShiftTy = VectorType::get(
Type::getIntNTy(ValTy->getContext(), Size), 128 / Size);
ReductionCost += getArithmeticInstrCost(
Instruction::LShr, ShiftTy, TargetTransformInfo::OK_AnyValue,
TargetTransformInfo::OK_UniformConstantValue,
TargetTransformInfo::OP_None, TargetTransformInfo::OP_None);
}

// Add the arithmetic op for this level.
ReductionCost += getArithmeticInstrCost(Opcode, Ty);
}

return BaseT::getArithmeticReductionCost(Opcode, ValTy, IsPairwise);
// Add the final extract element to the cost.
return ReductionCost + getVectorInstrCost(Instruction::ExtractElement, Ty, 0);
}

int X86TTIImpl::getMinMaxReductionCost(Type *ValTy, Type *CondTy,
Expand Down

0 comments on commit f4c67df

Please sign in to comment.