diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 60de70dcb16a0..bd8d29cb22a12 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -826,18 +826,18 @@ class TargetTransformInfoImplCRTPBase : public TargetTransformInfoImplBase { return TTI::TCC_Expensive; case Instruction::IntToPtr: case Instruction::PtrToInt: + case Instruction::SIToFP: + case Instruction::UIToFP: + case Instruction::FPToUI: + case Instruction::FPToSI: case Instruction::Trunc: + case Instruction::FPTrunc: case Instruction::BitCast: - if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) == - TTI::TCC_Free) - return TTI::TCC_Free; - break; case Instruction::FPExt: case Instruction::SExt: case Instruction::ZExt: - if (TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I) == TTI::TCC_Free) - return TTI::TCC_Free; - break; + case Instruction::AddrSpaceCast: + return TargetTTI->getCastInstrCost(Opcode, Ty, OpTy, CostKind, I); } // By default, just classify everything as 'basic'. return TTI::TCC_Basic; diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 86952a5ad6592..a14199515faf5 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1325,10 +1325,8 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { case Instruction::Trunc: case Instruction::FPTrunc: case Instruction::BitCast: - case Instruction::AddrSpaceCast: { - Type *SrcTy = I->getOperand(0)->getType(); - return getCastInstrCost(I->getOpcode(), I->getType(), SrcTy, CostKind, I); - } + case Instruction::AddrSpaceCast: + return getUserCost(I, CostKind); case Instruction::ExtractElement: { const ExtractElementInst *EEI = cast(I); ConstantInt *CI = dyn_cast(I->getOperand(1)); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index 1324945c4d4ea..f0961646c31ff 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -295,11 +295,18 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, } } + // TODO: Allow non-throughput costs that aren't binary. + auto AdjustCost = [&CostKind](int Cost) { + if (CostKind != TTI::TCK_RecipThroughput) + return Cost == 0 ? 0 : 1; + return Cost; + }; + EVT SrcTy = TLI->getValueType(DL, Src); EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); static const TypeConversionCostTblEntry ConversionTbl[] = { @@ -401,9 +408,9 @@ int AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) - return Entry->Cost; + return AdjustCost(Entry->Cost); - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); } int AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst, diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp index 1ca74bfc3df08..c1af19727ba2b 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -173,6 +173,13 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + // TODO: Allow non-throughput costs that aren't binary. + auto AdjustCost = [&CostKind](int Cost) { + if (CostKind != TTI::TCK_RecipThroughput) + return Cost == 0 ? 0 : 1; + return Cost; + }; + // Single to/from double precision conversions. static const CostTblEntry NEONFltDblTbl[] = { // Vector fptrunc/fpext conversions. @@ -185,14 +192,14 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, ISD == ISD::FP_EXTEND)) { std::pair LT = TLI->getTypeLegalizationCost(DL, Src); if (const auto *Entry = CostTableLookup(NEONFltDblTbl, ISD, LT.second)) - return LT.first * Entry->Cost; + return AdjustCost(LT.first * Entry->Cost); } EVT SrcTy = TLI->getValueType(DL, Src); EVT DstTy = TLI->getValueType(DL, Dst); if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); // The extend of a load is free if (I && isa(I->getOperand(0))) { @@ -212,7 +219,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, }; if (const auto *Entry = ConvertCostTableLookup( LoadConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) - return Entry->Cost; + return AdjustCost(Entry->Cost); static const TypeConversionCostTblEntry MVELoadConversionTbl[] = { {ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 0}, @@ -226,7 +233,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (const auto *Entry = ConvertCostTableLookup(MVELoadConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) - return Entry->Cost; + return AdjustCost(Entry->Cost); } } @@ -253,7 +260,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (auto *Entry = ConvertCostTableLookup(NEONDoubleWidthTbl, UserISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) { - return Entry->Cost; + return AdjustCost(Entry->Cost); } } @@ -347,7 +354,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (const auto *Entry = ConvertCostTableLookup(NEONVectorConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) - return Entry->Cost; + return AdjustCost(Entry->Cost); } // Scalar float to integer conversions. @@ -377,7 +384,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (const auto *Entry = ConvertCostTableLookup(NEONFloatConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) - return Entry->Cost; + return AdjustCost(Entry->Cost); } // Scalar integer to float conversions. @@ -408,7 +415,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (const auto *Entry = ConvertCostTableLookup(NEONIntegerConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) - return Entry->Cost; + return AdjustCost(Entry->Cost); } // MVE extend costs, taken from codegen tests. i8->i16 or i16->i32 is one @@ -433,7 +440,7 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (const auto *Entry = ConvertCostTableLookup(MVEVectorConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) - return Entry->Cost * ST->getMVEVectorCostFactor(); + return AdjustCost(Entry->Cost * ST->getMVEVectorCostFactor()); } // Scalar integer conversion costs. @@ -452,13 +459,14 @@ int ARMTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (const auto *Entry = ConvertCostTableLookup(ARMIntegerConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) - return Entry->Cost; + return AdjustCost(Entry->Cost); } int BaseCost = ST->hasMVEIntegerOps() && Src->isVectorTy() ? ST->getMVEVectorCostFactor() : 1; - return BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return AdjustCost( + BaseCost * BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); } int ARMTTIImpl::getVectorInstrCost(unsigned Opcode, Type *ValTy, diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp index 92e32ca99090e..381941df2fb46 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -263,7 +263,11 @@ unsigned HexagonTTIImpl::getCastInstrCost(unsigned Opcode, Type *DstTy, std::pair SrcLT = TLI.getTypeLegalizationCost(DL, SrcTy); std::pair DstLT = TLI.getTypeLegalizationCost(DL, DstTy); - return std::max(SrcLT.first, DstLT.first) + FloatFactor * (SrcN + DstN); + unsigned Cost = std::max(SrcLT.first, DstLT.first) + FloatFactor * (SrcN + DstN); + // TODO: Allow non-throughput costs that aren't binary. + if (CostKind != TTI::TCK_RecipThroughput) + return Cost == 0 ? 0 : 1; + return Cost; } return 1; } diff --git a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp index 002905febbc8b..a41c6b41a991b 100644 --- a/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp +++ b/llvm/lib/Target/PowerPC/PPCTargetTransformInfo.cpp @@ -212,7 +212,8 @@ int PPCTTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, unsigned PPCTTIImpl::getUserCost(const User *U, ArrayRef Operands, TTI::TargetCostKind CostKind) { - if (U->getType()->isVectorTy()) { + // We already implement getCastInstrCost and perform the vector adjustment there. + if (!isa(U) && U->getType()->isVectorTy()) { // Instructions that need to be split should cost more. std::pair LT = TLI->getTypeLegalizationCost(DL, U->getType()); return LT.first * BaseT::getUserCost(U, Operands, CostKind); @@ -760,7 +761,11 @@ int PPCTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, assert(TLI->InstructionOpcodeToISD(Opcode) && "Invalid opcode"); int Cost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); - return vectorCostAdjustment(Cost, Opcode, Dst, Src); + Cost = vectorCostAdjustment(Cost, Opcode, Dst, Src); + // TODO: Allow non-throughput costs that aren't binary. + if (CostKind != TTI::TCK_RecipThroughput) + return Cost == 0 ? 0 : 1; + return Cost; } int PPCTTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, diff --git a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp index d9efb40f0ab65..bce02cc793bf6 100644 --- a/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp +++ b/llvm/lib/Target/SystemZ/SystemZTargetTransformInfo.cpp @@ -691,6 +691,12 @@ getBoolVecToIntConversionCost(unsigned Opcode, Type *Dst, int SystemZTTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, TTI::TargetCostKind CostKind, const Instruction *I) { + // FIXME: Can the logic below also be used for these cost kinds? + if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) { + int BaseCost = BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return BaseCost == 0 ? BaseCost : 1; + } + unsigned DstScalarBits = Dst->getScalarSizeInBits(); unsigned SrcScalarBits = Src->getScalarSizeInBits(); diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 4170b102f2b31..6bfcadeaf8b67 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -1368,6 +1368,13 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); + // TODO: Allow non-throughput costs that aren't binary. + auto AdjustCost = [&CostKind](int Cost) { + if (CostKind != TTI::TCK_RecipThroughput) + return Cost == 0 ? 0 : 1; + return Cost; + }; + // FIXME: Need a better design of the cost table to handle non-simple types of // potential massive combinations (elem_num x src_type x dst_type). @@ -1969,7 +1976,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (ST->hasSSE2() && !ST->hasAVX()) { if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD, LTDest.second, LTSrc.second)) - return LTSrc.first * Entry->Cost; + return AdjustCost(LTSrc.first * Entry->Cost); } EVT SrcTy = TLI->getValueType(DL, Src); @@ -1977,7 +1984,7 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, // The function getSimpleVT only handles simple value types. if (!SrcTy.isSimple() || !DstTy.isSimple()) - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind); + return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind)); MVT SimpleSrcTy = SrcTy.getSimpleVT(); MVT SimpleDstTy = DstTy.getSimpleVT(); @@ -1986,59 +1993,59 @@ int X86TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, if (ST->hasBWI()) if (const auto *Entry = ConvertCostTableLookup(AVX512BWConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); if (ST->hasDQI()) if (const auto *Entry = ConvertCostTableLookup(AVX512DQConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); if (ST->hasAVX512()) if (const auto *Entry = ConvertCostTableLookup(AVX512FConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); } if (ST->hasBWI()) if (const auto *Entry = ConvertCostTableLookup(AVX512BWVLConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); if (ST->hasDQI()) if (const auto *Entry = ConvertCostTableLookup(AVX512DQVLConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); if (ST->hasAVX512()) if (const auto *Entry = ConvertCostTableLookup(AVX512VLConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); if (ST->hasAVX2()) { if (const auto *Entry = ConvertCostTableLookup(AVX2ConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); } if (ST->hasAVX()) { if (const auto *Entry = ConvertCostTableLookup(AVXConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); } if (ST->hasSSE41()) { if (const auto *Entry = ConvertCostTableLookup(SSE41ConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); } if (ST->hasSSE2()) { if (const auto *Entry = ConvertCostTableLookup(SSE2ConversionTbl, ISD, SimpleDstTy, SimpleSrcTy)) - return Entry->Cost; + return AdjustCost(Entry->Cost); } - return BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I); + return AdjustCost(BaseT::getCastInstrCost(Opcode, Dst, Src, CostKind, I)); } int X86TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,