diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 51b67d1756b9aa..548c6c503dc26a 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -660,10 +660,8 @@ class ScalarEvolution { return getConstant(Ty, -1, /*isSigned=*/true); } - /// Return an expression for sizeof ScalableTy that is type IntTy, where - /// ScalableTy is a scalable vector type. - const SCEV *getSizeOfScalableVectorExpr(Type *IntTy, - ScalableVectorType *ScalableTy); + /// Return an expression for a TypeSize. + const SCEV *getSizeOfExpr(Type *IntTy, TypeSize Size); /// Return an expression for the alloc size of AllocTy that is type IntTy const SCEV *getSizeOfExpr(Type *IntTy, Type *AllocTy); diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h index 9245cf1e5f4b74..1b14d74e015ec0 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -579,14 +579,8 @@ class SCEVUnknown final : public SCEV, private CallbackVH { public: Value *getValue() const { return getValPtr(); } - /// @{ - /// Test whether this is a special constant representing a type size in a - /// target-independent manner, and hasn't happened to have been folded with - /// other operations into something unrecognizable. This is mainly only - /// useful for pretty-printing and other situations where it isn't - /// absolutely required for these to succeed. - bool isSizeOf(Type *&AllocTy) const; - /// @} + /// Check whether this represents vscale. + bool isVScale() const; Type *getType() const { return getValPtr()->getType(); } diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 96cc518e6039bc..a820879b22082c 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -368,9 +368,8 @@ void SCEV::print(raw_ostream &OS) const { } case scUnknown: { const SCEVUnknown *U = cast(this); - Type *AllocTy; - if (U->isSizeOf(AllocTy)) { - OS << "sizeof(" << *AllocTy << ")"; + if (U->isVScale()) { + OS << "vscale"; return; } @@ -561,20 +560,8 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) { setValPtr(New); } -bool SCEVUnknown::isSizeOf(Type *&AllocTy) const { - if (ConstantExpr *VCE = dyn_cast(getValue())) - if (VCE->getOpcode() == Instruction::PtrToInt) - if (ConstantExpr *CE = dyn_cast(VCE->getOperand(0))) - if (CE->getOpcode() == Instruction::GetElementPtr && - CE->getOperand(0)->isNullValue() && - CE->getNumOperands() == 2) - if (ConstantInt *CI = dyn_cast(CE->getOperand(1))) - if (CI->isOne()) { - AllocTy = cast(CE)->getSourceElementType(); - return true; - } - - return false; +bool SCEVUnknown::isVScale() const { + return match(getValue(), m_VScale()); } //===----------------------------------------------------------------------===// @@ -4326,33 +4313,26 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl &Ops, } const SCEV * -ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy, - ScalableVectorType *ScalableTy) { - Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo()); - Constant *One = ConstantInt::get(IntTy, 1); - Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One); - // Note that the expression we created is the final expression, we don't - // want to simplify it any further Also, if we call a normal getSCEV(), - // we'll end up in an endless recursion. So just create an SCEVUnknown. - return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy)); +ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) { + const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue()); + if (Size.isScalable()) { + // TODO: Why is there no ConstantExpr::getVScale()? + Type *SrcElemTy = ScalableVectorType::get(Type::getInt8Ty(getContext()), 1); + Constant *NullPtr = Constant::getNullValue(SrcElemTy->getPointerTo()); + Constant *One = ConstantInt::get(IntTy, 1); + Constant *GEP = ConstantExpr::getGetElementPtr(SrcElemTy, NullPtr, One); + Constant *VScale = ConstantExpr::getPtrToInt(GEP, IntTy); + Res = getMulExpr(Res, getUnknown(VScale)); + } + return Res; } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { - if (auto *ScalableAllocTy = dyn_cast(AllocTy)) - return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy); - // We can bypass creating a target-independent constant expression and then - // folding it back into a ConstantInt. This is just a compile-time - // optimization. - return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); + return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { - if (auto *ScalableStoreTy = dyn_cast(StoreTy)) - return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy); - // We can bypass creating a target-independent constant expression and then - // folding it back into a ConstantInt. This is just a compile-time - // optimization. - return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); + return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); } const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, diff --git a/llvm/test/Analysis/ScalarEvolution/scalable-vector.ll b/llvm/test/Analysis/ScalarEvolution/scalable-vector.ll index 798a36023235d7..b5bd724cb269c1 100644 --- a/llvm/test/Analysis/ScalarEvolution/scalable-vector.ll +++ b/llvm/test/Analysis/ScalarEvolution/scalable-vector.ll @@ -5,9 +5,9 @@ define void @a(ptr %p) { ; CHECK-LABEL: 'a' ; CHECK-NEXT: Classifying expressions for: @a ; CHECK-NEXT: %1 = getelementptr , ptr null, i32 3 -; CHECK-NEXT: --> ((3 * sizeof()) + null) U: [0,-15) S: [-9223372036854775808,9223372036854775793) +; CHECK-NEXT: --> ((48 * vscale) + null) U: [0,-15) S: [-9223372036854775808,9223372036854775793) ; CHECK-NEXT: %2 = getelementptr , ptr %p, i32 1 -; CHECK-NEXT: --> (sizeof() + %p) U: full-set S: full-set +; CHECK-NEXT: --> ((8 * vscale) + %p) U: full-set S: full-set ; CHECK-NEXT: Determining loop execution counts for: @a ; getelementptr , ptr null, i32 3