diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 0201942183e68..6cb8fec4737fa 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -566,6 +566,7 @@ class ScalarEvolution { const SCEV *getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth = 0); const SCEV *getPtrToIntExpr(const SCEV *Op, Type *Ty); const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); + const SCEV *getVScale(Type *Ty); const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0); const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth = 0); diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h index 7d5902d317952..3283d438ccb51 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionDivision.h @@ -48,6 +48,8 @@ struct SCEVDivision : public SCEVVisitor { void visitConstant(const SCEVConstant *Numerator); + void visitVScale(const SCEVVScale *Numerator); + void visitAddRecExpr(const SCEVAddRecExpr *Numerator); void visitAddExpr(const SCEVAddExpr *Numerator); diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h index 1b14d74e015ec..0a1c900c3954b 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -39,6 +39,7 @@ enum SCEVTypes : unsigned short { // These should be ordered in terms of increasing complexity to make the // folders simpler. scConstant, + scVScale, scTruncate, scZeroExtend, scSignExtend, @@ -75,6 +76,23 @@ class SCEVConstant : public SCEV { static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; } }; +/// This class represents the value of vscale, as used when defining the length +/// of a scalable vector or returned by the llvm.vscale() intrinsic. +class SCEVVScale : public SCEV { + friend class ScalarEvolution; + + SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty) + : SCEV(ID, scVScale, 0), Ty(ty) {} + + Type *Ty; + +public: + Type *getType() const { return Ty; } + + /// Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; } +}; + inline unsigned short computeExpressionSize(ArrayRef Args) { APInt Size(16, 1); for (const auto *Arg : Args) @@ -579,9 +597,6 @@ class SCEVUnknown final : public SCEV, private CallbackVH { public: Value *getValue() const { return getValPtr(); } - /// Check whether this represents vscale. - bool isVScale() const; - Type *getType() const { return getValPtr()->getType(); } /// Methods for support type inquiry through isa, cast, and dyn_cast: @@ -595,6 +610,8 @@ template struct SCEVVisitor { switch (S->getSCEVType()) { case scConstant: return ((SC *)this)->visitConstant((const SCEVConstant *)S); + case scVScale: + return ((SC *)this)->visitVScale((const SCEVVScale *)S); case scPtrToInt: return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S); case scTruncate: @@ -662,6 +679,7 @@ template class SCEVTraversal { switch (S->getSCEVType()) { case scConstant: + case scVScale: case scUnknown: continue; case scPtrToInt: @@ -751,6 +769,8 @@ class SCEVRewriteVisitor : public SCEVVisitor { const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; } + const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; } + const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand()); return Operand == Expr->getOperand() diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h index 131e24f685e89..555897083469a 100644 --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -457,6 +457,8 @@ class SCEVExpander : public SCEVVisitor { Value *visitConstant(const SCEVConstant *S) { return S->getValue(); } + Value *visitVScale(const SCEVVScale *S); + Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S); Value *visitTruncateExpr(const SCEVTruncateExpr *S); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index f997b1950a4d4..b07429492d814 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -271,6 +271,9 @@ void SCEV::print(raw_ostream &OS) const { case scConstant: cast(this)->getValue()->printAsOperand(OS, false); return; + case scVScale: + OS << "vscale"; + return; case scPtrToInt: { const SCEVPtrToIntExpr *PtrToInt = cast(this); const SCEV *Op = PtrToInt->getOperand(); @@ -366,17 +369,9 @@ void SCEV::print(raw_ostream &OS) const { OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")"; return; } - case scUnknown: { - const SCEVUnknown *U = cast(this); - if (U->isVScale()) { - OS << "vscale"; - return; - } - - // Otherwise just print it normally. - U->getValue()->printAsOperand(OS, false); + case scUnknown: + cast(this)->getValue()->printAsOperand(OS, false); return; - } case scCouldNotCompute: OS << "***COULDNOTCOMPUTE***"; return; @@ -388,6 +383,8 @@ Type *SCEV::getType() const { switch (getSCEVType()) { case scConstant: return cast(this)->getType(); + case scVScale: + return cast(this)->getType(); case scPtrToInt: case scTruncate: case scZeroExtend: @@ -419,6 +416,7 @@ Type *SCEV::getType() const { ArrayRef SCEV::operands() const { switch (getSCEVType()) { case scConstant: + case scVScale: case scUnknown: return {}; case scPtrToInt: @@ -501,6 +499,18 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { return getConstant(ConstantInt::get(ITy, V, isSigned)); } +const SCEV *ScalarEvolution::getVScale(Type *Ty) { + FoldingSetNodeID ID; + ID.AddInteger(scVScale); + ID.AddPointer(Ty); + void *IP = nullptr; + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return S; + SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty); + UniqueSCEVs.InsertNode(S, IP); + return S; +} + SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} @@ -560,10 +570,6 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) { setValPtr(New); } -bool SCEVUnknown::isVScale() const { - return match(getValue(), m_VScale()); -} - //===----------------------------------------------------------------------===// // SCEV Utilities //===----------------------------------------------------------------------===// @@ -714,6 +720,12 @@ CompareSCEVComplexity(EquivalenceClasses &EqCacheSCEV, return LA.ult(RA) ? -1 : 1; } + case scVScale: { + const auto *LTy = cast(cast(LHS)->getType()); + const auto *RTy = cast(cast(RHS)->getType()); + return LTy->getBitWidth() - RTy->getBitWidth(); + } + case scAddRecExpr: { const SCEVAddRecExpr *LA = cast(LHS); const SCEVAddRecExpr *RA = cast(RHS); @@ -4015,6 +4027,8 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final RetVal visitConstant(const SCEVConstant *Constant) { return Constant; } + RetVal visitVScale(const SCEVVScale *VScale) { return VScale; } + RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; } RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; } @@ -4061,6 +4075,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) { switch (Kind) { case scConstant: + case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -4104,6 +4119,7 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) { if (!scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) { switch (S->getSCEVType()) { case scConstant: + case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -4315,15 +4331,8 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, const SCEV * ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue()); - if (Size.isScalable()) { - // TODO: Why is there no ConstantExpr::getVScale()? - Type *SrcElemTy = ScalableVectorType::get(Type::getInt8Ty(getContext()), 1); - Constant *NullPtr = Constant::getNullValue(SrcElemTy->getPointerTo()); - Constant *One = ConstantInt::get(IntTy, 1); - Constant *GEP = ConstantExpr::getGetElementPtr(SrcElemTy, NullPtr, One); - Constant *VScale = ConstantExpr::getPtrToInt(GEP, IntTy); - Res = getMulExpr(Res, getUnknown(VScale)); - } + if (Size.isScalable()) + Res = getMulExpr(Res, getVScale(IntTy)); return Res; } @@ -5887,6 +5896,7 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, bool follow(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: + case scVScale: case scPtrToInt: case scTruncate: case scZeroExtend: @@ -6274,6 +6284,8 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: return cast(S)->getAPInt().countr_zero(); + case scVScale: + return 0; case scTruncate: { const SCEVTruncateExpr *T = cast(S); return std::min(GetMinTrailingZeros(T->getOperand()), @@ -6504,6 +6516,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S, break; [[fallthrough]]; case scConstant: + case scVScale: case scTruncate: case scZeroExtend: case scSignExtend: @@ -6607,6 +6620,8 @@ const ConstantRange &ScalarEvolution::getRangeRef( switch (S->getSCEVType()) { case scConstant: llvm_unreachable("Already handled above."); + case scVScale: + return setRange(S, SignHint, std::move(ConservativeResult)); case scTruncate: { const SCEVTruncateExpr *Trunc = cast(S); ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1); @@ -9711,6 +9726,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { switch (V->getSCEVType()) { case scCouldNotCompute: case scAddRecExpr: + case scVScale: return nullptr; case scConstant: return cast(V)->getValue(); @@ -9794,6 +9810,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) { const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { switch (V->getSCEVType()) { case scConstant: + case scVScale: return V; case scAddRecExpr: { // If this is a loop recurrence for a loop that does not contain L, then we @@ -9892,6 +9909,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { case scSequentialUMinExpr: return getSequentialMinMaxExpr(V->getSCEVType(), NewOps); case scConstant: + case scVScale: case scAddRecExpr: case scUnknown: case scCouldNotCompute: @@ -13677,6 +13695,7 @@ ScalarEvolution::LoopDisposition ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { switch (S->getSCEVType()) { case scConstant: + case scVScale: return LoopInvariant; case scAddRecExpr: { const SCEVAddRecExpr *AR = cast(S); @@ -13775,6 +13794,7 @@ ScalarEvolution::BlockDisposition ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { switch (S->getSCEVType()) { case scConstant: + case scVScale: return ProperlyDominatesBlock; case scAddRecExpr: { // This uses a "dominates" query instead of "properly dominates" query diff --git a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp index 0619569bf8168..e1dd834cfb100 100644 --- a/llvm/lib/Analysis/ScalarEvolutionDivision.cpp +++ b/llvm/lib/Analysis/ScalarEvolutionDivision.cpp @@ -126,6 +126,10 @@ void SCEVDivision::visitConstant(const SCEVConstant *Numerator) { } } +void SCEVDivision::visitVScale(const SCEVVScale *Numerator) { + return cannotDivide(Numerator); +} + void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) { const SCEV *StartQ, *StartR, *StepQ, *StepR; if (!Numerator->isAffine()) diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 67c404a085df7..e5da0652a4ab7 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -976,6 +976,7 @@ static bool isHighCostExpansion(const SCEV *S, switch (S->getSCEVType()) { case scUnknown: case scConstant: + case scVScale: return false; case scTruncate: return isHighCostExpansion(cast(S)->getOperand(), @@ -2812,9 +2813,10 @@ static bool isCompatibleIVType(Value *LVal, Value *RVal) { /// SCEVUnknown, we simply return the rightmost SCEV operand. static const SCEV *getExprBase(const SCEV *S) { switch (S->getSCEVType()) { - default: // uncluding scUnknown. + default: // including scUnknown. return S; case scConstant: + case scVScale: return nullptr; case scTruncate: return getExprBase(cast(S)->getOperand()); diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 24f1966edd37a..902eee26a4567 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -680,6 +680,7 @@ const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) { switch (S->getSCEVType()) { case scConstant: + case scVScale: return nullptr; // A constant has no relevant loops. case scTruncate: case scZeroExtend: @@ -1744,6 +1745,10 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) { return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true); } +Value *SCEVExpander::visitVScale(const SCEVVScale *S) { + return Builder.CreateVScale(ConstantInt::get(S->getType(), 1)); +} + Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty, Instruction *IP) { setInsertPoint(IP); @@ -2124,6 +2129,7 @@ template static InstructionCost costAndCollectOperands( llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); case scUnknown: case scConstant: + case scVScale: return 0; case scPtrToInt: Cost = CastCost(Instruction::PtrToInt); @@ -2260,6 +2266,7 @@ bool SCEVExpander::isHighCostExpansionHelper( case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); case scUnknown: + case scVScale: // Assume to be zero-cost. return false; case scConstant: {