diff --git a/llvm/include/llvm/IR/Attributes.h b/llvm/include/llvm/IR/Attributes.h index 906aba02644910..002277eb4d7201 100644 --- a/llvm/include/llvm/IR/Attributes.h +++ b/llvm/include/llvm/IR/Attributes.h @@ -216,9 +216,12 @@ class Attribute { /// if not known). std::pair> getAllocSizeArgs() const; - /// Returns the argument numbers for the vscale_range attribute (or pair(1, 0) - /// if not known). - std::pair getVScaleRangeArgs() const; + /// Returns the minimum value for the vscale_range attribute. + unsigned getVScaleRangeMin() const; + + /// Returns the maximum value for the vscale_range attribute or None when + /// unknown. + Optional getVScaleRangeMax() const; /// The Attribute is converted to a string of equivalent mnemonic. This /// is, presumably, for writing out the mnemonics for the assembly writer. @@ -348,7 +351,8 @@ class AttributeSet { Type *getInAllocaType() const; Type *getElementType() const; std::pair> getAllocSizeArgs() const; - std::pair getVScaleRangeArgs() const; + unsigned getVScaleRangeMin() const; + Optional getVScaleRangeMax() const; std::string getAsString(bool InAttrGrp = false) const; /// Return true if this attribute set belongs to the LLVMContext. @@ -1053,9 +1057,11 @@ class AttrBuilder { /// doesn't exist, pair(0, 0) is returned. std::pair> getAllocSizeArgs() const; - /// Retrieve the vscale_range args, if the vscale_range attribute exists. If - /// it doesn't exist, pair(1, 0) is returned. - std::pair getVScaleRangeArgs() const; + /// Retrieve the minimum value of 'vscale_range'. + unsigned getVScaleRangeMin() const; + + /// Retrieve the maximum value of 'vscale_range' or None when unknown. + Optional getVScaleRangeMax() const; /// Add integer attribute with raw value (packed/encoded if necessary). AttrBuilder &addRawIntAttr(Attribute::AttrKind Kind, uint64_t Value); @@ -1097,7 +1103,8 @@ class AttrBuilder { const Optional &NumElemsArg); /// This turns two ints into the form used internally in Attribute. - AttrBuilder &addVScaleRangeAttr(unsigned MinValue, unsigned MaxValue); + AttrBuilder &addVScaleRangeAttr(unsigned MinValue, + Optional MaxValue); /// Add a type attribute with the given type. AttrBuilder &addTypeAttr(Attribute::AttrKind Kind, Type *Ty); diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 8894a2c1b052e6..2ee1e1e98f2eae 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -5886,9 +5886,9 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { auto Attr = Call->getFunction()->getFnAttribute(Attribute::VScaleRange); if (!Attr.isValid()) return nullptr; - unsigned VScaleMin, VScaleMax; - std::tie(VScaleMin, VScaleMax) = Attr.getVScaleRangeArgs(); - if (VScaleMin == VScaleMax && VScaleMax != 0) + unsigned VScaleMin = Attr.getVScaleRangeMin(); + Optional VScaleMax = Attr.getVScaleRangeMax(); + if (VScaleMax && VScaleMin == VScaleMax) return ConstantInt::get(F->getReturnType(), VScaleMin); return nullptr; } diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 1c41c77a8cfb18..4bec851870ba95 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -1709,23 +1709,25 @@ static void computeKnownBitsFromOperator(const Operator *I, !II->getFunction()->hasFnAttribute(Attribute::VScaleRange)) break; - auto VScaleRange = II->getFunction() - ->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs(); + auto Attr = II->getFunction()->getFnAttribute(Attribute::VScaleRange); + Optional VScaleMax = Attr.getVScaleRangeMax(); - if (VScaleRange.second == 0) + if (!VScaleMax) break; + unsigned VScaleMin = Attr.getVScaleRangeMin(); + // If vscale min = max then we know the exact value at compile time // and hence we know the exact bits. - if (VScaleRange.first == VScaleRange.second) { - Known.One = VScaleRange.first; - Known.Zero = VScaleRange.first; + if (VScaleMin == VScaleMax) { + Known.One = VScaleMin; + Known.Zero = VScaleMin; Known.Zero.flipAllBits(); break; } - unsigned FirstZeroHighBit = 32 - countLeadingZeros(VScaleRange.second); + unsigned FirstZeroHighBit = + 32 - countLeadingZeros(VScaleMax.getValue()); if (FirstZeroHighBit < BitWidth) Known.Zero.setBitsFrom(FirstZeroHighBit); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 5feabd876e3a7c..d517d0384d6210 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -1306,7 +1306,8 @@ bool LLParser::parseEnumAttribute(Attribute::AttrKind Attr, AttrBuilder &B, unsigned MinValue, MaxValue; if (parseVScaleRangeArguments(MinValue, MaxValue)) return true; - B.addVScaleRangeAttr(MinValue, MaxValue); + B.addVScaleRangeAttr(MinValue, + MaxValue > 0 ? MaxValue : Optional()); return false; } case Attribute::Dereferenceable: { diff --git a/llvm/lib/IR/AttributeImpl.h b/llvm/lib/IR/AttributeImpl.h index c5bbe6571096d0..1153fb827b5633 100644 --- a/llvm/lib/IR/AttributeImpl.h +++ b/llvm/lib/IR/AttributeImpl.h @@ -253,7 +253,8 @@ class AttributeSetNode final uint64_t getDereferenceableBytes() const; uint64_t getDereferenceableOrNullBytes() const; std::pair> getAllocSizeArgs() const; - std::pair getVScaleRangeArgs() const; + unsigned getVScaleRangeMin() const; + Optional getVScaleRangeMax() const; std::string getAsString(bool InAttrGrp) const; Type *getAttributeType(Attribute::AttrKind Kind) const; diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp index d0bc53bc8e4f8a..ede520bea053ee 100644 --- a/llvm/lib/IR/Attributes.cpp +++ b/llvm/lib/IR/Attributes.cpp @@ -78,15 +78,18 @@ unpackAllocSizeArgs(uint64_t Num) { return std::make_pair(ElemSizeArg, NumElemsArg); } -static uint64_t packVScaleRangeArgs(unsigned MinValue, unsigned MaxValue) { - return uint64_t(MinValue) << 32 | MaxValue; +static uint64_t packVScaleRangeArgs(unsigned MinValue, + Optional MaxValue) { + return uint64_t(MinValue) << 32 | MaxValue.getValueOr(0); } -static std::pair unpackVScaleRangeArgs(uint64_t Value) { +static std::pair> +unpackVScaleRangeArgs(uint64_t Value) { unsigned MaxValue = Value & std::numeric_limits::max(); unsigned MinValue = Value >> 32; - return std::make_pair(MinValue, MaxValue); + return std::make_pair(MinValue, + MaxValue > 0 ? MaxValue : Optional()); } Attribute Attribute::get(LLVMContext &Context, Attribute::AttrKind Kind, @@ -354,10 +357,16 @@ std::pair> Attribute::getAllocSizeArgs() const { return unpackAllocSizeArgs(pImpl->getValueAsInt()); } -std::pair Attribute::getVScaleRangeArgs() const { +unsigned Attribute::getVScaleRangeMin() const { assert(hasAttribute(Attribute::VScaleRange) && "Trying to get vscale args from non-vscale attribute"); - return unpackVScaleRangeArgs(pImpl->getValueAsInt()); + return unpackVScaleRangeArgs(pImpl->getValueAsInt()).first; +} + +Optional Attribute::getVScaleRangeMax() const { + assert(hasAttribute(Attribute::VScaleRange) && + "Trying to get vscale args from non-vscale attribute"); + return unpackVScaleRangeArgs(pImpl->getValueAsInt()).second; } std::string Attribute::getAsString(bool InAttrGrp) const { @@ -428,13 +437,13 @@ std::string Attribute::getAsString(bool InAttrGrp) const { } if (hasAttribute(Attribute::VScaleRange)) { - unsigned MinValue, MaxValue; - std::tie(MinValue, MaxValue) = getVScaleRangeArgs(); + unsigned MinValue = getVScaleRangeMin(); + Optional MaxValue = getVScaleRangeMax(); std::string Result = "vscale_range("; Result += utostr(MinValue); Result += ','; - Result += utostr(MaxValue); + Result += utostr(MaxValue.getValueOr(0)); Result += ')'; return Result; } @@ -717,9 +726,12 @@ std::pair> AttributeSet::getAllocSizeArgs() const { : std::pair>(0, 0); } -std::pair AttributeSet::getVScaleRangeArgs() const { - return SetNode ? SetNode->getVScaleRangeArgs() - : std::pair(1, 0); +unsigned AttributeSet::getVScaleRangeMin() const { + return SetNode ? SetNode->getVScaleRangeMin() : 1; +} + +Optional AttributeSet::getVScaleRangeMax() const { + return SetNode ? SetNode->getVScaleRangeMax() : None; } std::string AttributeSet::getAsString(bool InAttrGrp) const { @@ -897,10 +909,16 @@ AttributeSetNode::getAllocSizeArgs() const { return std::make_pair(0, 0); } -std::pair AttributeSetNode::getVScaleRangeArgs() const { +unsigned AttributeSetNode::getVScaleRangeMin() const { if (auto A = findEnumAttribute(Attribute::VScaleRange)) - return A->getVScaleRangeArgs(); - return std::make_pair(1, 0); + return A->getVScaleRangeMin(); + return 1; +} + +Optional AttributeSetNode::getVScaleRangeMax() const { + if (auto A = findEnumAttribute(Attribute::VScaleRange)) + return A->getVScaleRangeMax(); + return None; } std::string AttributeSetNode::getAsString(bool InAttrGrp) const { @@ -1623,8 +1641,12 @@ std::pair> AttrBuilder::getAllocSizeArgs() const { return unpackAllocSizeArgs(getRawIntAttr(Attribute::AllocSize)); } -std::pair AttrBuilder::getVScaleRangeArgs() const { - return unpackVScaleRangeArgs(getRawIntAttr(Attribute::VScaleRange)); +unsigned AttrBuilder::getVScaleRangeMin() const { + return unpackVScaleRangeArgs(getRawIntAttr(Attribute::VScaleRange)).first; +} + +Optional AttrBuilder::getVScaleRangeMax() const { + return unpackVScaleRangeArgs(getRawIntAttr(Attribute::VScaleRange)).second; } AttrBuilder &AttrBuilder::addAlignmentAttr(MaybeAlign Align) { @@ -1669,7 +1691,7 @@ AttrBuilder &AttrBuilder::addAllocSizeAttrFromRawRepr(uint64_t RawArgs) { } AttrBuilder &AttrBuilder::addVScaleRangeAttr(unsigned MinValue, - unsigned MaxValue) { + Optional MaxValue) { return addVScaleRangeAttrFromRawRepr(packVScaleRangeArgs(MinValue, MaxValue)); } diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 1290f5e08d2500..8d3fd4ec0fa8bd 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -2055,13 +2055,12 @@ void Verifier::verifyFunctionAttrs(FunctionType *FT, AttributeList Attrs, } if (Attrs.hasFnAttr(Attribute::VScaleRange)) { - std::pair Args = - Attrs.getFnAttrs().getVScaleRangeArgs(); - - if (Args.first == 0) + unsigned VScaleMin = Attrs.getFnAttrs().getVScaleRangeMin(); + if (VScaleMin == 0) CheckFailed("'vscale_range' minimum must be greater than 0", V); - if (Args.first > Args.second && Args.second != 0) + Optional VScaleMax = Attrs.getFnAttrs().getVScaleRangeMax(); + if (VScaleMax && VScaleMin > VScaleMax) CheckFailed("'vscale_range' minimum cannot be greater than maximum", V); } diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp index ce26c62af61a23..9a72f6dd90ea13 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -382,10 +382,9 @@ AArch64TargetMachine::getSubtargetImpl(const Function &F) const { unsigned MaxSVEVectorSize = 0; Attribute VScaleRangeAttr = F.getFnAttribute(Attribute::VScaleRange); if (VScaleRangeAttr.isValid()) { - std::tie(MinSVEVectorSize, MaxSVEVectorSize) = - VScaleRangeAttr.getVScaleRangeArgs(); - MinSVEVectorSize *= 128; - MaxSVEVectorSize *= 128; + Optional VScaleMax = VScaleRangeAttr.getVScaleRangeMax(); + MinSVEVectorSize = VScaleRangeAttr.getVScaleRangeMin() * 128; + MaxSVEVectorSize = VScaleMax ? VScaleMax.getValue() * 128 : 0; } else { MinSVEVectorSize = SVEVectorBitsMinOpt; MaxSVEVectorSize = SVEVectorBitsMaxOpt; diff --git a/llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp b/llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp index e72dccdc4b78f9..78e6f38bf85531 100644 --- a/llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp +++ b/llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp @@ -287,10 +287,10 @@ bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) { if (!Attr.isValid()) return false; - unsigned MinVScale, MaxVScale; - std::tie(MinVScale, MaxVScale) = Attr.getVScaleRangeArgs(); + unsigned MinVScale = Attr.getVScaleRangeMin(); + Optional MaxVScale = Attr.getVScaleRangeMax(); // The transform needs to know the exact runtime length of scalable vectors - if (MinVScale != MaxVScale || MinVScale == 0) + if (!MaxVScale || MinVScale != MaxVScale) return false; auto *PredType = @@ -351,10 +351,10 @@ bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) { if (!Attr.isValid()) return false; - unsigned MinVScale, MaxVScale; - std::tie(MinVScale, MaxVScale) = Attr.getVScaleRangeArgs(); + unsigned MinVScale = Attr.getVScaleRangeMin(); + Optional MaxVScale = Attr.getVScaleRangeMax(); // The transform needs to know the exact runtime length of scalable vectors - if (MinVScale != MaxVScale || MinVScale == 0) + if (!MaxVScale || MinVScale != MaxVScale) return false; auto *PredType = diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 33f217659c0182..18eb245779bf95 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -965,13 +965,13 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (match(Src, m_VScale(DL))) { if (Trunc.getFunction() && Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - unsigned MaxVScale = Trunc.getFunction() - ->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs() - .second; - if (MaxVScale > 0 && Log2_32(MaxVScale) < DestWidth) { - Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(Trunc, VScale); + Attribute Attr = + Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional MaxVScale = Attr.getVScaleRangeMax()) { + if (Log2_32(MaxVScale.getValue()) < DestWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(Trunc, VScale); + } } } } @@ -1337,14 +1337,13 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { if (match(Src, m_VScale(DL))) { if (CI.getFunction() && CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - unsigned MaxVScale = CI.getFunction() - ->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs() - .second; - unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); - if (MaxVScale > 0 && Log2_32(MaxVScale) < TypeWidth) { - Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(CI, VScale); + Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional MaxVScale = Attr.getVScaleRangeMax()) { + unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); + if (Log2_32(MaxVScale.getValue()) < TypeWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } } } } @@ -1608,13 +1607,12 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { if (match(Src, m_VScale(DL))) { if (CI.getFunction() && CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - unsigned MaxVScale = CI.getFunction() - ->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs() - .second; - if (MaxVScale > 0 && Log2_32(MaxVScale) < (SrcBitSize - 1)) { - Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(CI, VScale); + Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional MaxVScale = Attr.getVScaleRangeMax()) { + if (Log2_32(MaxVScale.getValue()) < (SrcBitSize - 1)) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } } } } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 8f63020a33d41c..97b951dfbcd387 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5366,13 +5366,9 @@ LoopVectorizationCostModel::getMaxLegalScalableVF(unsigned MaxSafeElements) { // Limit MaxScalableVF by the maximum safe dependence distance. Optional MaxVScale = TTI.getMaxVScale(); - if (!MaxVScale && TheFunction->hasFnAttribute(Attribute::VScaleRange)) { - unsigned VScaleMax = TheFunction->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs() - .second; - if (VScaleMax > 0) - MaxVScale = VScaleMax; - } + if (!MaxVScale && TheFunction->hasFnAttribute(Attribute::VScaleRange)) + MaxVScale = + TheFunction->getFnAttribute(Attribute::VScaleRange).getVScaleRangeMax(); MaxScalableVF = ElementCount::getScalable( MaxVScale ? (MaxSafeElements / MaxVScale.getValue()) : 0); if (!MaxScalableVF)