Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change contents of ScalarEvolution from private to protected #83052

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,10 @@ class ScalarEvolution {
LoopComputable ///< The SCEV varies predictably with the loop.
};

bool AssumeLoopFinite = false;
void setAssumeLoopExits();
SmallPtrSet<BasicBlock *, 4> GuaranteedUnreachable;

/// An enum describing the relationship between a SCEV and a basic block.
enum BlockDisposition {
DoesNotDominateBlock, ///< The SCEV does not dominate the block.
Expand Down
179 changes: 143 additions & 36 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,8 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) {
return S;
}

void ScalarEvolution::setAssumeLoopExits() { this->AssumeLoopFinite = true; }

SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
const SCEV *op, Type *ty)
: SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
Expand Down Expand Up @@ -7413,7 +7415,8 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) {
// A mustprogress loop without side effects must be finite.
// TODO: The check used here is very conservative. It's only *specific*
// side effects which are well defined in infinite loops.
return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
return AssumeLoopFinite || isFinite(L) ||
(isMustProgress(L) && loopHasNoSideEffects(L));
}

const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
Expand Down Expand Up @@ -8828,6 +8831,26 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
bool AllowPredicates) {
if (AssumeLoopFinite) {
SmallVector<BasicBlock *, 8> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
for (auto &ExitingBlock : ExitingBlocks) {
BasicBlock *Exit = nullptr;
for (auto *SBB : successors(ExitingBlock)) {
if (!L->contains(SBB)) {
if (GuaranteedUnreachable.count(SBB))
continue;
Exit = SBB;
break;
}
}
if (!Exit)
ExitingBlock = nullptr;
}
ExitingBlocks.erase(
std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr),
ExitingBlocks.end());
}
assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
// If our exiting block does not dominate the latch, then its connection with
// loop's exit limit may be far from trivial.
Expand All @@ -8853,6 +8876,8 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
BasicBlock *Exit = nullptr;
for (auto *SBB : successors(ExitingBlock))
if (!L->contains(SBB)) {
if (AssumeLoopFinite and GuaranteedUnreachable.count(SBB))
continue;
if (Exit) // Multiple exit successors.
return getCouldNotCompute();
Exit = SBB;
Expand Down Expand Up @@ -8923,6 +8948,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates) {

// Handle BinOp conditions (And, Or).
if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
Expand Down Expand Up @@ -8950,6 +8976,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
if (ExitIfTrue == !CI->getZExtValue())
// The backedge is always taken.
return getCouldNotCompute();

// The backedge is never taken.
return getZero(CI->getType());
}
Expand All @@ -8961,9 +8988,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
const APInt *C;
if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
match(WO->getRHS(), m_APInt(C))) {
ConstantRange NWR =
ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C,
WO->getNoWrapKind());
ConstantRange NWR = ConstantRange::makeExactNoWrapRegion(
WO->getBinaryOp(), *C, WO->getNoWrapKind());
CmpInst::Predicate Pred;
APInt NewRHSC, Offset;
NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
Expand Down Expand Up @@ -9019,13 +9045,15 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
const SCEV *SymbolicMaxBECount = getCouldNotCompute();
if (EitherMayExit) {
bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);

// Both conditions must be same for the loop to continue executing.
// Choose the less conservative count.
if (EL0.ExactNotTaken != getCouldNotCompute() &&
EL1.ExactNotTaken != getCouldNotCompute()) {
BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
UseSequentialUMin);
}

if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
ConstantMaxBECount = EL1.ConstantMaxNotTaken;
else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
Expand All @@ -9045,6 +9073,12 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
// For now, be conservative.
if (EL0.ExactNotTaken == EL1.ExactNotTaken)
BECount = EL0.ExactNotTaken;
// This was executed in Enzyme's must exit code under the
// logic for when the binary op was OR
if (AssumeLoopFinite && !IsAnd) {
if (EL0.ExactNotTaken == EL1.ExactNotTaken)
ConstantMaxBECount = EL0.ExactNotTaken;
}
}

// There are cases (e.g. PR26207) where computeExitLimitFromCond is able
Expand All @@ -9053,12 +9087,14 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp(
// and
// EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
// EL1.ConstantMaxNotTaken to not.
if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
!isa<SCEVCouldNotCompute>(BECount))
ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
SymbolicMaxBECount =
isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
if (!AssumeLoopFinite || !IsAnd) { // should skip if assume exits and OR
if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
!isa<SCEVCouldNotCompute>(BECount))
ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
SymbolicMaxBECount =
isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
}
return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
{ &EL0.Predicates, &EL1.Predicates });
}
Expand All @@ -9082,8 +9118,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
if (EL.hasAnyInfo())
return EL;

auto *ExhaustiveCount =
computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue);

if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
return ExhaustiveCount;
Expand All @@ -9094,7 +9129,31 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
bool ControlsOnlyExit, bool AllowPredicates) {

if (AssumeLoopFinite) {
#define PROP_PHI(LHS) \
if (auto un = dyn_cast<SCEVUnknown>(LHS)) { \
if (auto pn = dyn_cast_or_null<PHINode>(un->getValue())) { \
const SCEV *sc = nullptr; \
bool failed = false; \
for (auto &a : pn->incoming_values()) { \
auto subsc = getSCEV(a); \
if (sc == nullptr) { \
sc = subsc; \
continue; \
} \
if (subsc != sc) { \
failed = true; \
break; \
} \
} \
if (!failed) { \
LHS = sc; \
} \
} \
}
PROP_PHI(LHS)
PROP_PHI(RHS)
}
// Try to evaluate any dependencies out of the loop.
LHS = getSCEVAtScope(LHS, L);
RHS = getSCEVAtScope(RHS, L);
Expand All @@ -9107,6 +9166,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
Pred = ICmpInst::getSwappedPredicate(Pred);
}

// was not present in Enzyme code, the last condition is true if
// AssumeLoopExits is true
// will the first two checks cause enzyme to fail?
bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
loopIsFiniteByAssumption(L);
// Simplify the operands before analyzing them.
Expand Down Expand Up @@ -9184,15 +9246,19 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
if (EL.hasAnyInfo()) return EL;
break;
}

case ICmpInst::ICMP_SLE:
case ICmpInst::ICMP_ULE:
// Since the loop is finite, an invariant RHS cannot include the boundary
// value, otherwise it would loop forever.
if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
!isLoopInvariant(RHS, L))
break;
RHS = getAddExpr(getOne(RHS->getType()), RHS);
if (!AssumeLoopFinite) {
// Since the loop is finite, an invariant RHS cannot include the boundary
// value, otherwise it would loop forever.
if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
!isLoopInvariant(RHS, L))
break;
RHS = getAddExpr(getOne(RHS->getType()), RHS);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assignment to RHS should be outside the "if"?

}
[[fallthrough]];

case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_ULT: { // while (X < Y)
bool IsSigned = ICmpInst::isSigned(Pred);
Expand All @@ -9204,16 +9270,33 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
}
case ICmpInst::ICMP_SGE:
case ICmpInst::ICMP_UGE:
// Since the loop is finite, an invariant RHS cannot include the boundary
// value, otherwise it would loop forever.
if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
!isLoopInvariant(RHS, L))
break;
RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
if (!AssumeLoopFinite) {
// Since the loop is finite, an invariant RHS cannot include the boundary
// value, otherwise it would loop forever.
if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
!isLoopInvariant(RHS, L))
break;
RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
}
[[fallthrough]];
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_UGT: { // while (X > Y)
bool IsSigned = ICmpInst::isSigned(Pred);
if (AssumeLoopFinite) {
if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) {
if (!isa<IntegerType>(RHS->getType()))
break;
SmallVector<const SCEV *, 2> sv = {
RHS, getConstant(
ConstantInt::get(cast<IntegerType>(RHS->getType()), -1))};
// Since this is not an infinite loop by induction, RHS cannot be
// int_min/uint_min Therefore subtracting 1 does not wrap.
if (IsSigned)
RHS = getAddExpr(sv, SCEV::FlagNSW);
else
RHS = getAddExpr(sv, SCEV::FlagNUW);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like duplicated code from above (the EnableFiniteLoopControl bit).

}
}
ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
AllowPredicates);
if (EL.hasAnyInfo())
Expand All @@ -9238,8 +9321,14 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
if (Switch->getDefaultDest() == ExitingBlock)
return getCouldNotCompute();

assert(L->contains(Switch->getDefaultDest()) &&
"Default case must not exit the loop!");
// if not using enzyme executes by default
// if using enzyme and the code is guaranteed unreachable,
// the default destination doesn't matter
if (!AssumeLoopFinite ||
!GuaranteedUnreachable.count(Switch->getDefaultDest())) {
assert(L->contains(Switch->getDefaultDest()) &&
"Default case must not exit the loop!");
}
const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));

Expand Down Expand Up @@ -12752,9 +12841,9 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
// If RHS <=u Limit, then there must exist a value V in the sequence
// defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
// V <=u UINT_MAX. Thus, we must exit the loop before unsigned
// overflow occurs. This limit also implies that a signed comparison
// (in the wide bitwidth) is equivalent to an unsigned comparison as
// the high bits on both sides must be zero.
// overflow occurs. This limit also implies that a signed
// comparison (in the wide bitwidth) is equivalent to an unsigned
// comparison as the high bits on both sides must be zero.
APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
Limit = Limit.zext(OuterBitWidth);
Expand All @@ -12765,6 +12854,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
Flags = setFlags(Flags, SCEV::FlagNUW);

setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);

if (AR->hasNoUnsignedWrap()) {
// Emulate what getZeroExtendExpr would have done during construction
// if we'd been able to infer the fact just above at that time.
Expand Down Expand Up @@ -12848,6 +12938,13 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
!loopHasNoAbnormalExits(L))
return getCouldNotCompute();

// This bailout is protecting the logic in computeMaxBECountForLT which
// has not yet been sufficiently auditted or tested with negative strides.
// We used to filter out all known-non-positive cases here, we're in the
// process of being less restrictive bit by bit.
if (AssumeLoopFinite && IsSigned && isKnownNonPositive(Stride))
return getCouldNotCompute();

if (!isKnownNonZero(Stride)) {
// If we have a step of zero, and RHS isn't invariant in L, we don't know
// if it might eventually be greater than start and if so, on which
Expand Down Expand Up @@ -12977,13 +13074,20 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (!BECount) {
auto canProveRHSGreaterThanEqualStart = [&]() {
auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);

if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) {
return true;

}
// In the Enzyme MustExitScalarEvolutionCode, this check was missing
// I do not have enough context to know if these two checks should be
// mutually Exclusive. If they aren't then this bool check is unnecessary
if (!AssumeLoopFinite) {
const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);

if (isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
return true;
}
// (RHS > Start - 1) implies RHS >= Start.
// * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
// "Start - 1" doesn't overflow.
Expand Down Expand Up @@ -13120,7 +13224,10 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
!isa<SCEVCouldNotCompute>(BECount))
ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));

if (AssumeLoopFinite) {
return ExitLimit(BECount, ConstantMaxBECount, ConstantMaxBECount, MaxOrZero,
Predicates);
}
const SCEV *SymbolicMaxBECount =
isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
Expand Down