@@ -15034,6 +15034,91 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
1503415034    if (MatchRangeCheckIdiom())
1503515035      return;
1503615036
15037+     // Return true if \p Expr is a MinMax SCEV expression with a constant
15038+     // operand. If so, return in \p SCTy the SCEV type and in \p RHS the
15039+     // non-constant operand and in \p LHS the constant operand.
15040+     auto IsMinMaxSCEVWithConstant = [&](const SCEV *Expr, SCEVTypes &SCTy,
15041+                                         const SCEV *&LHS, const SCEV *&RHS) {
15042+       if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15043+         if (MinMax->getNumOperands() != 2)
15044+           return false;
15045+         SCTy = MinMax->getSCEVType();
15046+         if (!isa<SCEVConstant>(MinMax->getOperand(0)))
15047+           return false;
15048+         LHS = MinMax->getOperand(0);
15049+         RHS = MinMax->getOperand(1);
15050+         return true;
15051+       }
15052+       return false;
15053+     };
15054+ 
15055+     // Checks whether Expr is a non-negative constant, and Divisor is a positive
15056+     // constant, and returns their APInt in ExprVal and in DivisorVal.
15057+     auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15058+                                           APInt &ExprVal, APInt &DivisorVal) {
15059+       if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor))
15060+         return false;
15061+       auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15062+       auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15063+       if (!ConstExpr || !ConstDivisor)
15064+         return false;
15065+       ExprVal = ConstExpr->getAPInt();
15066+       DivisorVal = ConstDivisor->getAPInt();
15067+       return true;
15068+     };
15069+ 
15070+     // Return a new SCEV that modifies \p Expr to the closest number divides by
15071+     // \p Divisor and greater or equal than Expr.
15072+     // For now, only handle constant Expr and Divisor.
15073+     auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15074+                                            const SCEV *Divisor) {
15075+       APInt ExprVal;
15076+       APInt DivisorVal;
15077+       if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15078+         return Expr;
15079+       APInt Rem = ExprVal.urem(DivisorVal);
15080+       if (!Rem.isZero())
15081+         // return the SCEV: Expr + Divisor - Expr % Divisor
15082+         return getConstant(ExprVal + DivisorVal - Rem);
15083+       return Expr;
15084+     };
15085+ 
15086+     // Return a new SCEV that modifies \p Expr to the closest number divides by
15087+     // \p Divisor and less or equal than Expr.
15088+     // For now, only handle constant Expr and Divisor.
15089+     auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15090+                                                const SCEV *Divisor) {
15091+       APInt ExprVal;
15092+       APInt DivisorVal;
15093+       if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15094+         return Expr;
15095+       APInt Rem = ExprVal.urem(DivisorVal);
15096+       // return the SCEV: Expr - Expr % Divisor
15097+       return getConstant(ExprVal - Rem);
15098+     };
15099+ 
15100+     // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15101+     // recursively. This is done by aligning up/down the constant value to the
15102+     // Divisor.
15103+     std::function<const SCEV *(const SCEV *, const SCEV *)>
15104+         ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15105+                                            const SCEV *Divisor) {
15106+           const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15107+           SCEVTypes SCTy;
15108+           if (!IsMinMaxSCEVWithConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS))
15109+             return MinMaxExpr;
15110+           auto IsMin =
15111+               isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15112+           assert(isKnownNonNegative(MinMaxLHS) &&
15113+                  "Expected non-negative operand!");
15114+           auto *DivisibleExpr =
15115+               IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15116+                     : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15117+           SmallVector<const SCEV *> Ops = {
15118+               ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15119+           return getMinMaxExpr(SCTy, Ops);
15120+         };
15121+ 
1503715122    // If we have LHS == 0, check if LHS is computing a property of some unknown
1503815123    // SCEV %v which we can rewrite %v to express explicitly.
1503915124    const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
@@ -15045,7 +15130,12 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
1504515130      const SCEV *URemRHS = nullptr;
1504615131      if (matchURem(LHS, URemLHS, URemRHS)) {
1504715132        if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15048-           const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
15133+           auto I = RewriteMap.find(LHSUnknown);
15134+           const SCEV *RewrittenLHS =
15135+               I != RewriteMap.end() ? I->second : LHSUnknown;
15136+           RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15137+           const auto *Multiple =
15138+               getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
1504915139          RewriteMap[LHSUnknown] = Multiple;
1505015140          ExprsToRewrite.push_back(LHSUnknown);
1505115141          return;
@@ -15068,48 +15158,128 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
1506815158    auto I = RewriteMap.find(LHS);
1506915159    const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS;
1507015160
15161+     // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15162+     // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15163+     // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15164+     // /u B) * B was found, and return the divisor B in \p DividesBy. For
15165+     // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15166+     // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15167+     // DividesBy.
15168+     std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15169+         [&](const SCEV *Expr, const SCEV *&DividesBy) {
15170+           if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15171+             if (Mul->getNumOperands() != 2)
15172+               return false;
15173+             auto *MulLHS = Mul->getOperand(0);
15174+             auto *MulRHS = Mul->getOperand(1);
15175+             if (isa<SCEVConstant>(MulLHS))
15176+               std::swap(MulLHS, MulRHS);
15177+             if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS)) {
15178+               if (Div->getOperand(1) == MulRHS) {
15179+                 DividesBy = MulRHS;
15180+                 return true;
15181+               }
15182+             }
15183+           }
15184+           if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15185+             return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15186+                    HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15187+           }
15188+           return false;
15189+         };
15190+ 
15191+     // Return true if Expr known to divide by \p DividesBy.
15192+     std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15193+         [&](const SCEV *Expr, const SCEV *DividesBy) {
15194+           if (getURemExpr(Expr, DividesBy)->isZero())
15195+             return true;
15196+           if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15197+             return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15198+                    IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15199+           }
15200+           return false;
15201+         };
15202+ 
15203+     const SCEV *DividesBy = nullptr;
15204+     if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15205+       // Check that the whole expression is divided by DividesBy
15206+       DividesBy =
15207+           IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15208+ 
1507115209    const SCEV *RewrittenRHS = nullptr;
1507215210    switch (Predicate) {
1507315211    case CmpInst::ICMP_ULT: {
1507415212      if (RHS->getType()->isPointerTy())
1507515213        break;
1507615214      const SCEV *One = getOne(RHS->getType());
15077-       RewrittenRHS =
15078-           getUMinExpr(RewrittenLHS, getMinusSCEV(getUMaxExpr(RHS, One), One));
15215+       auto *ModifiedRHS = getMinusSCEV(getUMaxExpr(RHS, One), One);
15216+       ModifiedRHS =
15217+           DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15218+                     : ModifiedRHS;
15219+       RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS);
1507915220      break;
1508015221    }
15081-     case CmpInst::ICMP_SLT:
15082-       RewrittenRHS =
15083-           getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType())));
15222+     case CmpInst::ICMP_SLT: {
15223+       auto *ModifiedRHS = getMinusSCEV(RHS, getOne(RHS->getType()));
15224+       ModifiedRHS =
15225+           DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15226+                     : ModifiedRHS;
15227+       RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS);
1508415228      break;
15085-     case CmpInst::ICMP_ULE:
15086-       RewrittenRHS = getUMinExpr(RewrittenLHS, RHS);
15229+     }
15230+     case CmpInst::ICMP_ULE: {
15231+       auto *ModifiedRHS =
15232+           DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15233+       RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS);
1508715234      break;
15088-     case CmpInst::ICMP_SLE:
15089-       RewrittenRHS = getSMinExpr(RewrittenLHS, RHS);
15235+     }
15236+     case CmpInst::ICMP_SLE: {
15237+       auto *ModifiedRHS =
15238+           DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15239+       RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS);
1509015240      break;
15091-     case CmpInst::ICMP_UGT:
15092-       RewrittenRHS =
15093-           getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
15241+     }
15242+     case CmpInst::ICMP_UGT: {
15243+       auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType()));
15244+       ModifiedRHS = DividesBy
15245+                         ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15246+                         : ModifiedRHS;
15247+       RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
1509415248      break;
15095-     case CmpInst::ICMP_SGT:
15096-       RewrittenRHS =
15097-           getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType())));
15249+     }
15250+     case CmpInst::ICMP_SGT: {
15251+       auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType()));
15252+       ModifiedRHS = DividesBy
15253+                         ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15254+                         : ModifiedRHS;
15255+       RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS);
1509815256      break;
15099-     case CmpInst::ICMP_UGE:
15100-       RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS);
15257+     }
15258+     case CmpInst::ICMP_UGE: {
15259+       auto *ModifiedRHS =
15260+           DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15261+       RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
1510115262      break;
15102-     case CmpInst::ICMP_SGE:
15103-       RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS);
15263+     }
15264+     case CmpInst::ICMP_SGE: {
15265+       auto *ModifiedRHS =
15266+           DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15267+       RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS);
1510415268      break;
15269+     }
1510515270    case CmpInst::ICMP_EQ:
1510615271      if (isa<SCEVConstant>(RHS))
1510715272        RewrittenRHS = RHS;
1510815273      break;
1510915274    case CmpInst::ICMP_NE:
1511015275      if (isa<SCEVConstant>(RHS) &&
15111-           cast<SCEVConstant>(RHS)->getValue()->isNullValue())
15112-         RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType()));
15276+           cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15277+         auto *ModifiedRHS = getOne(RHS->getType());
15278+         ModifiedRHS = DividesBy
15279+                           ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy)
15280+                           : ModifiedRHS;
15281+         RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS);
15282+       }
1511315283      break;
1511415284    default:
1511515285      break;
0 commit comments