From eea887cf6be39856fa441ed48f72c1c9177a76a6 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sun, 25 Feb 2024 14:06:02 -0600 Subject: [PATCH 01/15] mainly pushing to switch machines --- llvm/include/llvm/Analysis/ScalarEvolution.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 0880f9c65aa45..1b03437de30c2 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1345,7 +1345,7 @@ class ScalarEvolution { } }; -private: +protected: /// A CallbackVH to arrange for ScalarEvolution to be notified whenever a /// Value is deleted. class SCEVCallbackVH final : public CallbackVH { From e47436b767d635c14c10fc8c0bfc4fe30b8967d6 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 29 Feb 2024 08:35:45 -0600 Subject: [PATCH 02/15] added AssumeLoopExits bool to SE, lifting MustExit code into SE --- llvm/include/llvm/Analysis/ScalarEvolution.h | 9 ++- .../llvm/Analysis/Utils/EnzymeFunctionUtils.h | 71 +++++++++++++++++++ llvm/lib/Analysis/ScalarEvolution.cpp | 8 ++- 3 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 1b03437de30c2..3075358e95791 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -460,6 +460,9 @@ class ScalarEvolution { LoopComputable ///< The SCEV varies predictably with the loop. }; + bool AssumeLoopExists = false; + void setAssumeLoopExists(); + /// An enum describing the relationship between a SCEV and a basic block. enum BlockDisposition { DoesNotDominateBlock, ///< The SCEV does not dominate the block. @@ -1345,7 +1348,7 @@ class ScalarEvolution { } }; -protected: + private: /// A CallbackVH to arrange for ScalarEvolution to be notified whenever a /// Value is deleted. class SCEVCallbackVH final : public CallbackVH { @@ -1364,7 +1367,7 @@ class ScalarEvolution { /// The function we are analyzing. Function &F; - + /// Does the module have any calls to the llvm.experimental.guard intrinsic /// at all? If this is false, we avoid doing work that will only help if /// thare are guards present in the IR. @@ -1765,7 +1768,7 @@ class ScalarEvolution { /// an arbitrary expression as opposed to only constants. const SCEV *computeSymbolicMaxBackedgeTakenCount(const Loop *L); - // Helper functions for computeExitLimitFromCond to avoid exponential time +// Helper functions for computeExitLimitFromCond to avoid exponential time // complexity. class ExitLimitCache { diff --git a/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h new file mode 100644 index 0000000000000..a211bdca6a47d --- /dev/null +++ b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h @@ -0,0 +1,71 @@ + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/TargetLibraryInfo.h" + +#include "llvm/IR/Function.h" + +#include "llvm/IR/Instructions.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include + + +// TODO note this doesn't go through [loop, unreachable], and we could get more +// performance by doing this can consider doing some domtree magic potentially +static inline llvm::SmallPtrSet +getGuaranteedUnreachable(llvm::Function *F) { + llvm::SmallPtrSet knownUnreachables; + if (F->empty()) + return knownUnreachables; + std::deque todo; + for (auto &BB : *F) { + todo.push_back(&BB); + } + + while (!todo.empty()) { + llvm::BasicBlock *next = todo.front(); + todo.pop_front(); + + if (knownUnreachables.find(next) != knownUnreachables.end()) + continue; + + if (llvm::isa(next->getTerminator())) + continue; + + if (llvm::isa(next->getTerminator())) { + knownUnreachables.insert(next); + for (llvm::BasicBlock *Pred : predecessors(next)) { + todo.push_back(Pred); + } + continue; + } + + // Assume resumes don't happen + // TODO consider EH + if (llvm::isa(next->getTerminator())) { + knownUnreachables.insert(next); + for (llvm::BasicBlock *Pred : predecessors(next)) { + todo.push_back(Pred); + } + continue; + } + + bool unreachable = true; + for (llvm::BasicBlock *Succ : llvm::successors(next)) { + if (knownUnreachables.find(Succ) == knownUnreachables.end()) { + unreachable = false; + break; + } + } + + if (!unreachable) + continue; + knownUnreachables.insert(next); + for (llvm::BasicBlock *Pred : llvm::predecessors(next)) { + todo.push_back(Pred); + } + continue; + } + + return knownUnreachables; +} \ No newline at end of file diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 4b2db80bc1ec3..6dc59108f5e18 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -82,6 +82,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Config/llvm-config.h" +#include "llvm/Analysis/Utils/EnzymeFunctionUtils.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -509,6 +510,10 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) { return S; } +void ScalarEvolution::setAssumeLoopExists() { + this->AssumeLoopExists=true; +} + SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {} @@ -7413,7 +7418,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); + return this->AssumeLoopExists || isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); } const SCEV *ScalarEvolution::createSCEVIter(Value *V) { @@ -13354,6 +13359,7 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { return getSizeOfExpr(ETy, Ty); } + //===----------------------------------------------------------------------===// // SCEVCallbackVH Class Implementation //===----------------------------------------------------------------------===// From f55e361a3ba1d4a5ca30f4b9719d23d57d273cc5 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 29 Feb 2024 09:51:55 -0600 Subject: [PATCH 03/15] added MustExitcode for computeExitLimit --- llvm/include/llvm/Analysis/ScalarEvolution.h | 7 ++-- .../llvm/Analysis/Utils/EnzymeFunctionUtils.h | 1 - llvm/lib/Analysis/ScalarEvolution.cpp | 32 +++++++++++++++---- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 3075358e95791..4cc1954c1233f 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -462,6 +462,7 @@ class ScalarEvolution { bool AssumeLoopExists = false; void setAssumeLoopExists(); + llvm::SmallPtrSet GuaranteedUnreachable; /// An enum describing the relationship between a SCEV and a basic block. enum BlockDisposition { @@ -1348,7 +1349,7 @@ class ScalarEvolution { } }; - private: +private: /// A CallbackVH to arrange for ScalarEvolution to be notified whenever a /// Value is deleted. class SCEVCallbackVH final : public CallbackVH { @@ -1367,7 +1368,7 @@ class ScalarEvolution { /// The function we are analyzing. Function &F; - + /// Does the module have any calls to the llvm.experimental.guard intrinsic /// at all? If this is false, we avoid doing work that will only help if /// thare are guards present in the IR. @@ -1768,7 +1769,7 @@ class ScalarEvolution { /// an arbitrary expression as opposed to only constants. const SCEV *computeSymbolicMaxBackedgeTakenCount(const Loop *L); -// Helper functions for computeExitLimitFromCond to avoid exponential time + // Helper functions for computeExitLimitFromCond to avoid exponential time // complexity. class ExitLimitCache { diff --git a/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h index a211bdca6a47d..59032cbe6dddd 100644 --- a/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h +++ b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h @@ -9,7 +9,6 @@ #include "llvm/Transforms/Utils/ValueMapper.h" #include - // TODO note this doesn't go through [loop, unreachable], and we could get more // performance by doing this can consider doing some domtree magic potentially static inline llvm::SmallPtrSet diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 6dc59108f5e18..c1071f07b7f28 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -80,9 +80,9 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/Utils/EnzymeFunctionUtils.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Config/llvm-config.h" -#include "llvm/Analysis/Utils/EnzymeFunctionUtils.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -510,9 +510,7 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) { return S; } -void ScalarEvolution::setAssumeLoopExists() { - this->AssumeLoopExists=true; -} +void ScalarEvolution::setAssumeLoopExists() { this->AssumeLoopExists = true; } SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) @@ -7418,7 +7416,8 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return this->AssumeLoopExists || isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); + return this->AssumeLoopExists || isFinite(L) || + (isMustProgress(L) && loopHasNoSideEffects(L)); } const SCEV *ScalarEvolution::createSCEVIter(Value *V) { @@ -8833,6 +8832,26 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates) { + if (AssumeLoopExists) { + SmallVector ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (auto &ExitingBlock : ExitingBlocks) { + BasicBlock *Exit = nullptr; + for (auto *SBB : successors(ExitingBlock)) { + if (!L->contains(SBB)) { + if (GuaranteedUnreachable.count(SBB)) + continue; + Exit = SBB; + break; + } + } + if (!Exit) + ExitingBlock = nullptr; + } + ExitingBlocks.erase( + std::remove(ExitingBlocks.begin(), ExitingBlocks.end(), nullptr), + ExitingBlocks.end()); + } assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); // If our exiting block does not dominate the latch, then its connection with // loop's exit limit may be far from trivial. @@ -8858,6 +8877,8 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, BasicBlock *Exit = nullptr; for (auto *SBB : successors(ExitingBlock)) if (!L->contains(SBB)) { + if (AssumeLoopExists and GuaranteedUnreachable.count(SBB)) + continue; if (Exit) // Multiple exit successors. return getCouldNotCompute(); Exit = SBB; @@ -13359,7 +13380,6 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { return getSizeOfExpr(ETy, Ty); } - //===----------------------------------------------------------------------===// // SCEVCallbackVH Class Implementation //===----------------------------------------------------------------------===// From 8e85c0653be244e036e68eb31a4022ff05b23257 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 29 Feb 2024 10:33:22 -0600 Subject: [PATCH 04/15] added enzyme mustExit code to computeExitLimitFromSingleExitSwitch --- llvm/lib/Analysis/ScalarEvolution.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index c1071f07b7f28..d28436e02466b 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9264,8 +9264,14 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, if (Switch->getDefaultDest() == ExitingBlock) return getCouldNotCompute(); - assert(L->contains(Switch->getDefaultDest()) && - "Default case must not exit the loop!"); + // if not using enzyme executes by default + // if using enzyme and the code is guaranteed unreachable, + // the default destination doesn't matter + if (!AssumeLoopExists || + !GuaranteedUnreachable.count(Switch->getDefaultDest())) { + assert(L->contains(Switch->getDefaultDest()) && + "Default case must not exit the loop!"); + } const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L); const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); From 3f378b5c9370355e3b5fc66709df06ec4f3970f3 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 29 Feb 2024 11:05:50 -0600 Subject: [PATCH 05/15] add enzyme must exit code to computeExitLimitFromCondImpl --- llvm/lib/Analysis/ScalarEvolution.cpp | 109 ++++++++++++++++++++++++-- 1 file changed, 104 insertions(+), 5 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index d28436e02466b..62f8ddfa72081 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8949,10 +8949,104 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { - // Handle BinOp conditions (And, Or). - if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( - Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) - return *LimitFromBinOp; + if (!AssumeLoopExists) { + // Handle BinOp conditions (And, Or). + if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( + Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) + return *LimitFromBinOp; + } else { + // Check if the controlling expression for this loop is an And or Or. + if (BinaryOperator *BO = dyn_cast(ExitCond)) { + if (BO->getOpcode() == Instruction::And) { + // Recurse on the operands of the and. + bool EitherMayExit = !ExitIfTrue; + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(0), ExitIfTrue, + ControlsOnlyExit && !EitherMayExit, AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(1), ExitIfTrue, + ControlsOnlyExit && !EitherMayExit, AllowPredicates); + const SCEV *BECount = getCouldNotCompute(); + const SCEV *MaxBECount = getCouldNotCompute(); + if (EitherMayExit) { + // Both conditions must be true for the loop to continue executing. + // Choose the less conservative count. + if (EL0.ExactNotTaken == getCouldNotCompute() || + EL1.ExactNotTaken == getCouldNotCompute()) + BECount = getCouldNotCompute(); + else + BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, + EL1.ExactNotTaken); + + if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) + MaxBECount = EL1.ConstantMaxNotTaken; + else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) + MaxBECount = EL0.ConstantMaxNotTaken; + else + MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken, + EL1.ConstantMaxNotTaken); + } else { + // Both conditions must be true at the same time for the loop to exit. + // For now, be conservative. + if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken) + MaxBECount = EL0.ConstantMaxNotTaken; + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + BECount = EL0.ExactNotTaken; + } + + // There are cases (e.g. PR26207) where computeExitLimitFromCond is able + // to be more aggressive when computing BECount than when computing + // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and + // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and + // EL1.ConstantMaxNotTaken to not. + if (isa(MaxBECount) && + !isa(BECount)) + MaxBECount = getConstant(getUnsignedRangeMax(BECount)); + + return ExitLimit(BECount, MaxBECount, MaxBECount, false, + {&EL0.Predicates, &EL1.Predicates}); + } + if (BO->getOpcode() == Instruction::Or) { + // Recurse on the operands of the or. + bool EitherMayExit = ExitIfTrue; + ExitLimit EL0 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(0), ExitIfTrue, + ControlsOnlyExit && !EitherMayExit, AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached( + Cache, L, BO->getOperand(1), ExitIfTrue, + ControlsOnlyExit && !EitherMayExit, AllowPredicates); + const SCEV *BECount = getCouldNotCompute(); + const SCEV *MaxBECount = getCouldNotCompute(); + if (EitherMayExit) { + // Both conditions must be false for the loop to continue executing. + // Choose the less conservative count. + if (EL0.ExactNotTaken == getCouldNotCompute() || + EL1.ExactNotTaken == getCouldNotCompute()) + BECount = getCouldNotCompute(); + else + BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, + EL1.ExactNotTaken); + + if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) + MaxBECount = EL1.ConstantMaxNotTaken; + else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) + MaxBECount = EL0.ConstantMaxNotTaken; + else + MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken, + EL1.ConstantMaxNotTaken); + } else { + // Both conditions must be false at the same time for the loop to + // exit. For now, be conservative. + if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken) + MaxBECount = EL0.ConstantMaxNotTaken; + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + BECount = EL0.ExactNotTaken; + } + return ExitLimit(BECount, MaxBECount, MaxBECount, false, + {&EL0.Predicates, &EL1.Predicates}); + } + } + } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. @@ -8973,12 +9067,17 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( // preserve the CFG and is temporarily leaving constant conditions // in place. if (ConstantInt *CI = dyn_cast(ExitCond)) { - if (ExitIfTrue == !CI->getZExtValue()) + if (ExitIfTrue == !CI->getZExtValue()) { // The backedge is always taken. return getCouldNotCompute(); + } // The backedge is never taken. return getZero(CI->getType()); } + // The rest of this code was missing from the MustExitScalarEvolution + // overrides + // so this should never be reached if using enzyme + assert(!AssumeLoopExists); // If we're exiting based on the overflow flag of an x.with.overflow intrinsic // with a constant step, we can form an equivalent icmp predicate and figure From 14a0c6c187d61db2e017202283be20d17cc93ed7 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 29 Feb 2024 11:31:39 -0600 Subject: [PATCH 06/15] implemented enzyme must exit code in computeExitLimitFromICmp --- llvm/lib/Analysis/ScalarEvolution.cpp | 183 +++++++++++++++++++++----- 1 file changed, 151 insertions(+), 32 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 62f8ddfa72081..b6e88b563e272 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9074,35 +9074,32 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( // The backedge is never taken. return getZero(CI->getType()); } - // The rest of this code was missing from the MustExitScalarEvolution - // overrides - // so this should never be reached if using enzyme - assert(!AssumeLoopExists); - - // If we're exiting based on the overflow flag of an x.with.overflow intrinsic - // with a constant step, we can form an equivalent icmp predicate and figure - // out how many iterations will be taken before we exit. - const WithOverflowInst *WO; - const APInt *C; - if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) && - match(WO->getRHS(), m_APInt(C))) { - ConstantRange NWR = - ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, - WO->getNoWrapKind()); - CmpInst::Predicate Pred; - APInt NewRHSC, Offset; - NWR.getEquivalentICmp(Pred, NewRHSC, Offset); - if (!ExitIfTrue) - Pred = ICmpInst::getInversePredicate(Pred); - auto *LHS = getSCEV(WO->getLHS()); - if (Offset != 0) - LHS = getAddExpr(LHS, getConstant(Offset)); - auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), - ControlsOnlyExit, AllowPredicates); - if (EL.hasAnyInfo()) - return EL; - } + // block was never executed in MustExitScalarEvolution code + if (!AssumeLoopExists) { + // If we're exiting based on the overflow flag of an x.with.overflow + // intrinsic with a constant step, we can form an equivalent icmp predicate + // and figure out how many iterations will be taken before we exit. + const WithOverflowInst *WO; + const APInt *C; + if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) && + match(WO->getRHS(), m_APInt(C))) { + ConstantRange NWR = ConstantRange::makeExactNoWrapRegion( + WO->getBinaryOp(), *C, WO->getNoWrapKind()); + CmpInst::Predicate Pred; + APInt NewRHSC, Offset; + NWR.getEquivalentICmp(Pred, NewRHSC, Offset); + if (!ExitIfTrue) + Pred = ICmpInst::getInversePredicate(Pred); + auto *LHS = getSCEV(WO->getLHS()); + if (Offset != 0) + LHS = getAddExpr(LHS, getConstant(Offset)); + auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), + ControlsOnlyExit, AllowPredicates); + if (EL.hasAnyInfo()) + return EL; + } + } // If it's not an integer or pointer comparison then compute it the hard way. return computeExitCountExhaustively(L, ExitCond, ExitIfTrue); } @@ -9201,12 +9198,134 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); + if (!AssumeLoopExists) { + ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, + AllowPredicates); + if (EL.hasAnyInfo()) + return EL; + } else { +#define PROP_PHI(LHS) \ + if (auto un = dyn_cast(LHS)) { \ + if (auto pn = dyn_cast_or_null(un->getValue())) { \ + const SCEV *sc = nullptr; \ + bool failed = false; \ + for (auto &a : pn->incoming_values()) { \ + auto subsc = getSCEV(a); \ + if (sc == nullptr) { \ + sc = subsc; \ + continue; \ + } \ + if (subsc != sc) { \ + failed = true; \ + break; \ + } \ + } \ + if (!failed) { \ + LHS = sc; \ + } \ + } \ + } + PROP_PHI(LHS) + PROP_PHI(RHS) + + // Try to evaluate any dependencies out of the loop. + LHS = getSCEVAtScope(LHS, L); + RHS = getSCEVAtScope(RHS, L); + + // At this point, we would like to compute how many iterations of the + // loop the predicate will return true for these inputs. + if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { + // If there is a loop-invariant, force it into the RHS. + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } - ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, - AllowPredicates); - if (EL.hasAnyInfo()) - return EL; + // Simplify the operands before analyzing them. + (void)SimplifyICmpOperands(Pred, LHS, RHS); + // If we have a comparison of a chrec against a constant, try to use value + // ranges to answer this query. + if (const SCEVConstant *RHSC = dyn_cast(RHS)) + if (const SCEVAddRecExpr *AddRec = dyn_cast(LHS)) + if (AddRec->getLoop() == L) { + // Form the constant range. + ConstantRange CompRange = + ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); + + const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); + if (!isa(Ret)) + return Ret; + } + + switch (Pred) { + case ICmpInst::ICMP_NE: { // while (X != Y) + // Convert to: while (X-Y != 0) + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit, + AllowPredicates); + if (EL.hasAnyInfo()) + return EL; + break; + } + case ICmpInst::ICMP_EQ: { // while (X == Y) + // Convert to: while (X-Y == 0) + ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); + if (EL.hasAnyInfo()) + return EL; + break; + } + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_ULE: { // while (X < Y) + bool IsSigned = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; + + if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) { + if (!isa(RHS->getType())) + break; + SmallVector sv = { + RHS, getConstant( + ConstantInt::get(cast(RHS->getType()), 1))}; + // Since this is not an infinite loop by induction, RHS cannot be + // int_max/uint_max Therefore adding 1 does not wrap. + if (IsSigned) + RHS = getAddExpr(sv, SCEV::FlagNSW); + else + RHS = getAddExpr(sv, SCEV::FlagNUW); + } + ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, + AllowPredicates); + if (EL.hasAnyInfo()) + return EL; + break; + } + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: { // while (X > Y) + bool IsSigned = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE; + if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) { + if (!isa(RHS->getType())) + break; + SmallVector sv = { + RHS, getConstant( + ConstantInt::get(cast(RHS->getType()), -1))}; + // Since this is not an infinite loop by induction, RHS cannot be + // int_min/uint_min Therefore subtracting 1 does not wrap. + if (IsSigned) + RHS = getAddExpr(sv, SCEV::FlagNSW); + else + RHS = getAddExpr(sv, SCEV::FlagNUW); + } + ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, + ControlsOnlyExit, AllowPredicates); + if (EL.hasAnyInfo()) + return EL; + break; + } + default: + break; + } + } auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue); From abb0ab463de42b5b66261fed48de69d8980b30c0 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 29 Feb 2024 14:30:36 -0600 Subject: [PATCH 07/15] add Enzyme changes to SE howManyLessThans --- llvm/lib/Analysis/ScalarEvolution.cpp | 100 ++++++++++++++++---------- 1 file changed, 63 insertions(+), 37 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index b6e88b563e272..854cfec1e6805 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -12983,38 +12983,50 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (auto *ZExt = dyn_cast(LHS)) { const SCEVAddRecExpr *AR = dyn_cast(ZExt->getOperand()); if (AR && AR->getLoop() == L && AR->isAffine()) { - auto canProveNUW = [&]() { - // We can use the comparison to infer no-wrap flags only if it fully - // controls the loop exit. - if (!ControlsOnlyExit) - return false; - - if (!isLoopInvariant(RHS, L)) - return false; - - if (!isKnownNonZero(AR->getStepRecurrence(*this))) - // We need the sequence defined by AR to strictly increase in the - // unsigned integer domain for the logic below to hold. - return false; - - const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType()); - const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType()); - // If RHS <=u Limit, then there must exist a value V in the sequence - // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and - // V <=u UINT_MAX. Thus, we must exit the loop before unsigned - // overflow occurs. This limit also implies that a signed comparison - // (in the wide bitwidth) is equivalent to an unsigned comparison as - // the high bits on both sides must be zero. - APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this)); - APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1); - Limit = Limit.zext(OuterBitWidth); - return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit); - }; - auto Flags = AR->getNoWrapFlags(); - if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW()) - Flags = setFlags(Flags, SCEV::FlagNUW); + if (!AssumeLoopExists) { + auto canProveNUW = [&]() { + // We can use the comparison to infer no-wrap flags only if it fully + // controls the loop exit. + if (!ControlsOnlyExit) + return false; - setNoWrapFlags(const_cast(AR), Flags); + if (!isLoopInvariant(RHS, L)) + return false; + + if (!isKnownNonZero(AR->getStepRecurrence(*this))) + // We need the sequence defined by AR to strictly increase in the + // unsigned integer domain for the logic below to hold. + return false; + + const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType()); + const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType()); + // If RHS <=u Limit, then there must exist a value V in the sequence + // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and + // V <=u UINT_MAX. Thus, we must exit the loop before unsigned + // overflow occurs. This limit also implies that a signed + // comparison (in the wide bitwidth) is equivalent to an unsigned + // comparison as the high bits on both sides must be zero. + APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this)); + APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1); + Limit = Limit.zext(OuterBitWidth); + return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit); + }; + auto Flags = AR->getNoWrapFlags(); + if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW()) + Flags = setFlags(Flags, SCEV::FlagNUW); + + setNoWrapFlags(const_cast(AR), Flags); + } else { + auto Flags = AR->getNoWrapFlags(); + if (!hasFlags(Flags, SCEV::FlagNW) && canAssumeNoSelfWrap(AR)) { + Flags = setFlags(Flags, SCEV::FlagNW); + + SmallVector Operands{AR->operands()}; + Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); + + setNoWrapFlags(const_cast(AR), Flags); + } + } if (AR->hasNoUnsignedWrap()) { // Emulate what getZeroExtendExpr would have done during construction // if we'd been able to infer the fact just above at that time. @@ -13098,6 +13110,13 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, !loopHasNoAbnormalExits(L)) return getCouldNotCompute(); + // This bailout is protecting the logic in computeMaxBECountForLT which + // has not yet been sufficiently auditted or tested with negative strides. + // We used to filter out all known-non-positive cases here, we're in the + // process of being less restrictive bit by bit. + if (AssumeLoopExists && IsSigned && isKnownNonPositive(Stride)) + return getCouldNotCompute(); + if (!isKnownNonZero(Stride)) { // If we have a step of zero, and RHS isn't invariant in L, we don't know // if it might eventually be greater than start and if so, on which @@ -13227,13 +13246,17 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!BECount) { auto canProveRHSGreaterThanEqualStart = [&]() { auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; - const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); - const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); - if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) || - isKnownPredicate(CondGE, GuardedRHS, GuardedStart)) - return true; + if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) { + if (AssumeLoopExists) { + return true; + } + const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); + const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); + if (isKnownPredicate(CondGE, GuardedRHS, GuardedStart)) + return true; + } // (RHS > Start - 1) implies RHS >= Start. // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if // "Start - 1" doesn't overflow. @@ -13370,7 +13393,10 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa(ConstantMaxBECount) && !isa(BECount)) ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - + if (AssumeLoopExists) { + return ExitLimit(BECount, ConstantMaxBECount, ConstantMaxBECount, MaxOrZero, + Predicates); + } const SCEV *SymbolicMaxBECount = isa(BECount) ? ConstantMaxBECount : BECount; return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero, From c1d83de8b8bb83f3a93e4c271ab9dfd10f7e7950 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Fri, 1 Mar 2024 11:21:01 -0600 Subject: [PATCH 08/15] fixed issue in howManyLessThans where conditions were incorrectly dependent --- llvm/lib/Analysis/ScalarEvolution.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 854cfec1e6805..492b33e0a7c23 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7416,7 +7416,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return this->AssumeLoopExists || isFinite(L) || + return AssumeLoopExists || isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); } @@ -13248,9 +13248,12 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart)) { - if (AssumeLoopExists) { - return true; - } + return true; + } + // In the Enzyme MustExitScalarEvolutionCode, this check was missing + // I do not have enough context to know if these two checks should be + // mutually Exclusive. If they aren't then this bool check is unnecessary + if (!AssumeLoopExists) { const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); From 66ab0c3e093988ac51258837ccc063f0955d7417 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 7 Mar 2024 09:47:02 -0600 Subject: [PATCH 09/15] incorporating changes from code review --- llvm/include/llvm/Analysis/ScalarEvolution.h | 4 +- llvm/lib/Analysis/ScalarEvolution.cpp | 139 +++++++++---------- 2 files changed, 66 insertions(+), 77 deletions(-) diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 4cc1954c1233f..50dbe2aeec884 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -460,8 +460,8 @@ class ScalarEvolution { LoopComputable ///< The SCEV varies predictably with the loop. }; - bool AssumeLoopExists = false; - void setAssumeLoopExists(); + bool AssumeLoopExits = false; + void setAssumeLoopExits(); llvm::SmallPtrSet GuaranteedUnreachable; /// An enum describing the relationship between a SCEV and a basic block. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 492b33e0a7c23..3b1fcab6e333b 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -510,7 +510,7 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) { return S; } -void ScalarEvolution::setAssumeLoopExists() { this->AssumeLoopExists = true; } +void ScalarEvolution::setAssumeLoopExits() { this->AssumeLoopExits = true; } SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) @@ -7416,7 +7416,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return AssumeLoopExists || isFinite(L) || + return AssumeLoopExits || isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); } @@ -8832,7 +8832,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates) { - if (AssumeLoopExists) { + if (AssumeLoopExits) { SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); for (auto &ExitingBlock : ExitingBlocks) { @@ -8877,7 +8877,7 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, BasicBlock *Exit = nullptr; for (auto *SBB : successors(ExitingBlock)) if (!L->contains(SBB)) { - if (AssumeLoopExists and GuaranteedUnreachable.count(SBB)) + if (AssumeLoopExits and GuaranteedUnreachable.count(SBB)) continue; if (Exit) // Multiple exit successors. return getCouldNotCompute(); @@ -8949,7 +8949,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { - if (!AssumeLoopExists) { + if (!AssumeLoopExits) { // Handle BinOp conditions (And, Or). if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) @@ -9076,30 +9076,30 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( } // block was never executed in MustExitScalarEvolution code - if (!AssumeLoopExists) { - // If we're exiting based on the overflow flag of an x.with.overflow - // intrinsic with a constant step, we can form an equivalent icmp predicate - // and figure out how many iterations will be taken before we exit. - const WithOverflowInst *WO; - const APInt *C; - if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) && - match(WO->getRHS(), m_APInt(C))) { - ConstantRange NWR = ConstantRange::makeExactNoWrapRegion( - WO->getBinaryOp(), *C, WO->getNoWrapKind()); - CmpInst::Predicate Pred; - APInt NewRHSC, Offset; - NWR.getEquivalentICmp(Pred, NewRHSC, Offset); - if (!ExitIfTrue) - Pred = ICmpInst::getInversePredicate(Pred); - auto *LHS = getSCEV(WO->getLHS()); - if (Offset != 0) - LHS = getAddExpr(LHS, getConstant(Offset)); - auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), - ControlsOnlyExit, AllowPredicates); - if (EL.hasAnyInfo()) - return EL; - } + + // If we're exiting based on the overflow flag of an x.with.overflow + // intrinsic with a constant step, we can form an equivalent icmp predicate + // and figure out how many iterations will be taken before we exit. + const WithOverflowInst *WO; + const APInt *C; + if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) && + match(WO->getRHS(), m_APInt(C))) { + ConstantRange NWR = ConstantRange::makeExactNoWrapRegion( + WO->getBinaryOp(), *C, WO->getNoWrapKind()); + CmpInst::Predicate Pred; + APInt NewRHSC, Offset; + NWR.getEquivalentICmp(Pred, NewRHSC, Offset); + if (!ExitIfTrue) + Pred = ICmpInst::getInversePredicate(Pred); + auto *LHS = getSCEV(WO->getLHS()); + if (Offset != 0) + LHS = getAddExpr(LHS, getConstant(Offset)); + auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC), + ControlsOnlyExit, AllowPredicates); + if (EL.hasAnyInfo()) + return EL; } + // If it's not an integer or pointer comparison then compute it the hard way. return computeExitCountExhaustively(L, ExitCond, ExitIfTrue); } @@ -9198,7 +9198,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); - if (!AssumeLoopExists) { + if (!AssumeLoopExits) { ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) @@ -9485,7 +9485,7 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, // if not using enzyme executes by default // if using enzyme and the code is guaranteed unreachable, // the default destination doesn't matter - if (!AssumeLoopExists || + if (!AssumeLoopExits || !GuaranteedUnreachable.count(Switch->getDefaultDest())) { assert(L->contains(Switch->getDefaultDest()) && "Default case must not exit the loop!"); @@ -12983,50 +12983,39 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (auto *ZExt = dyn_cast(LHS)) { const SCEVAddRecExpr *AR = dyn_cast(ZExt->getOperand()); if (AR && AR->getLoop() == L && AR->isAffine()) { - if (!AssumeLoopExists) { - auto canProveNUW = [&]() { - // We can use the comparison to infer no-wrap flags only if it fully - // controls the loop exit. - if (!ControlsOnlyExit) - return false; - - if (!isLoopInvariant(RHS, L)) - return false; - - if (!isKnownNonZero(AR->getStepRecurrence(*this))) - // We need the sequence defined by AR to strictly increase in the - // unsigned integer domain for the logic below to hold. - return false; - - const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType()); - const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType()); - // If RHS <=u Limit, then there must exist a value V in the sequence - // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and - // V <=u UINT_MAX. Thus, we must exit the loop before unsigned - // overflow occurs. This limit also implies that a signed - // comparison (in the wide bitwidth) is equivalent to an unsigned - // comparison as the high bits on both sides must be zero. - APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this)); - APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1); - Limit = Limit.zext(OuterBitWidth); - return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit); - }; - auto Flags = AR->getNoWrapFlags(); - if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW()) - Flags = setFlags(Flags, SCEV::FlagNUW); - - setNoWrapFlags(const_cast(AR), Flags); - } else { - auto Flags = AR->getNoWrapFlags(); - if (!hasFlags(Flags, SCEV::FlagNW) && canAssumeNoSelfWrap(AR)) { - Flags = setFlags(Flags, SCEV::FlagNW); + auto canProveNUW = [&]() { + // We can use the comparison to infer no-wrap flags only if it fully + // controls the loop exit. + if (!ControlsOnlyExit) + return false; + + if (!isLoopInvariant(RHS, L)) + return false; + + if (!isKnownNonZero(AR->getStepRecurrence(*this))) + // We need the sequence defined by AR to strictly increase in the + // unsigned integer domain for the logic below to hold. + return false; + + const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType()); + const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType()); + // If RHS <=u Limit, then there must exist a value V in the sequence + // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and + // V <=u UINT_MAX. Thus, we must exit the loop before unsigned + // overflow occurs. This limit also implies that a signed + // comparison (in the wide bitwidth) is equivalent to an unsigned + // comparison as the high bits on both sides must be zero. + APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this)); + APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1); + Limit = Limit.zext(OuterBitWidth); + return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit); + }; + auto Flags = AR->getNoWrapFlags(); + if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW()) + Flags = setFlags(Flags, SCEV::FlagNUW); - SmallVector Operands{AR->operands()}; - Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); + setNoWrapFlags(const_cast(AR), Flags); - setNoWrapFlags(const_cast(AR), Flags); - } - } if (AR->hasNoUnsignedWrap()) { // Emulate what getZeroExtendExpr would have done during construction // if we'd been able to infer the fact just above at that time. @@ -13114,7 +13103,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // has not yet been sufficiently auditted or tested with negative strides. // We used to filter out all known-non-positive cases here, we're in the // process of being less restrictive bit by bit. - if (AssumeLoopExists && IsSigned && isKnownNonPositive(Stride)) + if (AssumeLoopExits && IsSigned && isKnownNonPositive(Stride)) return getCouldNotCompute(); if (!isKnownNonZero(Stride)) { @@ -13253,7 +13242,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // In the Enzyme MustExitScalarEvolutionCode, this check was missing // I do not have enough context to know if these two checks should be // mutually Exclusive. If they aren't then this bool check is unnecessary - if (!AssumeLoopExists) { + if (!AssumeLoopExits) { const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); @@ -13396,7 +13385,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa(ConstantMaxBECount) && !isa(BECount)) ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - if (AssumeLoopExists) { + if (AssumeLoopExits) { return ExitLimit(BECount, ConstantMaxBECount, ConstantMaxBECount, MaxOrZero, Predicates); } From 9b57191bf32c57dc62927cbc7d1c17ad04f4d91d Mon Sep 17 00:00:00 2001 From: skewballfox Date: Mon, 11 Mar 2024 08:44:39 -0500 Subject: [PATCH 10/15] removed unrelated change --- llvm/lib/Analysis/ScalarEvolution.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 3b1fcab6e333b..53aa2faacf1cd 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9067,16 +9067,14 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( // preserve the CFG and is temporarily leaving constant conditions // in place. if (ConstantInt *CI = dyn_cast(ExitCond)) { - if (ExitIfTrue == !CI->getZExtValue()) { + if (ExitIfTrue == !CI->getZExtValue()) // The backedge is always taken. return getCouldNotCompute(); - } + // The backedge is never taken. return getZero(CI->getType()); } - // block was never executed in MustExitScalarEvolution code - // If we're exiting based on the overflow flag of an x.with.overflow // intrinsic with a constant step, we can form an equivalent icmp predicate // and figure out how many iterations will be taken before we exit. From 57767932c2ce69a71f672da7bc115e2796e529f9 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 14 Mar 2024 12:13:45 -0500 Subject: [PATCH 11/15] moved mustexit code to other computeExitLimitFromICmp definition --- llvm/lib/Analysis/ScalarEvolution.cpp | 191 +++++++++----------------- 1 file changed, 66 insertions(+), 125 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 53aa2faacf1cd..9c54dcb0e3f90 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9196,12 +9196,25 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); - if (!AssumeLoopExits) { + ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; - } else { + + auto *ExhaustiveCount = + computeExitCountExhaustively(L, ExitCond, ExitIfTrue); + + if (!isa(ExhaustiveCount)) + return ExhaustiveCount; + + return computeShiftCompareExitLimit( + ExitCond->getOperand(0), ExitCond->getOperand(1), L, OriginalPred); +} +ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( + const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + bool ControlsOnlyExit, bool AllowPredicates) { + if (AssumeLoopExits) { #define PROP_PHI(LHS) \ if (auto un = dyn_cast(LHS)) { \ if (auto pn = dyn_cast_or_null(un->getValue())) { \ @@ -9225,118 +9238,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( } PROP_PHI(LHS) PROP_PHI(RHS) - - // Try to evaluate any dependencies out of the loop. - LHS = getSCEVAtScope(LHS, L); - RHS = getSCEVAtScope(RHS, L); - - // At this point, we would like to compute how many iterations of the - // loop the predicate will return true for these inputs. - if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { - // If there is a loop-invariant, force it into the RHS. - std::swap(LHS, RHS); - Pred = ICmpInst::getSwappedPredicate(Pred); - } - - // Simplify the operands before analyzing them. - (void)SimplifyICmpOperands(Pred, LHS, RHS); - - // If we have a comparison of a chrec against a constant, try to use value - // ranges to answer this query. - if (const SCEVConstant *RHSC = dyn_cast(RHS)) - if (const SCEVAddRecExpr *AddRec = dyn_cast(LHS)) - if (AddRec->getLoop() == L) { - // Form the constant range. - ConstantRange CompRange = - ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); - - const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); - if (!isa(Ret)) - return Ret; - } - - switch (Pred) { - case ICmpInst::ICMP_NE: { // while (X != Y) - // Convert to: while (X-Y != 0) - ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit, - AllowPredicates); - if (EL.hasAnyInfo()) - return EL; - break; - } - case ICmpInst::ICMP_EQ: { // while (X == Y) - // Convert to: while (X-Y == 0) - ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); - if (EL.hasAnyInfo()) - return EL; - break; - } - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_SLE: - case ICmpInst::ICMP_ULE: { // while (X < Y) - bool IsSigned = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE; - - if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) { - if (!isa(RHS->getType())) - break; - SmallVector sv = { - RHS, getConstant( - ConstantInt::get(cast(RHS->getType()), 1))}; - // Since this is not an infinite loop by induction, RHS cannot be - // int_max/uint_max Therefore adding 1 does not wrap. - if (IsSigned) - RHS = getAddExpr(sv, SCEV::FlagNSW); - else - RHS = getAddExpr(sv, SCEV::FlagNUW); - } - ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, - AllowPredicates); - if (EL.hasAnyInfo()) - return EL; - break; - } - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_SGE: - case ICmpInst::ICMP_UGE: { // while (X > Y) - bool IsSigned = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE; - if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) { - if (!isa(RHS->getType())) - break; - SmallVector sv = { - RHS, getConstant( - ConstantInt::get(cast(RHS->getType()), -1))}; - // Since this is not an infinite loop by induction, RHS cannot be - // int_min/uint_min Therefore subtracting 1 does not wrap. - if (IsSigned) - RHS = getAddExpr(sv, SCEV::FlagNSW); - else - RHS = getAddExpr(sv, SCEV::FlagNUW); - } - ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, - ControlsOnlyExit, AllowPredicates); - if (EL.hasAnyInfo()) - return EL; - break; - } - default: - break; - } } - auto *ExhaustiveCount = - computeExitCountExhaustively(L, ExitCond, ExitIfTrue); - - if (!isa(ExhaustiveCount)) - return ExhaustiveCount; - - return computeShiftCompareExitLimit(ExitCond->getOperand(0), - ExitCond->getOperand(1), L, OriginalPred); -} -ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( - const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - bool ControlsOnlyExit, bool AllowPredicates) { - // Try to evaluate any dependencies out of the loop. LHS = getSCEVAtScope(LHS, L); RHS = getSCEVAtScope(RHS, L); @@ -9349,6 +9251,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( Pred = ICmpInst::getSwappedPredicate(Pred); } + // was not present in Enzyme code, the last condition is true if + // AssumeLoopExits is true + // will the first two checks cause enzyme to fail? bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) && loopIsFiniteByAssumption(L); // Simplify the operands before analyzing them. @@ -9426,18 +9331,37 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( if (EL.hasAnyInfo()) return EL; break; } + case ICmpInst::ICMP_SLE: case ICmpInst::ICMP_ULE: - // Since the loop is finite, an invariant RHS cannot include the boundary - // value, otherwise it would loop forever. - if (!EnableFiniteLoopControl || !ControllingFiniteLoop || - !isLoopInvariant(RHS, L)) - break; - RHS = getAddExpr(getOne(RHS->getType()), RHS); + if (!AssumeLoopExits) { + // Since the loop is finite, an invariant RHS cannot include the boundary + // value, otherwise it would loop forever. + if (!EnableFiniteLoopControl || !ControllingFiniteLoop || + !isLoopInvariant(RHS, L)) + break; + RHS = getAddExpr(getOne(RHS->getType()), RHS); + } [[fallthrough]]; + case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = ICmpInst::isSigned(Pred); + if (AssumeLoopExits) { + if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) { + if (!isa(RHS->getType())) + break; + SmallVector sv = { + RHS, getConstant( + ConstantInt::get(cast(RHS->getType()), 1))}; + // Since this is not an infinite loop by induction, RHS cannot be + // int_max/uint_max Therefore adding 1 does not wrap. + if (IsSigned) + RHS = getAddExpr(sv, SCEV::FlagNSW); + else + RHS = getAddExpr(sv, SCEV::FlagNUW); + } + } ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) @@ -9446,16 +9370,33 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( } case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_UGE: - // Since the loop is finite, an invariant RHS cannot include the boundary - // value, otherwise it would loop forever. - if (!EnableFiniteLoopControl || !ControllingFiniteLoop || - !isLoopInvariant(RHS, L)) - break; - RHS = getAddExpr(getMinusOne(RHS->getType()), RHS); + if (!AssumeLoopExits) { + // Since the loop is finite, an invariant RHS cannot include the boundary + // value, otherwise it would loop forever. + if (!EnableFiniteLoopControl || !ControllingFiniteLoop || + !isLoopInvariant(RHS, L)) + break; + RHS = getAddExpr(getMinusOne(RHS->getType()), RHS); + } [[fallthrough]]; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = ICmpInst::isSigned(Pred); + if (AssumeLoopExits) { + if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) { + if (!isa(RHS->getType())) + break; + SmallVector sv = { + RHS, getConstant( + ConstantInt::get(cast(RHS->getType()), -1))}; + // Since this is not an infinite loop by induction, RHS cannot be + // int_min/uint_min Therefore subtracting 1 does not wrap. + if (IsSigned) + RHS = getAddExpr(sv, SCEV::FlagNSW); + else + RHS = getAddExpr(sv, SCEV::FlagNUW); + } + } ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) From bdabce85e51a2da335dd9c669981624c1eb1e6b6 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 14 Mar 2024 12:45:11 -0500 Subject: [PATCH 12/15] reran git clang-format HEAD~1 --- llvm/lib/Analysis/ScalarEvolution.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 9c54dcb0e3f90..e11d02f2c12e1 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9197,19 +9197,18 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); - ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, - AllowPredicates); - if (EL.hasAnyInfo()) - return EL; + ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit, + AllowPredicates); + if (EL.hasAnyInfo()) + return EL; - auto *ExhaustiveCount = - computeExitCountExhaustively(L, ExitCond, ExitIfTrue); + auto *ExhaustiveCount = computeExitCountExhaustively(L, ExitCond, ExitIfTrue); - if (!isa(ExhaustiveCount)) - return ExhaustiveCount; + if (!isa(ExhaustiveCount)) + return ExhaustiveCount; - return computeShiftCompareExitLimit( - ExitCond->getOperand(0), ExitCond->getOperand(1), L, OriginalPred); + return computeShiftCompareExitLimit(ExitCond->getOperand(0), + ExitCond->getOperand(1), L, OriginalPred); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, From a9c9251be1038e8af49402a30f96249448d68ecc Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 14 Mar 2024 14:25:18 -0500 Subject: [PATCH 13/15] removed redundant binOp code from CondImpl --- llvm/lib/Analysis/ScalarEvolution.cpp | 123 +++++--------------------- 1 file changed, 20 insertions(+), 103 deletions(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e11d02f2c12e1..4375c254c8361 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8949,104 +8949,11 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates) { - if (!AssumeLoopExits) { - // Handle BinOp conditions (And, Or). - if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( - Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) - return *LimitFromBinOp; - } else { - // Check if the controlling expression for this loop is an And or Or. - if (BinaryOperator *BO = dyn_cast(ExitCond)) { - if (BO->getOpcode() == Instruction::And) { - // Recurse on the operands of the and. - bool EitherMayExit = !ExitIfTrue; - ExitLimit EL0 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(0), ExitIfTrue, - ControlsOnlyExit && !EitherMayExit, AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(1), ExitIfTrue, - ControlsOnlyExit && !EitherMayExit, AllowPredicates); - const SCEV *BECount = getCouldNotCompute(); - const SCEV *MaxBECount = getCouldNotCompute(); - if (EitherMayExit) { - // Both conditions must be true for the loop to continue executing. - // Choose the less conservative count. - if (EL0.ExactNotTaken == getCouldNotCompute() || - EL1.ExactNotTaken == getCouldNotCompute()) - BECount = getCouldNotCompute(); - else - BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, - EL1.ExactNotTaken); - - if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) - MaxBECount = EL1.ConstantMaxNotTaken; - else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) - MaxBECount = EL0.ConstantMaxNotTaken; - else - MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken, - EL1.ConstantMaxNotTaken); - } else { - // Both conditions must be true at the same time for the loop to exit. - // For now, be conservative. - if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken) - MaxBECount = EL0.ConstantMaxNotTaken; - if (EL0.ExactNotTaken == EL1.ExactNotTaken) - BECount = EL0.ExactNotTaken; - } - // There are cases (e.g. PR26207) where computeExitLimitFromCond is able - // to be more aggressive when computing BECount than when computing - // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and - // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and - // EL1.ConstantMaxNotTaken to not. - if (isa(MaxBECount) && - !isa(BECount)) - MaxBECount = getConstant(getUnsignedRangeMax(BECount)); - - return ExitLimit(BECount, MaxBECount, MaxBECount, false, - {&EL0.Predicates, &EL1.Predicates}); - } - if (BO->getOpcode() == Instruction::Or) { - // Recurse on the operands of the or. - bool EitherMayExit = ExitIfTrue; - ExitLimit EL0 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(0), ExitIfTrue, - ControlsOnlyExit && !EitherMayExit, AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(1), ExitIfTrue, - ControlsOnlyExit && !EitherMayExit, AllowPredicates); - const SCEV *BECount = getCouldNotCompute(); - const SCEV *MaxBECount = getCouldNotCompute(); - if (EitherMayExit) { - // Both conditions must be false for the loop to continue executing. - // Choose the less conservative count. - if (EL0.ExactNotTaken == getCouldNotCompute() || - EL1.ExactNotTaken == getCouldNotCompute()) - BECount = getCouldNotCompute(); - else - BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, - EL1.ExactNotTaken); - - if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) - MaxBECount = EL1.ConstantMaxNotTaken; - else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) - MaxBECount = EL0.ConstantMaxNotTaken; - else - MaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken, - EL1.ConstantMaxNotTaken); - } else { - // Both conditions must be false at the same time for the loop to - // exit. For now, be conservative. - if (EL0.ConstantMaxNotTaken == EL1.ConstantMaxNotTaken) - MaxBECount = EL0.ConstantMaxNotTaken; - if (EL0.ExactNotTaken == EL1.ExactNotTaken) - BECount = EL0.ExactNotTaken; - } - return ExitLimit(BECount, MaxBECount, MaxBECount, false, - {&EL0.Predicates, &EL1.Predicates}); - } - } - } + // Handle BinOp conditions (And, Or). + if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( + Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates)) + return *LimitFromBinOp; // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. @@ -9139,6 +9046,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( const SCEV *SymbolicMaxBECount = getCouldNotCompute(); if (EitherMayExit) { bool UseSequentialUMin = !isa(ExitCond); + // Both conditions must be same for the loop to continue executing. // Choose the less conservative count. if (EL0.ExactNotTaken != getCouldNotCompute() && @@ -9146,6 +9054,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken, UseSequentialUMin); } + if (EL0.ConstantMaxNotTaken == getCouldNotCompute()) ConstantMaxBECount = EL1.ConstantMaxNotTaken; else if (EL1.ConstantMaxNotTaken == getCouldNotCompute()) @@ -9165,6 +9074,12 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( // For now, be conservative. if (EL0.ExactNotTaken == EL1.ExactNotTaken) BECount = EL0.ExactNotTaken; + // This was executed in Enzyme's must exit code under the + // logic for when the binary op was OR + if (AssumeLoopExits && !IsAnd) { + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + ConstantMaxBECount = EL0.ExactNotTaken; + } } // There are cases (e.g. PR26207) where computeExitLimitFromCond is able @@ -9173,12 +9088,14 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( // and // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and // EL1.ConstantMaxNotTaken to not. - if (isa(ConstantMaxBECount) && - !isa(BECount)) - ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - if (isa(SymbolicMaxBECount)) - SymbolicMaxBECount = - isa(BECount) ? ConstantMaxBECount : BECount; + if (!AssumeLoopExits || !IsAnd) { // should skip if assume exits and OR + if (isa(ConstantMaxBECount) && + !isa(BECount)) + ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); + if (isa(SymbolicMaxBECount)) + SymbolicMaxBECount = + isa(BECount) ? ConstantMaxBECount : BECount; + } return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false, { &EL0.Predicates, &EL1.Predicates }); } From 7597a836c99cae68a984d83dc3882e5f7ef6fcd3 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sat, 20 Apr 2024 16:21:50 -0500 Subject: [PATCH 14/15] implenting requested changes --- llvm/include/llvm/Analysis/ScalarEvolution.h | 4 +- .../llvm/Analysis/Utils/EnzymeFunctionUtils.h | 70 ------------------- llvm/lib/Analysis/ScalarEvolution.cpp | 49 +++++-------- 3 files changed, 19 insertions(+), 104 deletions(-) delete mode 100644 llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index 50dbe2aeec884..fa331c93d712e 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -460,9 +460,9 @@ class ScalarEvolution { LoopComputable ///< The SCEV varies predictably with the loop. }; - bool AssumeLoopExits = false; + bool AssumeLoopFinite = false; void setAssumeLoopExits(); - llvm::SmallPtrSet GuaranteedUnreachable; + SmallPtrSet GuaranteedUnreachable; /// An enum describing the relationship between a SCEV and a basic block. enum BlockDisposition { diff --git a/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h b/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h deleted file mode 100644 index 59032cbe6dddd..0000000000000 --- a/llvm/include/llvm/Analysis/Utils/EnzymeFunctionUtils.h +++ /dev/null @@ -1,70 +0,0 @@ - -#include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/LoopAnalysisManager.h" -#include "llvm/Analysis/TargetLibraryInfo.h" - -#include "llvm/IR/Function.h" - -#include "llvm/IR/Instructions.h" -#include "llvm/Transforms/Utils/ValueMapper.h" -#include - -// TODO note this doesn't go through [loop, unreachable], and we could get more -// performance by doing this can consider doing some domtree magic potentially -static inline llvm::SmallPtrSet -getGuaranteedUnreachable(llvm::Function *F) { - llvm::SmallPtrSet knownUnreachables; - if (F->empty()) - return knownUnreachables; - std::deque todo; - for (auto &BB : *F) { - todo.push_back(&BB); - } - - while (!todo.empty()) { - llvm::BasicBlock *next = todo.front(); - todo.pop_front(); - - if (knownUnreachables.find(next) != knownUnreachables.end()) - continue; - - if (llvm::isa(next->getTerminator())) - continue; - - if (llvm::isa(next->getTerminator())) { - knownUnreachables.insert(next); - for (llvm::BasicBlock *Pred : predecessors(next)) { - todo.push_back(Pred); - } - continue; - } - - // Assume resumes don't happen - // TODO consider EH - if (llvm::isa(next->getTerminator())) { - knownUnreachables.insert(next); - for (llvm::BasicBlock *Pred : predecessors(next)) { - todo.push_back(Pred); - } - continue; - } - - bool unreachable = true; - for (llvm::BasicBlock *Succ : llvm::successors(next)) { - if (knownUnreachables.find(Succ) == knownUnreachables.end()) { - unreachable = false; - break; - } - } - - if (!unreachable) - continue; - knownUnreachables.insert(next); - for (llvm::BasicBlock *Pred : llvm::predecessors(next)) { - todo.push_back(Pred); - } - continue; - } - - return knownUnreachables; -} \ No newline at end of file diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 4375c254c8361..67b491506b8ed 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -510,7 +510,7 @@ const SCEV *ScalarEvolution::getVScale(Type *Ty) { return S; } -void ScalarEvolution::setAssumeLoopExits() { this->AssumeLoopExits = true; } +void ScalarEvolution::setAssumeLoopExits() { this->AssumeLoopFinite = true; } SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty) @@ -7416,7 +7416,7 @@ bool ScalarEvolution::loopIsFiniteByAssumption(const Loop *L) { // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return AssumeLoopExits || isFinite(L) || + return AssumeLoopFinite || isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L)); } @@ -8832,7 +8832,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates) { - if (AssumeLoopExits) { + if (AssumeLoopFinite) { SmallVector ExitingBlocks; L->getExitingBlocks(ExitingBlocks); for (auto &ExitingBlock : ExitingBlocks) { @@ -8877,7 +8877,7 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, BasicBlock *Exit = nullptr; for (auto *SBB : successors(ExitingBlock)) if (!L->contains(SBB)) { - if (AssumeLoopExits and GuaranteedUnreachable.count(SBB)) + if (AssumeLoopFinite and GuaranteedUnreachable.count(SBB)) continue; if (Exit) // Multiple exit successors. return getCouldNotCompute(); @@ -8982,9 +8982,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( return getZero(CI->getType()); } - // If we're exiting based on the overflow flag of an x.with.overflow - // intrinsic with a constant step, we can form an equivalent icmp predicate - // and figure out how many iterations will be taken before we exit. + // If we're exiting based on the overflow flag of an x.with.overflow intrinsic + // with a constant step, we can form an equivalent icmp predicate and figure + // out how many iterations will be taken before we exit. const WithOverflowInst *WO; const APInt *C; if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) && @@ -9076,7 +9076,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( BECount = EL0.ExactNotTaken; // This was executed in Enzyme's must exit code under the // logic for when the binary op was OR - if (AssumeLoopExits && !IsAnd) { + if (AssumeLoopFinite && !IsAnd) { if (EL0.ExactNotTaken == EL1.ExactNotTaken) ConstantMaxBECount = EL0.ExactNotTaken; } @@ -9088,7 +9088,7 @@ ScalarEvolution::computeExitLimitFromCondFromBinOp( // and // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and // EL1.ConstantMaxNotTaken to not. - if (!AssumeLoopExits || !IsAnd) { // should skip if assume exits and OR + if (!AssumeLoopFinite || !IsAnd) { // should skip if assume exits and OR if (isa(ConstantMaxBECount) && !isa(BECount)) ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); @@ -9130,7 +9130,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, bool ControlsOnlyExit, bool AllowPredicates) { - if (AssumeLoopExits) { + if (AssumeLoopFinite) { #define PROP_PHI(LHS) \ if (auto un = dyn_cast(LHS)) { \ if (auto pn = dyn_cast_or_null(un->getValue())) { \ @@ -9250,7 +9250,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( case ICmpInst::ICMP_SLE: case ICmpInst::ICMP_ULE: - if (!AssumeLoopExits) { + if (!AssumeLoopFinite) { // Since the loop is finite, an invariant RHS cannot include the boundary // value, otherwise it would loop forever. if (!EnableFiniteLoopControl || !ControllingFiniteLoop || @@ -9263,21 +9263,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = ICmpInst::isSigned(Pred); - if (AssumeLoopExits) { - if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE) { - if (!isa(RHS->getType())) - break; - SmallVector sv = { - RHS, getConstant( - ConstantInt::get(cast(RHS->getType()), 1))}; - // Since this is not an infinite loop by induction, RHS cannot be - // int_max/uint_max Therefore adding 1 does not wrap. - if (IsSigned) - RHS = getAddExpr(sv, SCEV::FlagNSW); - else - RHS = getAddExpr(sv, SCEV::FlagNUW); - } - } ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit, AllowPredicates); if (EL.hasAnyInfo()) @@ -9286,7 +9271,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( } case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_UGE: - if (!AssumeLoopExits) { + if (!AssumeLoopFinite) { // Since the loop is finite, an invariant RHS cannot include the boundary // value, otherwise it would loop forever. if (!EnableFiniteLoopControl || !ControllingFiniteLoop || @@ -9298,7 +9283,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp( case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = ICmpInst::isSigned(Pred); - if (AssumeLoopExits) { + if (AssumeLoopFinite) { if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_UGE) { if (!isa(RHS->getType())) break; @@ -9340,7 +9325,7 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, // if not using enzyme executes by default // if using enzyme and the code is guaranteed unreachable, // the default destination doesn't matter - if (!AssumeLoopExits || + if (!AssumeLoopFinite || !GuaranteedUnreachable.count(Switch->getDefaultDest())) { assert(L->contains(Switch->getDefaultDest()) && "Default case must not exit the loop!"); @@ -12958,7 +12943,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // has not yet been sufficiently auditted or tested with negative strides. // We used to filter out all known-non-positive cases here, we're in the // process of being less restrictive bit by bit. - if (AssumeLoopExits && IsSigned && isKnownNonPositive(Stride)) + if (AssumeLoopFinite && IsSigned && isKnownNonPositive(Stride)) return getCouldNotCompute(); if (!isKnownNonZero(Stride)) { @@ -13097,7 +13082,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, // In the Enzyme MustExitScalarEvolutionCode, this check was missing // I do not have enough context to know if these two checks should be // mutually Exclusive. If they aren't then this bool check is unnecessary - if (!AssumeLoopExits) { + if (!AssumeLoopFinite) { const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L); const SCEV *GuardedStart = applyLoopGuards(OrigStart, L); @@ -13240,7 +13225,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa(ConstantMaxBECount) && !isa(BECount)) ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount)); - if (AssumeLoopExits) { + if (AssumeLoopFinite) { return ExitLimit(BECount, ConstantMaxBECount, ConstantMaxBECount, MaxOrZero, Predicates); } From 16d19f64f9ea7cf0ca3485e1ec3a88f48739070b Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sat, 20 Apr 2024 16:25:00 -0500 Subject: [PATCH 15/15] forgot to remove include for deleted file --- llvm/lib/Analysis/ScalarEvolution.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 67b491506b8ed..f261623dcd069 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -80,7 +80,6 @@ #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/Utils/EnzymeFunctionUtils.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/Argument.h"