Skip to content

Commit

Permalink
[SCEV] Add SCEVType to represent vscale.
Browse files Browse the repository at this point in the history
This is part of an effort to remove ConstantExpr based
representations of `vscale` so that its LangRef definiton can
be relaxed to accommodate a less strict definition of constant.

Differential Revision: https://reviews.llvm.org/D144891
  • Loading branch information
paulwalker-arm committed Mar 2, 2023
1 parent b8b8aa6 commit 7912f5c
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 27 deletions.
1 change: 1 addition & 0 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ class ScalarEvolution {
const SCEV *getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth = 0);
const SCEV *getPtrToIntExpr(const SCEV *Op, Type *Ty);
const SCEV *getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0);
const SCEV *getVScale(Type *Ty);
const SCEV *getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth = 0);
const SCEV *getZeroExtendExprImpl(const SCEV *Op, Type *Ty,
unsigned Depth = 0);
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolutionDivision.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {

void visitConstant(const SCEVConstant *Numerator);

void visitVScale(const SCEVVScale *Numerator);

void visitAddRecExpr(const SCEVAddRecExpr *Numerator);

void visitAddExpr(const SCEVAddExpr *Numerator);
Expand Down
26 changes: 23 additions & 3 deletions llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum SCEVTypes : unsigned short {
// These should be ordered in terms of increasing complexity to make the
// folders simpler.
scConstant,
scVScale,
scTruncate,
scZeroExtend,
scSignExtend,
Expand Down Expand Up @@ -75,6 +76,23 @@ class SCEVConstant : public SCEV {
static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
};

/// This class represents the value of vscale, as used when defining the length
/// of a scalable vector or returned by the llvm.vscale() intrinsic.
class SCEVVScale : public SCEV {
friend class ScalarEvolution;

SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
: SCEV(ID, scVScale, 0), Ty(ty) {}

Type *Ty;

public:
Type *getType() const { return Ty; }

/// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
};

inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
APInt Size(16, 1);
for (const auto *Arg : Args)
Expand Down Expand Up @@ -579,9 +597,6 @@ class SCEVUnknown final : public SCEV, private CallbackVH {
public:
Value *getValue() const { return getValPtr(); }

/// Check whether this represents vscale.
bool isVScale() const;

Type *getType() const { return getValPtr()->getType(); }

/// Methods for support type inquiry through isa, cast, and dyn_cast:
Expand All @@ -595,6 +610,8 @@ template <typename SC, typename RetVal = void> struct SCEVVisitor {
switch (S->getSCEVType()) {
case scConstant:
return ((SC *)this)->visitConstant((const SCEVConstant *)S);
case scVScale:
return ((SC *)this)->visitVScale((const SCEVVScale *)S);
case scPtrToInt:
return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
case scTruncate:
Expand Down Expand Up @@ -662,6 +679,7 @@ template <typename SV> class SCEVTraversal {

switch (S->getSCEVType()) {
case scConstant:
case scVScale:
case scUnknown:
continue;
case scPtrToInt:
Expand Down Expand Up @@ -751,6 +769,8 @@ class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {

const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }

const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }

const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
return Operand == Expr->getOperand()
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,8 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {

Value *visitConstant(const SCEVConstant *S) { return S->getValue(); }

Value *visitVScale(const SCEVVScale *S);

Value *visitPtrToIntExpr(const SCEVPtrToIntExpr *S);

Value *visitTruncateExpr(const SCEVTruncateExpr *S);
Expand Down
66 changes: 43 additions & 23 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ void SCEV::print(raw_ostream &OS) const {
case scConstant:
cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
return;
case scVScale:
OS << "vscale";
return;
case scPtrToInt: {
const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
const SCEV *Op = PtrToInt->getOperand();
Expand Down Expand Up @@ -366,17 +369,9 @@ void SCEV::print(raw_ostream &OS) const {
OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
return;
}
case scUnknown: {
const SCEVUnknown *U = cast<SCEVUnknown>(this);
if (U->isVScale()) {
OS << "vscale";
return;
}

// Otherwise just print it normally.
U->getValue()->printAsOperand(OS, false);
case scUnknown:
cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
return;
}
case scCouldNotCompute:
OS << "***COULDNOTCOMPUTE***";
return;
Expand All @@ -388,6 +383,8 @@ Type *SCEV::getType() const {
switch (getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(this)->getType();
case scVScale:
return cast<SCEVVScale>(this)->getType();
case scPtrToInt:
case scTruncate:
case scZeroExtend:
Expand Down Expand Up @@ -419,6 +416,7 @@ Type *SCEV::getType() const {
ArrayRef<const SCEV *> SCEV::operands() const {
switch (getSCEVType()) {
case scConstant:
case scVScale:
case scUnknown:
return {};
case scPtrToInt:
Expand Down Expand Up @@ -501,6 +499,18 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
return getConstant(ConstantInt::get(ITy, V, isSigned));
}

const SCEV *ScalarEvolution::getVScale(Type *Ty) {
FoldingSetNodeID ID;
ID.AddInteger(scVScale);
ID.AddPointer(Ty);
void *IP = nullptr;
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
UniqueSCEVs.InsertNode(S, IP);
return S;
}

SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
const SCEV *op, Type *ty)
: SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
Expand Down Expand Up @@ -560,10 +570,6 @@ void SCEVUnknown::allUsesReplacedWith(Value *New) {
setValPtr(New);
}

bool SCEVUnknown::isVScale() const {
return match(getValue(), m_VScale());
}

//===----------------------------------------------------------------------===//
// SCEV Utilities
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -714,6 +720,12 @@ CompareSCEVComplexity(EquivalenceClasses<const SCEV *> &EqCacheSCEV,
return LA.ult(RA) ? -1 : 1;
}

case scVScale: {
const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
return LTy->getBitWidth() - RTy->getBitWidth();
}

case scAddRecExpr: {
const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
Expand Down Expand Up @@ -4015,6 +4027,8 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final

RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }

RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }

RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }

RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
Expand Down Expand Up @@ -4061,6 +4075,7 @@ class SCEVSequentialMinMaxDeduplicatingVisitor final
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind) {
switch (Kind) {
case scConstant:
case scVScale:
case scTruncate:
case scZeroExtend:
case scSignExtend:
Expand Down Expand Up @@ -4104,6 +4119,7 @@ static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
if (!scevUnconditionallyPropagatesPoisonFromOperands(S->getSCEVType())) {
switch (S->getSCEVType()) {
case scConstant:
case scVScale:
case scTruncate:
case scZeroExtend:
case scSignExtend:
Expand Down Expand Up @@ -4315,15 +4331,8 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops,
const SCEV *
ScalarEvolution::getSizeOfExpr(Type *IntTy, TypeSize Size) {
const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
if (Size.isScalable()) {
// TODO: Why is there no ConstantExpr::getVScale()?
Type *SrcElemTy = ScalableVectorType::get(Type::getInt8Ty(getContext()), 1);
Constant *NullPtr = Constant::getNullValue(SrcElemTy->getPointerTo());
Constant *One = ConstantInt::get(IntTy, 1);
Constant *GEP = ConstantExpr::getGetElementPtr(SrcElemTy, NullPtr, One);
Constant *VScale = ConstantExpr::getPtrToInt(GEP, IntTy);
Res = getMulExpr(Res, getUnknown(VScale));
}
if (Size.isScalable())
Res = getMulExpr(Res, getVScale(IntTy));
return Res;
}

Expand Down Expand Up @@ -5887,6 +5896,7 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
bool follow(const SCEV *S) {
switch (S->getSCEVType()) {
case scConstant:
case scVScale:
case scPtrToInt:
case scTruncate:
case scZeroExtend:
Expand Down Expand Up @@ -6274,6 +6284,8 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
switch (S->getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(S)->getAPInt().countr_zero();
case scVScale:
return 0;
case scTruncate: {
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
return std::min(GetMinTrailingZeros(T->getOperand()),
Expand Down Expand Up @@ -6504,6 +6516,7 @@ ScalarEvolution::getRangeRefIter(const SCEV *S,
break;
[[fallthrough]];
case scConstant:
case scVScale:
case scTruncate:
case scZeroExtend:
case scSignExtend:
Expand Down Expand Up @@ -6607,6 +6620,8 @@ const ConstantRange &ScalarEvolution::getRangeRef(
switch (S->getSCEVType()) {
case scConstant:
llvm_unreachable("Already handled above.");
case scVScale:
return setRange(S, SignHint, std::move(ConservativeResult));
case scTruncate: {
const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
Expand Down Expand Up @@ -9711,6 +9726,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
switch (V->getSCEVType()) {
case scCouldNotCompute:
case scAddRecExpr:
case scVScale:
return nullptr;
case scConstant:
return cast<SCEVConstant>(V)->getValue();
Expand Down Expand Up @@ -9794,6 +9810,7 @@ static Constant *BuildConstantFromSCEV(const SCEV *V) {
const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
switch (V->getSCEVType()) {
case scConstant:
case scVScale:
return V;
case scAddRecExpr: {
// If this is a loop recurrence for a loop that does not contain L, then we
Expand Down Expand Up @@ -9892,6 +9909,7 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
case scSequentialUMinExpr:
return getSequentialMinMaxExpr(V->getSCEVType(), NewOps);
case scConstant:
case scVScale:
case scAddRecExpr:
case scUnknown:
case scCouldNotCompute:
Expand Down Expand Up @@ -13677,6 +13695,7 @@ ScalarEvolution::LoopDisposition
ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
switch (S->getSCEVType()) {
case scConstant:
case scVScale:
return LoopInvariant;
case scAddRecExpr: {
const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
Expand Down Expand Up @@ -13775,6 +13794,7 @@ ScalarEvolution::BlockDisposition
ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
switch (S->getSCEVType()) {
case scConstant:
case scVScale:
return ProperlyDominatesBlock;
case scAddRecExpr: {
// This uses a "dominates" query instead of "properly dominates" query
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Analysis/ScalarEvolutionDivision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ void SCEVDivision::visitConstant(const SCEVConstant *Numerator) {
}
}

void SCEVDivision::visitVScale(const SCEVVScale *Numerator) {
return cannotDivide(Numerator);
}

void SCEVDivision::visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
const SCEV *StartQ, *StartR, *StepQ, *StepR;
if (!Numerator->isAffine())
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ static bool isHighCostExpansion(const SCEV *S,
switch (S->getSCEVType()) {
case scUnknown:
case scConstant:
case scVScale:
return false;
case scTruncate:
return isHighCostExpansion(cast<SCEVTruncateExpr>(S)->getOperand(),
Expand Down Expand Up @@ -2812,9 +2813,10 @@ static bool isCompatibleIVType(Value *LVal, Value *RVal) {
/// SCEVUnknown, we simply return the rightmost SCEV operand.
static const SCEV *getExprBase(const SCEV *S) {
switch (S->getSCEVType()) {
default: // uncluding scUnknown.
default: // including scUnknown.
return S;
case scConstant:
case scVScale:
return nullptr;
case scTruncate:
return getExprBase(cast<SCEVTruncateExpr>(S)->getOperand());
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ const Loop *SCEVExpander::getRelevantLoop(const SCEV *S) {

switch (S->getSCEVType()) {
case scConstant:
case scVScale:
return nullptr; // A constant has no relevant loops.
case scTruncate:
case scZeroExtend:
Expand Down Expand Up @@ -1744,6 +1745,10 @@ Value *SCEVExpander::visitSequentialUMinExpr(const SCEVSequentialUMinExpr *S) {
return expandMinMaxExpr(S, Intrinsic::umin, "umin", /*IsSequential*/true);
}

Value *SCEVExpander::visitVScale(const SCEVVScale *S) {
return Builder.CreateVScale(ConstantInt::get(S->getType(), 1));
}

Value *SCEVExpander::expandCodeForImpl(const SCEV *SH, Type *Ty,
Instruction *IP) {
setInsertPoint(IP);
Expand Down Expand Up @@ -2124,6 +2129,7 @@ template<typename T> static InstructionCost costAndCollectOperands(
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
case scUnknown:
case scConstant:
case scVScale:
return 0;
case scPtrToInt:
Cost = CastCost(Instruction::PtrToInt);
Expand Down Expand Up @@ -2260,6 +2266,7 @@ bool SCEVExpander::isHighCostExpansionHelper(
case scCouldNotCompute:
llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
case scUnknown:
case scVScale:
// Assume to be zero-cost.
return false;
case scConstant: {
Expand Down

0 comments on commit 7912f5c

Please sign in to comment.