diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index c39de9d72cf49..51b15510a3add 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -219,7 +219,7 @@ class SCEVPredicate : public FoldingSetNode { FoldingSetNodeIDRef FastID; public: - enum SCEVPredicateKind { P_Union, P_Equal, P_Wrap }; + enum SCEVPredicateKind { P_Union, P_Compare, P_Wrap }; protected: SCEVPredicateKind Kind; @@ -276,16 +276,18 @@ struct FoldingSetTrait : DefaultFoldingSetTrait { } }; -/// This class represents an assumption that two SCEV expressions are equal, -/// and this can be checked at run-time. -class SCEVEqualPredicate final : public SCEVPredicate { - /// We assume that LHS == RHS. +/// This class represents an assumption that the expression LHS Pred RHS +/// evaluates to true, and this can be checked at run-time. +class SCEVComparePredicate final : public SCEVPredicate { + /// We assume that LHS Pred RHS is true. + const ICmpInst::Predicate Pred; const SCEV *LHS; const SCEV *RHS; public: - SCEVEqualPredicate(const FoldingSetNodeIDRef ID, const SCEV *LHS, - const SCEV *RHS); + SCEVComparePredicate(const FoldingSetNodeIDRef ID, + const ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS); /// Implementation of the SCEVPredicate interface bool implies(const SCEVPredicate *N) const override; @@ -293,15 +295,17 @@ class SCEVEqualPredicate final : public SCEVPredicate { bool isAlwaysTrue() const override; const SCEV *getExpr() const override; - /// Returns the left hand side of the equality. + ICmpInst::Predicate getPredicate() const { return Pred; } + + /// Returns the left hand side of the predicate. const SCEV *getLHS() const { return LHS; } - /// Returns the right hand side of the equality. + /// Returns the right hand side of the predicate. const SCEV *getRHS() const { return RHS; } /// Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const SCEVPredicate *P) { - return P->getKind() == P_Equal; + return P->getKind() == P_Compare; } }; diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h index 277eb7acf238c..60b772b94a6f7 100644 --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -293,8 +293,9 @@ class SCEVExpander : public SCEVVisitor { Value *expandCodeForPredicate(const SCEVPredicate *Pred, Instruction *Loc); /// A specialized variant of expandCodeForPredicate, handling the case when - /// we are expanding code for a SCEVEqualPredicate. - Value *expandEqualPredicate(const SCEVEqualPredicate *Pred, Instruction *Loc); + /// we are expanding code for a SCEVComparePredicate. + Value *expandComparePredicate(const SCEVComparePredicate *Pred, + Instruction *Loc); /// Generates code that evaluates if the \p AR expression will overflow. Value *generateOverflowCheck(const SCEVAddRecExpr *AR, Instruction *Loc, diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 977fc09113550..fca5614c7469e 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -13541,14 +13541,15 @@ const SCEVPredicate *ScalarEvolution::getEqualPredicate(const SCEV *LHS, assert(LHS->getType() == RHS->getType() && "Type mismatch between LHS and RHS"); // Unique this node based on the arguments - ID.AddInteger(SCEVPredicate::P_Equal); + ID.AddInteger(SCEVPredicate::P_Compare); + ID.AddInteger(ICmpInst::ICMP_EQ); ID.AddPointer(LHS); ID.AddPointer(RHS); void *IP = nullptr; if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) return S; - SCEVEqualPredicate *Eq = new (SCEVAllocator) - SCEVEqualPredicate(ID.Intern(SCEVAllocator), LHS, RHS); + SCEVComparePredicate *Eq = new (SCEVAllocator) + SCEVComparePredicate(ID.Intern(SCEVAllocator), ICmpInst::ICMP_EQ, LHS, RHS); UniquePreds.InsertNode(Eq, IP); return Eq; } @@ -13594,8 +13595,9 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { if (Pred) { auto ExprPreds = Pred->getPredicatesForExpr(Expr); for (auto *Pred : ExprPreds) - if (const auto *IPred = dyn_cast(Pred)) - if (IPred->getLHS() == Expr) + if (const auto *IPred = dyn_cast(Pred)) + if (IPred->getLHS() == Expr && + IPred->getPredicate() == ICmpInst::ICMP_EQ) return IPred->getRHS(); } return convertToAddRecWithPreds(Expr); @@ -13715,28 +13717,38 @@ SCEVPredicate::SCEVPredicate(const FoldingSetNodeIDRef ID, SCEVPredicateKind Kind) : FastID(ID), Kind(Kind) {} -SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, - const SCEV *LHS, const SCEV *RHS) - : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) { +SCEVComparePredicate::SCEVComparePredicate(const FoldingSetNodeIDRef ID, + const ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) + : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) { assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match"); assert(LHS != RHS && "LHS and RHS are the same SCEV"); } -bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { - const auto *Op = dyn_cast(N); +bool SCEVComparePredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast(N); if (!Op) return false; + if (Pred != ICmpInst::ICMP_EQ) + return false; + return Op->LHS == LHS && Op->RHS == RHS; } -bool SCEVEqualPredicate::isAlwaysTrue() const { return false; } +bool SCEVComparePredicate::isAlwaysTrue() const { return false; } -const SCEV *SCEVEqualPredicate::getExpr() const { return LHS; } +const SCEV *SCEVComparePredicate::getExpr() const { return LHS; } + +void SCEVComparePredicate::print(raw_ostream &OS, unsigned Depth) const { + if (Pred == ICmpInst::ICMP_EQ) + OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; + else + OS.indent(Depth) << "Compare predicate: " << *LHS + << " " << CmpInst::getPredicateName(Pred) << ") " + << *RHS << "\n"; -void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { - OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 277431fbbfd77..fbd42ce41a999 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -2469,8 +2469,8 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, switch (Pred->getKind()) { case SCEVPredicate::P_Union: return expandUnionPredicate(cast(Pred), IP); - case SCEVPredicate::P_Equal: - return expandEqualPredicate(cast(Pred), IP); + case SCEVPredicate::P_Compare: + return expandComparePredicate(cast(Pred), IP); case SCEVPredicate::P_Wrap: { auto *AddRecPred = cast(Pred); return expandWrapPredicate(AddRecPred, IP); @@ -2479,15 +2479,16 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, llvm_unreachable("Unknown SCEV predicate type"); } -Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, - Instruction *IP) { +Value *SCEVExpander::expandComparePredicate(const SCEVComparePredicate *Pred, + Instruction *IP) { Value *Expr0 = expandCodeForImpl(Pred->getLHS(), Pred->getLHS()->getType(), IP, false); Value *Expr1 = expandCodeForImpl(Pred->getRHS(), Pred->getRHS()->getType(), IP, false); Builder.SetInsertPoint(IP); - auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check"); + auto InvPred = ICmpInst::getInversePredicate(Pred->getPredicate()); + auto *I = Builder.CreateICmp(InvPred, Expr0, Expr1, "ident.check"); return I; }