diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 4a5680d6a23df6..5c2c12c7d0591d 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15034,91 +15034,6 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { if (MatchRangeCheckIdiom()) return; - // Return true if \p Expr is a MinMax SCEV expression with a constant - // operand. If so, return in \p SCTy the SCEV type and in \p RHS the - // non-constant operand and in \p LHS the constant operand. - auto IsMinMaxSCEVWithConstant = [&](const SCEV *Expr, SCEVTypes &SCTy, - const SCEV *&LHS, const SCEV *&RHS) { - if (auto *MinMax = dyn_cast(Expr)) { - if (MinMax->getNumOperands() != 2) - return false; - SCTy = MinMax->getSCEVType(); - if (!isa(MinMax->getOperand(0))) - return false; - LHS = MinMax->getOperand(0); - RHS = MinMax->getOperand(1); - return true; - } - return false; - }; - - // Checks whether Expr is a non-negative constant, and Divisor is a positive - // constant, and returns their APInt in ExprVal and in DivisorVal. - auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor, - APInt &ExprVal, APInt &DivisorVal) { - if (!isKnownNonNegative(Expr) || !isKnownPositive(Divisor)) - return false; - auto *ConstExpr = dyn_cast(Expr); - auto *ConstDivisor = dyn_cast(Divisor); - if (!ConstExpr || !ConstDivisor) - return false; - ExprVal = ConstExpr->getAPInt(); - DivisorVal = ConstDivisor->getAPInt(); - return true; - }; - - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and greater or equal than Expr. - // For now, only handle constant Expr and Divisor. - auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { - APInt ExprVal; - APInt DivisorVal; - if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) - return Expr; - APInt Rem = ExprVal.urem(DivisorVal); - if (!Rem.isZero()) - // return the SCEV: Expr + Divisor - Expr % Divisor - return getConstant(ExprVal + DivisorVal - Rem); - return Expr; - }; - - // Return a new SCEV that modifies \p Expr to the closest number divides by - // \p Divisor and less or equal than Expr. - // For now, only handle constant Expr and Divisor. - auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr, - const SCEV *Divisor) { - APInt ExprVal; - APInt DivisorVal; - if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal)) - return Expr; - APInt Rem = ExprVal.urem(DivisorVal); - // return the SCEV: Expr - Expr % Divisor - return getConstant(ExprVal - Rem); - }; - - // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, - // recursively. This is done by aligning up/down the constant value to the - // Divisor. - std::function - ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, - const SCEV *Divisor) { - const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; - SCEVTypes SCTy; - if (!IsMinMaxSCEVWithConstant(MinMaxExpr, SCTy, MinMaxLHS, MinMaxRHS)) - return MinMaxExpr; - auto IsMin = - isa(MinMaxExpr) || isa(MinMaxExpr); - assert(isKnownNonNegative(MinMaxLHS) && - "Expected non-negative operand!"); - auto *DivisibleExpr = - IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor) - : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor); - SmallVector Ops = { - ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; - return getMinMaxExpr(SCTy, Ops); - }; - // If we have LHS == 0, check if LHS is computing a property of some unknown // SCEV %v which we can rewrite %v to express explicitly. const SCEVConstant *RHSC = dyn_cast(RHS); @@ -15130,12 +15045,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { const SCEV *URemRHS = nullptr; if (matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - auto I = RewriteMap.find(LHSUnknown); - const SCEV *RewrittenLHS = - I != RewriteMap.end() ? I->second : LHSUnknown; - RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); - const auto *Multiple = - getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS); + const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; ExprsToRewrite.push_back(LHSUnknown); return; @@ -15158,128 +15068,48 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { auto I = RewriteMap.find(LHS); const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS; - // Check for the SCEV expression (A /u B) * B while B is a constant, inside - // \p Expr. The check is done recuresively on \p Expr, which is assumed to - // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A - // /u B) * B was found, and return the divisor B in \p DividesBy. For - // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since - // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p - // DividesBy. - std::function HasDivisibiltyInfo = - [&](const SCEV *Expr, const SCEV *&DividesBy) { - if (auto *Mul = dyn_cast(Expr)) { - if (Mul->getNumOperands() != 2) - return false; - auto *MulLHS = Mul->getOperand(0); - auto *MulRHS = Mul->getOperand(1); - if (isa(MulLHS)) - std::swap(MulLHS, MulRHS); - if (auto *Div = dyn_cast(MulLHS)) { - if (Div->getOperand(1) == MulRHS) { - DividesBy = MulRHS; - return true; - } - } - } - if (auto *MinMax = dyn_cast(Expr)) { - return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) || - HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy); - } - return false; - }; - - // Return true if Expr known to divide by \p DividesBy. - std::function IsKnownToDivideBy = - [&](const SCEV *Expr, const SCEV *DividesBy) { - if (getURemExpr(Expr, DividesBy)->isZero()) - return true; - if (auto *MinMax = dyn_cast(Expr)) { - return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) && - IsKnownToDivideBy(MinMax->getOperand(1), DividesBy); - } - return false; - }; - - const SCEV *DividesBy = nullptr; - if (HasDivisibiltyInfo(RewrittenLHS, DividesBy)) - // Check that the whole expression is divided by DividesBy - DividesBy = - IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr; - const SCEV *RewrittenRHS = nullptr; switch (Predicate) { case CmpInst::ICMP_ULT: { if (RHS->getType()->isPointerTy()) break; const SCEV *One = getOne(RHS->getType()); - auto *ModifiedRHS = getMinusSCEV(getUMaxExpr(RHS, One), One); - ModifiedRHS = - DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy) - : ModifiedRHS; - RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS); + RewrittenRHS = + getUMinExpr(RewrittenLHS, getMinusSCEV(getUMaxExpr(RHS, One), One)); break; } - case CmpInst::ICMP_SLT: { - auto *ModifiedRHS = getMinusSCEV(RHS, getOne(RHS->getType())); - ModifiedRHS = - DividesBy ? GetPreviousSCEVDividesByDivisor(ModifiedRHS, DividesBy) - : ModifiedRHS; - RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS); + case CmpInst::ICMP_SLT: + RewrittenRHS = + getSMinExpr(RewrittenLHS, getMinusSCEV(RHS, getOne(RHS->getType()))); break; - } - case CmpInst::ICMP_ULE: { - auto *ModifiedRHS = - DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; - RewrittenRHS = getUMinExpr(RewrittenLHS, ModifiedRHS); + case CmpInst::ICMP_ULE: + RewrittenRHS = getUMinExpr(RewrittenLHS, RHS); break; - } - case CmpInst::ICMP_SLE: { - auto *ModifiedRHS = - DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS; - RewrittenRHS = getSMinExpr(RewrittenLHS, ModifiedRHS); + case CmpInst::ICMP_SLE: + RewrittenRHS = getSMinExpr(RewrittenLHS, RHS); break; - } - case CmpInst::ICMP_UGT: { - auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType())); - ModifiedRHS = DividesBy - ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) - : ModifiedRHS; - RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); + case CmpInst::ICMP_UGT: + RewrittenRHS = + getUMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); break; - } - case CmpInst::ICMP_SGT: { - auto *ModifiedRHS = getAddExpr(RHS, getOne(RHS->getType())); - ModifiedRHS = DividesBy - ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) - : ModifiedRHS; - RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS); + case CmpInst::ICMP_SGT: + RewrittenRHS = + getSMaxExpr(RewrittenLHS, getAddExpr(RHS, getOne(RHS->getType()))); break; - } - case CmpInst::ICMP_UGE: { - auto *ModifiedRHS = - DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; - RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); + case CmpInst::ICMP_UGE: + RewrittenRHS = getUMaxExpr(RewrittenLHS, RHS); break; - } - case CmpInst::ICMP_SGE: { - auto *ModifiedRHS = - DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS; - RewrittenRHS = getSMaxExpr(RewrittenLHS, ModifiedRHS); + case CmpInst::ICMP_SGE: + RewrittenRHS = getSMaxExpr(RewrittenLHS, RHS); break; - } case CmpInst::ICMP_EQ: if (isa(RHS)) RewrittenRHS = RHS; break; case CmpInst::ICMP_NE: if (isa(RHS) && - cast(RHS)->getValue()->isNullValue()) { - auto *ModifiedRHS = getOne(RHS->getType()); - ModifiedRHS = DividesBy - ? GetNextSCEVDividesByDivisor(ModifiedRHS, DividesBy) - : ModifiedRHS; - RewrittenRHS = getUMaxExpr(RewrittenLHS, ModifiedRHS); - } + cast(RHS)->getValue()->isNullValue()) + RewrittenRHS = getUMaxExpr(RewrittenLHS, getOne(RHS->getType())); break; default: break; diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll index 492ed9c4d26537..cfa91e3cc74731 100644 --- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll @@ -125,7 +125,7 @@ define void @test_trip_multiple_4_ugt_5_order_swapped(i32 %num) { ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 4 +; CHECK: Loop %for.body: Trip multiple is 2 ; entry: %u = urem i32 %num, 4 @@ -196,7 +196,7 @@ define void @test_trip_multiple_4_sgt_5_order_swapped(i32 %num) { ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 4 +; CHECK: Loop %for.body: Trip multiple is 2 ; entry: %u = urem i32 %num, 4 @@ -267,7 +267,7 @@ define void @test_trip_multiple_4_uge_5_order_swapped(i32 %num) { ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 4 +; CHECK: Loop %for.body: Trip multiple is 1 ; entry: %u = urem i32 %num, 4 @@ -338,7 +338,7 @@ define void @test_trip_multiple_4_sge_5_order_swapped(i32 %num) { ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 4 +; CHECK: Loop %for.body: Trip multiple is 1 ; entry: %u = urem i32 %num, 4 @@ -409,7 +409,7 @@ define void @test_trip_multiple_4_upper_lower_bounds(i32 %num) { ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 4 +; CHECK: Loop %for.body: Trip multiple is 1 ; entry: %cmp.1 = icmp uge i32 %num, 5 @@ -446,7 +446,7 @@ define void @test_trip_multiple_4_upper_lower_bounds_swapped1(i32 %num) { ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 4 +; CHECK: Loop %for.body: Trip multiple is 1 ; entry: %cmp.1 = icmp uge i32 %num, 5 @@ -483,7 +483,7 @@ define void @test_trip_multiple_4_upper_lower_bounds_swapped2(i32 %num) { ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 4 +; CHECK: Loop %for.body: Trip multiple is 1 ; entry: %cmp.1 = icmp uge i32 %num, 5 diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp index d0fec1ab2c570d..8756e2c66c25a6 100644 --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1744,42 +1744,4 @@ TEST_F(ScalarEvolutionsTest, ComputeMaxTripCountFromMultiDemArray) { }); } -TEST_F(ScalarEvolutionsTest, ApplyLoopGuards) { - LLVMContext C; - SMDiagnostic Err; - std::unique_ptr M = parseAssemblyString( - "declare void @llvm.assume(i1)\n" - "define void @test(i32 %num) {\n" - "entry:\n" - " %u = urem i32 %num, 4\n" - " %cmp = icmp eq i32 %u, 0\n" - " tail call void @llvm.assume(i1 %cmp)\n" - " %cmp.1 = icmp ugt i32 %num, 0\n" - " tail call void @llvm.assume(i1 %cmp.1)\n" - " br label %for.body\n" - "for.body:\n" - " %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]\n" - " %inc = add nuw nsw i32 %i.010, 1\n" - " %cmp2 = icmp ult i32 %inc, %num\n" - " br i1 %cmp2, label %for.body, label %exit\n" - "exit:\n" - " ret void\n" - "}\n", - Err, C); - - ASSERT_TRUE(M && "Could not parse module?"); - ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); - - runWithSE(*M, "test", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { - auto *TCScev = SE.getSCEV(getArgByName(F, "num")); - auto *ApplyLoopGuardsTC = SE.applyLoopGuards(TCScev, *LI.begin()); - // Assert that the new TC is (4 * ((4 umax %num) /u 4)) - APInt Four(32, 4); - auto *Constant4 = SE.getConstant(Four); - auto *Max = SE.getUMaxExpr(TCScev, Constant4); - auto *Mul = SE.getMulExpr(SE.getUDivExpr(Max, Constant4), Constant4); - ASSERT_TRUE(Mul == ApplyLoopGuardsTC); - }); -} - } // end namespace llvm