diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index 7e2424d8eb044..8ac6485362b29 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -660,7 +660,7 @@ class TargetTransformInfo { /// \Returns true if the target supports broadcasting a load to a vector of /// type . - bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const; + bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const; /// Return true if the target supports masked scatter. bool isLegalMaskedScatter(Type *DataType, Align Alignment) const; @@ -1560,7 +1560,7 @@ class TargetTransformInfo::Concept { virtual bool isLegalNTStore(Type *DataType, Align Alignment) = 0; virtual bool isLegalNTLoad(Type *DataType, Align Alignment) = 0; virtual bool isLegalBroadcastLoad(Type *ElementTy, - unsigned NumElements) const = 0; + ElementCount NumElements) const = 0; virtual bool isLegalMaskedScatter(Type *DataType, Align Alignment) = 0; virtual bool isLegalMaskedGather(Type *DataType, Align Alignment) = 0; virtual bool forceScalarizeMaskedGather(VectorType *DataType, @@ -1968,7 +1968,7 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { return Impl.isLegalNTLoad(DataType, Alignment); } bool isLegalBroadcastLoad(Type *ElementTy, - unsigned NumElements) const override { + ElementCount NumElements) const override { return Impl.isLegalBroadcastLoad(ElementTy, NumElements); } bool isLegalMaskedScatter(Type *DataType, Align Alignment) override { diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 538e099cf76ce..ff73e62d1f332 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -256,7 +256,7 @@ class TargetTransformInfoImplBase { return Alignment >= DataSize && isPowerOf2_32(DataSize); } - bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const { + bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const { return false; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 6186f0061eb8d..8a46569e1730c 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -397,7 +397,7 @@ bool TargetTransformInfo::isLegalNTLoad(Type *DataType, Align Alignment) const { } bool TargetTransformInfo::isLegalBroadcastLoad(Type *ElementTy, - unsigned NumElements) const { + ElementCount NumElements) const { return TTIImpl->isLegalBroadcastLoad(ElementTy, NumElements); } diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index abefa62c9d52c..1079bd68468b0 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -1558,7 +1558,7 @@ InstructionCost X86TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, if (const auto *Entry = CostTableLookup(SSE3BroadcastLoadTbl, Kind, LT.second)) { assert(isLegalBroadcastLoad(BaseTp->getElementType(), - LT.second.getVectorNumElements()) && + LT.second.getVectorElementCount()) && "Table entry missing from isLegalBroadcastLoad()"); return LT.first * Entry->Cost; } @@ -5137,9 +5137,10 @@ bool X86TTIImpl::isLegalNTStore(Type *DataType, Align Alignment) { } bool X86TTIImpl::isLegalBroadcastLoad(Type *ElementTy, - unsigned NumElements) const { + ElementCount NumElements) const { // movddup - return ST->hasSSE3() && NumElements == 2 && + return ST->hasSSE3() && !NumElements.isScalable() && + NumElements.getFixedValue() == 2 && ElementTy == Type::getDoubleTy(ElementTy->getContext()); } diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h index 4f874783f6989..3dbc2d0b6c2e3 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -232,7 +232,7 @@ class X86TTIImpl : public BasicTTIImplBase { bool isLegalMaskedStore(Type *DataType, Align Alignment); bool isLegalNTLoad(Type *DataType, Align Alignment); bool isLegalNTStore(Type *DataType, Align Alignment); - bool isLegalBroadcastLoad(Type *ElementTy, unsigned NumElements) const; + bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const; bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment); bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) { return forceScalarizeMaskedGather(VTy, Alignment); diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index bc976c95dcc9b..2150ed99e79de 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1188,7 +1188,8 @@ class BoUpSLP { return AllUsersVectorized(V1) && AllUsersVectorized(V2); }; // A broadcast of a load can be cheaper on some targets. - if (R.TTI->isLegalBroadcastLoad(V1->getType(), NumLanes) && + if (R.TTI->isLegalBroadcastLoad(V1->getType(), + ElementCount::getFixed(NumLanes)) && ((int)V1->getNumUses() == NumLanes || AllUsersAreInternal(V1, V2))) return VLOperands::ScoreSplatLoads;