diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 3b615bc700bbb..58577a6b6eb5c 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -781,6 +781,9 @@ class TargetTransformInfo { /// Return true if the target supports masked expand load. bool isLegalMaskedExpandLoad(Type *DataType) const; + /// Return true if the target supports strided load. + bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const; + /// Return true if this is an alternating opcode pattern that can be lowered /// to a single instruction on the target. In X86 this is for the addsub /// instruction which corrsponds to a Shuffle + Fadd + FSub pattern in IR. @@ -1412,6 +1415,20 @@ class TargetTransformInfo { Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput, const Instruction *I = nullptr) const; + /// \return The cost of strided memory operations. + /// \p Opcode - is a type of memory access Load or Store + /// \p DataTy - a vector type of the data to be loaded or stored + /// \p Ptr - pointer [or vector of pointers] - address[es] in memory + /// \p VariableMask - true when the memory access is predicated with a mask + /// that is not a compile-time constant + /// \p Alignment - alignment of single element + /// \p I - the optional original context instruction, if one exists, e.g. the + /// load/store to transform or the call to the gather/scatter intrinsic + InstructionCost getStridedMemoryOpCost( + unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, + Align Alignment, TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput, + const Instruction *I = nullptr) const; + /// \return The cost of the interleaved memory operation. /// \p Opcode is the memory operation code /// \p VecTy is the vector type of the interleaved access. @@ -1848,6 +1865,7 @@ class TargetTransformInfo::Concept { Align Alignment) = 0; virtual bool isLegalMaskedCompressStore(Type *DataType) = 0; virtual bool isLegalMaskedExpandLoad(Type *DataType) = 0; + virtual bool isLegalStridedLoadStore(Type *DataType, Align Alignment) = 0; virtual bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, const SmallBitVector &OpcodeMask) const = 0; @@ -2023,6 +2041,11 @@ class TargetTransformInfo::Concept { bool VariableMask, Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I = nullptr) = 0; + virtual InstructionCost + getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, + bool VariableMask, Align Alignment, + TTI::TargetCostKind CostKind, + const Instruction *I = nullptr) = 0; virtual InstructionCost getInterleavedMemoryOpCost( unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, @@ -2341,6 +2364,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { bool isLegalMaskedExpandLoad(Type *DataType) override { return Impl.isLegalMaskedExpandLoad(DataType); } + bool isLegalStridedLoadStore(Type *DataType, Align Alignment) override { + return Impl.isLegalStridedLoadStore(DataType, Alignment); + } bool isLegalAltInstr(VectorType *VecTy, unsigned Opcode0, unsigned Opcode1, const SmallBitVector &OpcodeMask) const override { return Impl.isLegalAltInstr(VecTy, Opcode0, Opcode1, OpcodeMask); @@ -2671,6 +2697,14 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { return Impl.getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I); } + InstructionCost + getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, const Value *Ptr, + bool VariableMask, Align Alignment, + TTI::TargetCostKind CostKind, + const Instruction *I = nullptr) override { + return Impl.getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask, + Alignment, CostKind, I); + } InstructionCost getInterleavedMemoryOpCost( unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 9958b4daa6ed8..3d5db96e86b80 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -304,6 +304,10 @@ class TargetTransformInfoImplBase { bool isLegalMaskedExpandLoad(Type *DataType) const { return false; } + bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const { + return false; + } + bool enableOrderedReductions() const { return false; } bool hasDivRemOp(Type *DataType, bool IsSigned) const { return false; } @@ -687,6 +691,14 @@ class TargetTransformInfoImplBase { return 1; } + InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, + const Value *Ptr, bool VariableMask, + Align Alignment, + TTI::TargetCostKind CostKind, + const Instruction *I = nullptr) const { + return InstructionCost::getInvalid(); + } + unsigned getInterleavedMemoryOpCost( unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 8902dde37cbca..1f11f0d7dd620 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -500,6 +500,11 @@ bool TargetTransformInfo::isLegalMaskedExpandLoad(Type *DataType) const { return TTIImpl->isLegalMaskedExpandLoad(DataType); } +bool TargetTransformInfo::isLegalStridedLoadStore(Type *DataType, + Align Alignment) const { + return TTIImpl->isLegalStridedLoadStore(DataType, Alignment); +} + bool TargetTransformInfo::enableOrderedReductions() const { return TTIImpl->enableOrderedReductions(); } @@ -1037,6 +1042,16 @@ InstructionCost TargetTransformInfo::getGatherScatterOpCost( Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const { InstructionCost Cost = TTIImpl->getGatherScatterOpCost( Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I); + assert((!Cost.isValid() || Cost >= 0) && + "TTI should not produce negative costs!"); + return Cost; +} + +InstructionCost TargetTransformInfo::getStridedMemoryOpCost( + unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, + Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const { + InstructionCost Cost = TTIImpl->getStridedMemoryOpCost( + Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index fe1cdb2dfa423..cb48720cc1902 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -658,6 +658,29 @@ InstructionCost RISCVTTIImpl::getGatherScatterOpCost( return NumLoads * MemOpCost; } +InstructionCost RISCVTTIImpl::getStridedMemoryOpCost( + unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, + Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) { + if (((Opcode == Instruction::Load || Opcode == Instruction::Store) && + !isLegalStridedLoadStore(DataTy, Alignment)) || + (Opcode != Instruction::Load && Opcode != Instruction::Store)) + return BaseT::getStridedMemoryOpCost(Opcode, DataTy, Ptr, VariableMask, + Alignment, CostKind, I); + + if (CostKind == TTI::TCK_CodeSize) + return TTI::TCC_Basic; + + // Cost is proportional to the number of memory operations implied. For + // scalable vectors, we use an estimate on that number since we don't + // know exactly what VL will be. + auto &VTy = *cast(DataTy); + InstructionCost MemOpCost = + getMemoryOpCost(Opcode, VTy.getElementType(), Alignment, 0, CostKind, + {TTI::OK_AnyValue, TTI::OP_None}, I); + unsigned NumLoads = getEstimatedVLFor(&VTy); + return NumLoads * MemOpCost; +} + // Currently, these represent both throughput and codesize costs // for the respective intrinsics. The costs in this table are simply // instruction counts with the following adjustments made: diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index 0747a778fe9a2..af36e9d5d5e88 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -143,6 +143,12 @@ class RISCVTTIImpl : public BasicTTIImplBase { TTI::TargetCostKind CostKind, const Instruction *I); + InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy, + const Value *Ptr, bool VariableMask, + Align Alignment, + TTI::TargetCostKind CostKind, + const Instruction *I); + InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, TTI::CastContextHint CCH, TTI::TargetCostKind CostKind, @@ -250,6 +256,11 @@ class RISCVTTIImpl : public BasicTTIImplBase { return ST->is64Bit() && !ST->hasVInstructionsI64(); } + bool isLegalStridedLoadStore(Type *DataType, Align Alignment) { + EVT DataTypeVT = TLI->getValueType(DL, DataType); + return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment); + } + bool isVScaleKnownToBeAPowerOfTwo() const { return TLI->isVScaleKnownToBeAPowerOfTwo(); }