diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 5882972a358cf..5a70120ea544d 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -906,6 +906,16 @@ class TargetTransformInfo { struct OperandValueInfo { OperandValueKind Kind = OK_AnyValue; OperandValueProperties Properties = OP_None; + + bool isConstant() const { + return Kind == OK_UniformConstantValue || Kind == OK_NonUniformConstantValue; + } + bool isUniform() const { + return Kind == OK_UniformConstantValue || Kind == OK_UniformValue; + } + bool isPowerOf2() const { + return Properties == OP_PowerOf2; + } }; /// \return the number of registers in the target-provided register class. diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp index f9a977c73e735..66c011104dffc 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1982,6 +1982,9 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( TTI::OperandValueProperties Opd1PropInfo, TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { + + const TTI::OperandValueInfo Op2Info = {Opd2Info, Opd2PropInfo}; + // TODO: Handle more cost kinds. if (CostKind != TTI::TCK_RecipThroughput) return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info, @@ -1997,8 +2000,7 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Opd1Info, Opd2Info, Opd1PropInfo, Opd2PropInfo); case ISD::SDIV: - if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue && - Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) { + if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) { // On AArch64, scalar signed division by constants power-of-two are // normally expanded to the sequence ADD + CMP + SELECT + SRA. // The OperandValue properties many not be same as that of previous @@ -2019,7 +2021,7 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost( } [[fallthrough]]; case ISD::UDIV: { - if (Opd2Info == TargetTransformInfo::OK_UniformConstantValue) { + if (Op2Info.isConstant() && Op2Info.isUniform()) { auto VT = TLI->getValueType(DL, Ty); if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) { // Vector signed division by constant are expanded to the diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index a3d2dfe24c3ba..703ad9f5dbe6e 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -514,11 +514,9 @@ InstructionCost RISCVTTIImpl::getExtendedReductionCost( } InstructionCost RISCVTTIImpl::getVectorImmCost(VectorType *VecTy, - TTI::OperandValueKind OpInfo, - TTI::OperandValueProperties PropInfo, + TTI::OperandValueInfo OpInfo, TTI::TargetCostKind CostKind) { - assert((OpInfo == TTI::OK_UniformConstantValue || - OpInfo == TTI::OK_NonUniformConstantValue) && "non constant operand?"); + assert(OpInfo.isConstant() && "non constant operand?"); APInt PseudoAddr = APInt::getAllOnes(DL.getPointerSizeInBits()); // Add a cost of address load + the cost of the vector load. return RISCVMatInt::getIntMatCost(PseudoAddr, DL.getPointerSizeInBits(), @@ -532,16 +530,14 @@ InstructionCost RISCVTTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, - TTI::OperandValueKind OpdInfo, + TTI::OperandValueKind OpdKind, const Instruction *I) { + const TTI::OperandValueInfo OpInfo = {OpdKind, TTI::OP_None}; InstructionCost Cost = 0; - if (Opcode == Instruction::Store && isa(Src) && - (OpdInfo == TTI::OK_UniformConstantValue || - OpdInfo == TTI::OK_NonUniformConstantValue)) { - Cost += getVectorImmCost(cast(Src), OpdInfo, TTI::OP_None, CostKind); - } + if (Opcode == Instruction::Store && isa(Src) && OpInfo.isConstant()) + Cost += getVectorImmCost(cast(Src), OpInfo, CostKind); return Cost + BaseT::getMemoryOpCost(Opcode, Src, Alignment, AddressSpace, - CostKind, OpdInfo, I); + CostKind, OpInfo.Kind, I); } void RISCVTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index 0d424b2b0a0c7..43be1cf599ac2 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -54,8 +54,7 @@ class RISCVTTIImpl : public BasicTTIImplBase { /// Return the cost of materializing a vector immediate, assuming it does /// not get folded into the using instruction(s). InstructionCost getVectorImmCost(VectorType *VecTy, - TTI::OperandValueKind OpInfo, - TTI::OperandValueProperties PropInfo, + TTI::OperandValueInfo OpInfo, TTI::TargetCostKind CostKind); InstructionCost getIntImmCost(const APInt &Imm, Type *Ty, diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp index c1cd58d140b78..8df77a8a6fef7 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp @@ -57,6 +57,8 @@ InstructionCost WebAssemblyTTIImpl::getArithmeticInstrCost( TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { + const TTI::OperandValueInfo Op2Info = {Opd2Info, Opd2PropInfo}; + InstructionCost Cost = BasicTTIImplBase::getArithmeticInstrCost( Opcode, Ty, CostKind, Opd1Info, Opd2Info, Opd1PropInfo, Opd2PropInfo); @@ -69,8 +71,7 @@ InstructionCost WebAssemblyTTIImpl::getArithmeticInstrCost( // SIMD128's shifts currently only accept a scalar shift count. For each // element, we'll need to extract, op, insert. The following is a rough // approximation. - if (Opd2Info != TTI::OK_UniformValue && - Opd2Info != TTI::OK_UniformConstantValue) + if (!Op2Info.isUniform()) Cost = cast(VTy)->getNumElements() * (TargetTransformInfo::TCC_Basic + diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 74885eba5d034..b3d88b84ad68f 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -180,6 +180,9 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( TTI::OperandValueProperties Opd1PropInfo, TTI::OperandValueProperties Opd2PropInfo, ArrayRef Args, const Instruction *CxtI) { + + const TTI::OperandValueInfo Op2Info = {Op2Kind, Opd2PropInfo}; + // vXi8 multiplications are always promoted to vXi16. if (Opcode == Instruction::Mul && Ty->isVectorTy() && Ty->getScalarSizeInBits() == 8) { @@ -232,10 +235,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( } // Vector multiply by pow2 will be simplified to shifts. - if (ISD == ISD::MUL && - (Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) && - Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) + if (ISD == ISD::MUL && Op2Info.isConstant() && Op2Info.isPowerOf2()) return getArithmeticInstrCost(Instruction::Shl, Ty, CostKind, Op1Kind, Op2Kind, TargetTransformInfo::OP_None, TargetTransformInfo::OP_None); @@ -245,9 +245,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( // The OperandValue properties may not be the same as that of the previous // operation; conservatively assume OP_None. if ((ISD == ISD::SDIV || ISD == ISD::SREM) && - (Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) && - Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) { + Op2Info.isConstant() && Op2Info.isPowerOf2()) { InstructionCost Cost = 2 * getArithmeticInstrCost(Instruction::AShr, Ty, CostKind, Op1Kind, Op2Kind, TargetTransformInfo::OP_None, @@ -272,9 +270,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( // Vector unsigned division/remainder will be simplified to shifts/masks. if ((ISD == ISD::UDIV || ISD == ISD::UREM) && - (Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) && - Opd2PropInfo == TargetTransformInfo::OP_PowerOf2) { + Op2Info.isConstant() && Op2Info.isPowerOf2()) { if (ISD == ISD::UDIV) return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind, Op1Kind, Op2Kind, TargetTransformInfo::OP_None, @@ -372,8 +368,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::SRA, MVT::v64i8, 4 }, // psrlw, pand, pxor, psubb. }; - if (Op2Kind == TargetTransformInfo::OK_UniformConstantValue && - ST->hasBWI()) { + if (Op2Info.isUniform() && Op2Info.isConstant() && ST->hasBWI()) { if (const auto *Entry = CostTableLookup(AVX512BWUniformConstCostTable, ISD, LT.second)) return LT.first * Entry->Cost; @@ -394,8 +389,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::UREM, MVT::v16i32, 7 }, // pmuludq+mul+sub sequence }; - if (Op2Kind == TargetTransformInfo::OK_UniformConstantValue && - ST->hasAVX512()) { + if (Op2Info.isUniform() && Op2Info.isConstant() && ST->hasAVX512()) { if (const auto *Entry = CostTableLookup(AVX512UniformConstCostTable, ISD, LT.second)) return LT.first * Entry->Cost; @@ -414,8 +408,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::UREM, MVT::v8i32, 7 }, // pmuludq+mul+sub sequence }; - if (Op2Kind == TargetTransformInfo::OK_UniformConstantValue && - ST->hasAVX2()) { + if (Op2Info.isUniform() && Op2Info.isConstant() && ST->hasAVX2()) { if (const auto *Entry = CostTableLookup(AVX2UniformConstCostTable, ISD, LT.second)) return LT.first * Entry->Cost; @@ -441,7 +434,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( }; // XOP has faster vXi8 shifts. - if (Op2Kind == TargetTransformInfo::OK_UniformConstantValue && + if (Op2Info.isUniform() && Op2Info.isConstant() && ST->hasSSE2() && !ST->hasXOP()) { if (const auto *Entry = CostTableLookup(SSE2UniformConstCostTable, ISD, LT.second)) @@ -459,9 +452,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::UREM, MVT::v32i16, 8 }, // vpmulhuw+mul+sub sequence }; - if ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) && - ST->hasBWI()) { + if (Op2Info.isConstant() && ST->hasBWI()) { if (const auto *Entry = CostTableLookup(AVX512BWConstCostTable, ISD, LT.second)) return LT.first * Entry->Cost; @@ -482,9 +473,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::UREM, MVT::v32i16, 16 }, // 2*vpmulhuw+mul+sub sequence }; - if ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) && - ST->hasAVX512()) { + if (Op2Info.isConstant() && ST->hasAVX512()) { if (const auto *Entry = CostTableLookup(AVX512ConstCostTable, ISD, LT.second)) return LT.first * Entry->Cost; @@ -505,9 +494,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::UREM, MVT::v8i32, 19 }, // vpmuludq+mul+sub sequence }; - if ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) && - ST->hasAVX2()) { + if (Op2Info.isConstant() && ST->hasAVX2()) { if (const auto *Entry = CostTableLookup(AVX2ConstCostTable, ISD, LT.second)) return LT.first * Entry->Cost; } @@ -539,9 +526,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::UREM, MVT::v4i32, 20 }, // pmuludq+mul+sub sequence }; - if ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue) && - ST->hasSSE2()) { + if (Op2Info.isConstant() && ST->hasSSE2()) { // pmuldq sequence. if (ISD == ISD::SDIV && LT.second == MVT::v8i32 && ST->hasAVX()) return LT.first * 32; @@ -598,9 +583,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::SRL, MVT::v4i64, 1 }, // psrlq }; - if (ST->hasAVX2() && - ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue) || - (Op2Kind == TargetTransformInfo::OK_UniformValue))) { + if (ST->hasAVX2() && Op2Info.isUniform()) { if (const auto *Entry = CostTableLookup(AVX2UniformCostTable, ISD, LT.second)) return LT.first * Entry->Cost; @@ -620,9 +603,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::SRA, MVT::v4i32, 1 }, // psrad. }; - if (ST->hasSSE2() && - ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue) || - (Op2Kind == TargetTransformInfo::OK_UniformValue))) { + if (ST->hasSSE2() && Op2Info.isUniform()) { if (const auto *Entry = CostTableLookup(SSE2UniformCostTable, ISD, LT.second)) return LT.first * Entry->Cost; @@ -717,9 +698,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( }; if (ST->hasAVX512()) { - if (ISD == ISD::SHL && LT.second == MVT::v32i16 && - (Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue)) + if (ISD == ISD::SHL && LT.second == MVT::v32i16 && Op2Info.isConstant()) // On AVX512, a packed v32i16 shift left by a constant build_vector // is lowered into a vector multiply (vpmullw). return getArithmeticInstrCost(Instruction::Mul, Ty, CostKind, @@ -731,8 +710,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( // Look for AVX2 lowering tricks (XOP is always better at v4i32 shifts). if (ST->hasAVX2() && !(ST->hasXOP() && LT.second == MVT::v4i32)) { if (ISD == ISD::SHL && LT.second == MVT::v16i16 && - (Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue)) + Op2Info.isConstant()) // On AVX2, a packed v16i16 shift left by a constant build_vector // is lowered into a vector multiply (vpmullw). return getArithmeticInstrCost(Instruction::Mul, Ty, CostKind, @@ -778,9 +756,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( // If the right shift is constant then we'll fold the negation so // it's as cheap as a left shift. int ShiftISD = ISD; - if ((ShiftISD == ISD::SRL || ShiftISD == ISD::SRA) && - (Op2Kind == TargetTransformInfo::OK_UniformConstantValue || - Op2Kind == TargetTransformInfo::OK_NonUniformConstantValue)) + if ((ShiftISD == ISD::SRL || ShiftISD == ISD::SRA) && Op2Info.isConstant()) ShiftISD = ISD::SHL; if (const auto *Entry = CostTableLookup(XOPShiftCostTable, ShiftISD, LT.second)) @@ -803,9 +779,7 @@ InstructionCost X86TTIImpl::getArithmeticInstrCost( { ISD::SRA, MVT::v4i64, 8+2 }, // 2*(2*psrad + shuffle) + split. }; - if (ST->hasSSE2() && - ((Op2Kind == TargetTransformInfo::OK_UniformConstantValue) || - (Op2Kind == TargetTransformInfo::OK_UniformValue))) { + if (ST->hasSSE2() && Op2Info.isUniform()) { // Handle AVX2 uniform v4i64 ISD::SRA, it's not worth a table. if (ISD == ISD::SRA && LT.second == MVT::v4i64 && ST->hasAVX2()) @@ -4069,8 +4043,10 @@ InstructionCost X86TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, - TTI::OperandValueKind OpdInfo, + TTI::OperandValueKind OpdKind, const Instruction *I) { + const TTI::OperandValueInfo OpInfo = {OpdKind, TTI::OP_None}; + // TODO: Handle other cost kinds. if (CostKind != TTI::TCK_RecipThroughput) { if (auto *SI = dyn_cast_or_null(I)) { @@ -4099,9 +4075,7 @@ InstructionCost X86TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src, InstructionCost Cost = 0; // Add a cost for constant load to vector. - if (Opcode == Instruction::Store && - (OpdInfo == TTI::OK_UniformConstantValue || - OpdInfo == TTI::OK_NonUniformConstantValue)) + if (Opcode == Instruction::Store && OpInfo.isConstant()) Cost += getMemoryOpCost(Instruction::Load, Src, DL.getABITypeAlign(Src), /*AddressSpace=*/0, CostKind);