diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index e74d01be3de8e..9408485e95cbb 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -865,6 +865,17 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase { LI->getPointerAddressSpace(), CostKind, I); } + case Instruction::Select: { + Type *CondTy = U->getOperand(0)->getType(); + return TargetTTI->getCmpSelInstrCost(Opcode, U->getType(), CondTy, + CostKind, I); + } + case Instruction::ICmp: + case Instruction::FCmp: { + Type *ValTy = U->getOperand(0)->getType(); + return TargetTTI->getCmpSelInstrCost(Opcode, ValTy, U->getType(), + CostKind, I); + } } // By default, just classify everything as 'basic'. return TTI::TCC_Basic; diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 8d83306ba3be0..361d9813a60ec 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -838,6 +838,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // Selects on vectors are actually vector selects. if (ISD == ISD::SELECT) { assert(CondTy && "CondTy must exist"); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 4bcadfe09a8c8..ed5db33731271 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1296,18 +1296,10 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { Op1VK, Op2VK, Op1VP, Op2VP, Operands, I); } - case Instruction::Select: { - const SelectInst *SI = cast(I); - Type *CondTy = SI->getCondition()->getType(); - return getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy, - CostKind, I); - } + case Instruction::Select: case Instruction::ICmp: - case Instruction::FCmp: { - Type *ValTy = I->getOperand(0)->getType(); - return getCmpSelInstrCost(I->getOpcode(), ValTy, I->getType(), - CostKind, I); - } + case Instruction::FCmp: + return getUserCost(I, CostKind); case Instruction::Store: case Instruction::Load: return getUserCost(I, CostKind); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 3030c949db43e..cc5157fbf9c8f 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -620,6 +620,9 @@ int AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); int ISD = TLI->InstructionOpcodeToISD(Opcode); // We don't lower some vector selects well that are wider than the register diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp index 0709f7c3028e5..6b4899f624d8c 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -508,6 +508,10 @@ int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy, int ARMTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + int ISD = TLI->InstructionOpcodeToISD(Opcode); // On NEON a vector select gets lowered to vbsl. if (ST->hasNEON() && ValTy->isVectorTy() && ISD == ISD::SELECT) { diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp index e0eaf55aba6e9..b70386f1565cd 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -236,7 +236,7 @@ unsigned HexagonTTIImpl::getInterleavedMemoryOpCost(unsigned Opcode, unsigned HexagonTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { - if (ValTy->isVectorTy()) { + if (ValTy->isVectorTy() && CostKind == TTI::TCK_RecipThroughput) { std::pair LT = TLI.getTypeLegalizationCost(DL, ValTy); if (Opcode == Instruction::FCmp) return LT.first + FloatFactor * getTypeNumElements(ValTy); diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp index 8c77d743562f8..bc8897664b141 100644 --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -779,6 +779,9 @@ int PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { int Cost = BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return Cost; return vectorCostAdjustment(Cost, Opcode, ValTy, nullptr); } diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp index ef26a91d02a95..f7a4c76fbbd5e 100644 --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -837,6 +837,9 @@ int SystemZTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind); + if (!ValTy->isVectorTy()) { switch (Opcode) { case Instruction::ICmp: { diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 3980909ca5fd4..964ac06c863a2 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -2051,6 +2051,10 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, TTI::TargetCostKind CostKind, const Instruction *I) { + // TODO: Handle other cost kinds. + if (CostKind != TTI::TCK_RecipThroughput) + return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, CostKind, I); + // Legalize the type. std::pair LT = TLI->getTypeLegalizationCost(DL, ValTy);