Skip to content

Commit

Permalink
[TTI][X86] getMinMaxCost - use existing integer min/max intrinsic cos…
Browse files Browse the repository at this point in the history
…t values instead of maintaining a duplicate cost table

getMinMaxCost has an alternative set of min/max costs to getIntrinsicInstrCost that are only used by getMinMaxReductionCost, but are a lot less thorough and fallback to an expansion in most cases resulting in cost overestimations - we're better off just using getIntrinsicInstrCost.

getIntrinsicInstrCost is still missing complete FMINNUM/FMAXNUM costs, so until then getMinMaxCost will still be used for these, after that we can remove getMinMaxCost and have getMinMaxReductionCost call getIntrinsicInstrCost directly.

Fixes regression noticed in D148036
  • Loading branch information
RKSimon committed Apr 12, 2023
1 parent d514382 commit 63c3895
Show file tree
Hide file tree
Showing 8 changed files with 354 additions and 685 deletions.
107 changes: 20 additions & 87 deletions llvm/lib/Target/X86/X86TargetTransformInfo.cpp
Expand Up @@ -5191,106 +5191,48 @@ X86TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
}

InstructionCost X86TTIImpl::getMinMaxCost(Type *Ty, Type *CondTy,
TTI::TargetCostKind CostKind,
bool IsUnsigned) {
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);

MVT MTy = LT.second;

int ISD;
if (Ty->isIntOrIntVectorTy()) {
ISD = IsUnsigned ? ISD::UMIN : ISD::SMIN;
} else {
assert(Ty->isFPOrFPVectorTy() &&
"Expected float point or integer vector type.");
ISD = ISD::FMINNUM;
Intrinsic::ID Id = IsUnsigned ? Intrinsic::umin : Intrinsic::smin;
IntrinsicCostAttributes ICA(Id, Ty, {Ty, Ty});
return getIntrinsicInstrCost(ICA, CostKind);
}

// TODO: Use getIntrinsicInstrCost once ISD::FMINNUM costs are improved.
assert(Ty->isFPOrFPVectorTy() &&
"Expected float point or integer vector type.");
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
MVT MTy = LT.second;
int ISD = ISD::FMINNUM;

static const CostTblEntry SSE1CostTbl[] = {
{ISD::FMINNUM, MVT::v4f32, 1},
{ISD::FMINNUM, MVT::v4f32, 1},
};

static const CostTblEntry SSE2CostTbl[] = {
{ISD::FMINNUM, MVT::v2f64, 1},
{ISD::SMIN, MVT::v8i16, 1},
{ISD::UMIN, MVT::v16i8, 1},
};

static const CostTblEntry SSE41CostTbl[] = {
{ISD::SMIN, MVT::v4i32, 1},
{ISD::UMIN, MVT::v4i32, 1},
{ISD::UMIN, MVT::v8i16, 1},
{ISD::SMIN, MVT::v16i8, 1},
};

static const CostTblEntry SSE42CostTbl[] = {
{ISD::UMIN, MVT::v2i64, 3}, // xor+pcmpgtq+blendvpd
{ISD::FMINNUM, MVT::v2f64, 1},
};

static const CostTblEntry AVX1CostTbl[] = {
{ISD::FMINNUM, MVT::v8f32, 1},
{ISD::FMINNUM, MVT::v4f64, 1},
{ISD::SMIN, MVT::v8i32, 3},
{ISD::UMIN, MVT::v8i32, 3},
{ISD::SMIN, MVT::v16i16, 3},
{ISD::UMIN, MVT::v16i16, 3},
{ISD::SMIN, MVT::v32i8, 3},
{ISD::UMIN, MVT::v32i8, 3},
};

static const CostTblEntry AVX2CostTbl[] = {
{ISD::SMIN, MVT::v8i32, 1},
{ISD::UMIN, MVT::v8i32, 1},
{ISD::SMIN, MVT::v16i16, 1},
{ISD::UMIN, MVT::v16i16, 1},
{ISD::SMIN, MVT::v32i8, 1},
{ISD::UMIN, MVT::v32i8, 1},
{ISD::FMINNUM, MVT::v8f32, 1},
{ISD::FMINNUM, MVT::v4f64, 1},
};

static const CostTblEntry AVX512CostTbl[] = {
{ISD::FMINNUM, MVT::v16f32, 1},
{ISD::FMINNUM, MVT::v8f64, 1},
{ISD::SMIN, MVT::v2i64, 1},
{ISD::UMIN, MVT::v2i64, 1},
{ISD::SMIN, MVT::v4i64, 1},
{ISD::UMIN, MVT::v4i64, 1},
{ISD::SMIN, MVT::v8i64, 1},
{ISD::UMIN, MVT::v8i64, 1},
{ISD::SMIN, MVT::v16i32, 1},
{ISD::UMIN, MVT::v16i32, 1},
};

static const CostTblEntry AVX512BWCostTbl[] = {
{ISD::SMIN, MVT::v32i16, 1},
{ISD::UMIN, MVT::v32i16, 1},
{ISD::SMIN, MVT::v64i8, 1},
{ISD::UMIN, MVT::v64i8, 1},
{ISD::FMINNUM, MVT::v16f32, 1},
{ISD::FMINNUM, MVT::v8f64, 1},
};

// If we have a native MIN/MAX instruction for this type, use it.
if (ST->hasBWI())
if (const auto *Entry = CostTableLookup(AVX512BWCostTbl, ISD, MTy))
return LT.first * Entry->Cost;

if (ST->hasAVX512())
if (const auto *Entry = CostTableLookup(AVX512CostTbl, ISD, MTy))
return LT.first * Entry->Cost;

if (ST->hasAVX2())
if (const auto *Entry = CostTableLookup(AVX2CostTbl, ISD, MTy))
return LT.first * Entry->Cost;

if (ST->hasAVX())
if (const auto *Entry = CostTableLookup(AVX1CostTbl, ISD, MTy))
return LT.first * Entry->Cost;

if (ST->hasSSE42())
if (const auto *Entry = CostTableLookup(SSE42CostTbl, ISD, MTy))
return LT.first * Entry->Cost;

if (ST->hasSSE41())
if (const auto *Entry = CostTableLookup(SSE41CostTbl, ISD, MTy))
return LT.first * Entry->Cost;

if (ST->hasSSE2())
if (const auto *Entry = CostTableLookup(SSE2CostTbl, ISD, MTy))
return LT.first * Entry->Cost;
Expand All @@ -5299,17 +5241,8 @@ InstructionCost X86TTIImpl::getMinMaxCost(Type *Ty, Type *CondTy,
if (const auto *Entry = CostTableLookup(SSE1CostTbl, ISD, MTy))
return LT.first * Entry->Cost;

unsigned CmpOpcode;
if (Ty->isFPOrFPVectorTy()) {
CmpOpcode = Instruction::FCmp;
} else {
assert(Ty->isIntOrIntVectorTy() &&
"expecting floating point or integer type for min/max reduction");
CmpOpcode = Instruction::ICmp;
}

TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
// Otherwise fall back to cmp+select.
unsigned CmpOpcode = Instruction::FCmp;
InstructionCost Result =
getCmpSelInstrCost(CmpOpcode, Ty, CondTy, CmpInst::BAD_ICMP_PREDICATE,
CostKind) +
Expand Down Expand Up @@ -5410,7 +5343,7 @@ X86TTIImpl::getMinMaxReductionCost(VectorType *ValTy, VectorType *CondTy,
MTy.getVectorNumElements());
auto *SubCondTy = FixedVectorType::get(CondTy->getElementType(),
MTy.getVectorNumElements());
MinMaxCost = getMinMaxCost(Ty, SubCondTy, IsUnsigned);
MinMaxCost = getMinMaxCost(Ty, SubCondTy, CostKind, IsUnsigned);
MinMaxCost *= LT.first - 1;
NumVecElts = MTy.getVectorNumElements();
}
Expand Down Expand Up @@ -5483,7 +5416,7 @@ X86TTIImpl::getMinMaxReductionCost(VectorType *ValTy, VectorType *CondTy,
// Add the arithmetic op for this level.
auto *SubCondTy =
FixedVectorType::get(CondTy->getElementType(), Ty->getNumElements());
MinMaxCost += getMinMaxCost(Ty, SubCondTy, IsUnsigned);
MinMaxCost += getMinMaxCost(Ty, SubCondTy, CostKind, IsUnsigned);
}

// Add the final extract element to the cost.
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/X86/X86TargetTransformInfo.h
Expand Up @@ -205,7 +205,8 @@ class X86TTIImpl : public BasicTTIImplBase<X86TTIImpl> {
std::optional<FastMathFlags> FMF,
TTI::TargetCostKind CostKind);

InstructionCost getMinMaxCost(Type *Ty, Type *CondTy, bool IsUnsigned);
InstructionCost getMinMaxCost(Type *Ty, Type *CondTy, TTI::TargetCostKind CostKind,
bool IsUnsigned);

InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
bool IsUnsigned,
Expand Down

0 comments on commit 63c3895

Please sign in to comment.