diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 547183d541728..b423b01b8033c 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1390,6 +1390,9 @@ class ScalarEvolution { /// True iff the backedge is taken either exactly Max or zero times. bool MaxOrZero = false; + /// SCEV expressions used in any of the ExitNotTakenInfo counts. + SmallPtrSet Operands; + bool isComplete() const { return IsComplete; } const SCEV *getConstantMax() const { return ConstantMax; } @@ -1458,7 +1461,7 @@ class ScalarEvolution { /// Return true if any backedge taken count expressions refer to the given /// subexpression. - bool hasOperand(const SCEV *S, ScalarEvolution *SE) const; + bool hasOperand(const SCEV *S) const; }; /// Cache the backedge-taken count of the loops for this function as they diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 1687929650f0a..a38ea84acd938 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7392,18 +7392,8 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue); } -bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, - ScalarEvolution *SE) const { - if (getConstantMax() && getConstantMax() != SE->getCouldNotCompute() && - SE->hasOperand(getConstantMax(), S)) - return true; - - for (auto &ENT : ExitNotTaken) - if (ENT.ExactNotTaken != SE->getCouldNotCompute() && - SE->hasOperand(ENT.ExactNotTaken, S)) - return true; - - return false; +bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S) const { + return Operands.contains(S); } ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) @@ -7445,6 +7435,19 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M, "No point in having a non-constant max backedge taken count!"); } +class SCEVRecordOperands { + SmallPtrSetImpl &Operands; + +public: + SCEVRecordOperands(SmallPtrSetImpl &Operands) + : Operands(Operands) {} + bool follow(const SCEV *S) { + Operands.insert(S); + return true; + } + bool isDone() { return false; } +}; + /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( @@ -7473,6 +7476,14 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( assert((isa(ConstantMax) || isa(ConstantMax)) && "No point in having a non-constant max backedge taken count!"); + + SCEVRecordOperands RecordOperands(Operands); + SCEVTraversal ST(RecordOperands); + if (!isa(ConstantMax)) + ST.visitAll(ConstantMax); + for (auto &ENT : ExitNotTaken) + if (!isa(ENT.ExactNotTaken)) + ST.visitAll(ENT.ExactNotTaken); } /// Compute the number of times the backedge of the specified loop will execute. @@ -12627,7 +12638,7 @@ ScalarEvolution::forgetMemoizedResults(const SCEV *S) { [S, this](DenseMap &Map) { for (auto I = Map.begin(), E = Map.end(); I != E;) { BackedgeTakenInfo &BEInfo = I->second; - if (BEInfo.hasOperand(S, this)) + if (BEInfo.hasOperand(S)) Map.erase(I++); else ++I;