Skip to content

Commit

Permalink
[SCEV] Simplify invalidation after BE count calculation (NFCI)
Browse files Browse the repository at this point in the history
After backedge taken counts have been calculated, we want to
invalidate all addrecs and dependent expressions in the loop,
because we might compute better results with the newly available
backedge taken counts. Previously this was done with a forgetLoop()
style use-def walk. With recent improvements to SCEV invalidation,
we can instead directly invalidate any SCEVs using addrecs in this
loop. This requires a great deal less subtlety to avoid invalidating
more than necessary, and in particular gets rid of the hack from
D113349. The change is similar to D114263 in spirit.
  • Loading branch information
nikic committed Nov 27, 2021
1 parent 1b2d58b commit c2550e3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 67 deletions.
5 changes: 2 additions & 3 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Expand Up @@ -1912,11 +1912,10 @@ class ScalarEvolution {
SCEV::NoWrapFlags &Flags);

/// Drop memoized information for all \p SCEVs.
void forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs,
bool SkipUnknownPhis = false);
void forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs);

/// Helper for forgetMemoizedResults.
void forgetMemoizedResultsImpl(const SCEV *S, bool SkipUnknownPhis = false);
void forgetMemoizedResultsImpl(const SCEV *S);

/// Return an existing SCEV for V if there is one, otherwise return nullptr.
const SCEV *getExistingSCEV(Value *V);
Expand Down
77 changes: 13 additions & 64 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Expand Up @@ -7610,62 +7610,19 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
// Now that we know more about the trip count for this loop, forget any
// existing SCEV values for PHI nodes in this loop since they are only
// conservative estimates made without the benefit of trip count
// information. This is similar to the code in forgetLoop, except that
// it handles SCEVUnknown PHI nodes specially.
// information. This invalidation is not necessary for correctness, and is
// only done to produce more precise results.
if (Result.hasAnyInfo()) {
SmallVector<Instruction *, 16> Worklist;
SmallPtrSet<Instruction *, 8> Discovered;
// Invalidate any expression using an addrec in this loop.
SmallVector<const SCEV *, 8> ToForget;
PushLoopPHIs(L, Worklist, Discovered);
while (!Worklist.empty()) {
Instruction *I = Worklist.pop_back_val();
auto LoopUsersIt = LoopUsers.find(L);
if (LoopUsersIt != LoopUsers.end())
append_range(ToForget, LoopUsersIt->second);
forgetMemoizedResults(ToForget);

ValueExprMapType::iterator It =
ValueExprMap.find_as(static_cast<Value *>(I));
if (It != ValueExprMap.end()) {
const SCEV *Old = It->second;

// SCEVUnknown for a PHI either means that it has an unrecognized
// structure, or it's a PHI that's in the progress of being computed
// by createNodeForPHI. In the former case, additional loop trip
// count information isn't going to change anything. In the later
// case, createNodeForPHI will perform the necessary updates on its
// own when it gets to that point.
if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
eraseValueFromMap(It->first);
ToForget.push_back(Old);
}
if (PHINode *PN = dyn_cast<PHINode>(I))
ConstantEvolutionLoopExitValue.erase(PN);
}

// Since we don't need to invalidate anything for correctness and we're
// only invalidating to make SCEV's results more precise, we get to stop
// early to avoid invalidating too much. This is especially important in
// cases like:
//
// %v = f(pn0, pn1) // pn0 and pn1 used through some other phi node
// loop0:
// %pn0 = phi
// ...
// loop1:
// %pn1 = phi
// ...
//
// where both loop0 and loop1's backedge taken count uses the SCEV
// expression for %v. If we don't have the early stop below then in cases
// like the above, getBackedgeTakenInfo(loop1) will clear out the trip
// count for loop0 and getBackedgeTakenInfo(loop0) will clear out the trip
// count for loop1, effectively nullifying SCEV's trip count cache.
for (auto *U : I->users())
if (auto *I = dyn_cast<Instruction>(U)) {
auto *LoopForUser = LI.getLoopFor(I->getParent());
if (LoopForUser && L->contains(LoopForUser) &&
Discovered.insert(I).second)
Worklist.push_back(I);
}
}
forgetMemoizedResults(ToForget, /* SkipUnknownPhis */ true);
// Invalidate constant-evolved loop header phis.
for (PHINode &PN : L->getHeader()->phis())
ConstantEvolutionLoopExitValue.erase(&PN);
}

// Re-lookup the insert position, since the call to
Expand Down Expand Up @@ -12958,8 +12915,7 @@ bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
}

void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs,
bool SkipUnknownPhis) {
void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());

Expand All @@ -12973,7 +12929,7 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs,
}

for (auto *S : ToForget)
forgetMemoizedResultsImpl(S, SkipUnknownPhis);
forgetMemoizedResultsImpl(S);

for (auto I = PredicatedSCEVRewrites.begin();
I != PredicatedSCEVRewrites.end();) {
Expand All @@ -13000,8 +12956,7 @@ void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs,
RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
}

void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S,
bool SkipUnknownPhis) {
void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
ValuesAtScopes.erase(S);
LoopDispositions.erase(S);
BlockDispositions.erase(S);
Expand All @@ -13013,12 +12968,6 @@ void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S,
auto ExprIt = ExprValueMap.find(S);
if (ExprIt != ExprValueMap.end()) {
for (auto &ValueAndOffset : ExprIt->second) {
// For some invalidations, it's important that symbolic SCEVUnknown
// placeholders do not get removed.
if (SkipUnknownPhis && isa<SCEVUnknown>(S) &&
isa<PHINode>(ValueAndOffset.first))
continue;

if (ValueAndOffset.second == nullptr) {
auto ValueIt = ValueExprMap.find_as(ValueAndOffset.first);
if (ValueIt != ValueExprMap.end())
Expand Down

0 comments on commit c2550e3

Please sign in to comment.