diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h index e5a6c8cc0a6aa..3d3ec14796bc1 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1345,6 +1345,7 @@ class ScalarEvolution { class LoopGuards { DenseMap RewriteMap; + SmallDenseSet> NotEqual; bool PreserveNUW = false; bool PreserveNSW = false; ScalarEvolution &SE; diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e06b0956d4c82..425420fd698ef 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15740,19 +15740,26 @@ void ScalarEvolution::LoopGuards::collectFromBlock( GetNextSCEVDividesByDivisor(One, DividesBy); To = SE.getUMaxExpr(FromRewritten, OneAlignedUp); } else { + // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS), + // but creating the subtraction eagerly is expensive. Track the + // inequalities in a separate map, and materialize the rewrite lazily + // when encountering a suitable subtraction while re-writing. if (LHS->getType()->isPointerTy()) { LHS = SE.getLosslessPtrToIntExpr(LHS); RHS = SE.getLosslessPtrToIntExpr(RHS); if (isa(LHS) || isa(RHS)) break; } - auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) { - const SCEV *Sub = SE.getMinusSCEV(A, B); - AddRewrite(Sub, Sub, - SE.getUMaxExpr(Sub, SE.getOne(From->getType()))); - }; - AddSubRewrite(LHS, RHS); - AddSubRewrite(RHS, LHS); + const SCEVConstant *C; + const SCEV *A, *B; + if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) && + match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) { + RHS = A; + LHS = B; + } + if (LHS > RHS) + std::swap(LHS, RHS); + Guards.NotEqual.insert({LHS, RHS}); continue; } break; @@ -15886,13 +15893,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { class SCEVLoopGuardRewriter : public SCEVRewriteVisitor { const DenseMap ⤅ + const SmallDenseSet> ≠ SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap; public: SCEVLoopGuardRewriter(ScalarEvolution &SE, const ScalarEvolution::LoopGuards &Guards) - : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) { + : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap), + NotEqual(Guards.NotEqual) { if (Guards.PreserveNUW) FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW); if (Guards.PreserveNSW) @@ -15947,14 +15956,36 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + // Helper to check if S is a subtraction (A - B) where A != B, and if so, + // return UMax(S, 1). + auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * { + const SCEV *LHS, *RHS; + if (MatchBinarySub(S, LHS, RHS)) { + if (LHS > RHS) + std::swap(LHS, RHS); + if (NotEqual.contains({LHS, RHS})) + return SE.getUMaxExpr(S, SE.getOne(S->getType())); + } + return nullptr; + }; + + // Check if Expr itself is a subtraction pattern with guard info. + if (const SCEV *Rewritten = RewriteSubtraction(Expr)) + return Rewritten; + // Trip count expressions sometimes consist of adding 3 operands, i.e. // (Const + A + B). There may be guard info for A + B, and if so, apply // it. // TODO: Could more generally apply guards to Add sub-expressions. if (isa(Expr->getOperand(0)) && Expr->getNumOperands() == 3) { - if (const SCEV *S = Map.lookup( - SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)))) + const SCEV *Add = + SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2)); + if (const SCEV *Rewritten = RewriteSubtraction(Add)) + return SE.getAddExpr( + Expr->getOperand(0), Rewritten, + ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask)); + if (const SCEV *S = Map.lookup(Add)) return SE.getAddExpr(Expr->getOperand(0), S); } SmallVector Operands; @@ -15989,7 +16020,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const { } }; - if (RewriteMap.empty()) + if (RewriteMap.empty() && NotEqual.empty()) return Expr; SCEVLoopGuardRewriter Rewriter(SE, *this); diff --git a/llvm/test/Transforms/IndVarSimplify/pointer-loop-guards.ll b/llvm/test/Transforms/IndVarSimplify/pointer-loop-guards.ll index 6732efcc38926..dbd572d611632 100644 --- a/llvm/test/Transforms/IndVarSimplify/pointer-loop-guards.ll +++ b/llvm/test/Transforms/IndVarSimplify/pointer-loop-guards.ll @@ -111,7 +111,6 @@ define void @test_sub_cmp(ptr align 8 %start, ptr %end) { ; N32-NEXT: [[CMP_ENTRY:%.*]] = icmp eq ptr [[START]], [[END]] ; N32-NEXT: br i1 [[CMP_ENTRY]], label %[[EXIT:.*]], label %[[LOOP_HEADER_PREHEADER:.*]] ; N32: [[LOOP_HEADER_PREHEADER]]: -; N32-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[PTR_DIFF]], i64 1) ; N32-NEXT: br label %[[LOOP_HEADER:.*]] ; N32: [[LOOP_HEADER]]: ; N32-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[LOOP_LATCH:.*]] ], [ 0, %[[LOOP_HEADER_PREHEADER]] ] @@ -119,7 +118,7 @@ define void @test_sub_cmp(ptr align 8 %start, ptr %end) { ; N32-NEXT: br i1 [[C_1]], label %[[EXIT_EARLY:.*]], label %[[LOOP_LATCH]] ; N32: [[LOOP_LATCH]]: ; N32-NEXT: [[IV_NEXT]] = add nuw i64 [[IV]], 1 -; N32-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[UMAX]] +; N32-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[PTR_DIFF]] ; N32-NEXT: br i1 [[EXITCOND]], label %[[LOOP_HEADER]], label %[[EXIT_LOOPEXIT:.*]] ; N32: [[EXIT_EARLY]]: ; N32-NEXT: br label %[[EXIT]] @@ -162,13 +161,17 @@ define void @test_ptr_diff_with_assume(ptr align 8 %start, ptr align 8 %end, ptr ; CHECK-NEXT: [[PTR_DIFF:%.*]] = sub i64 [[START_INT]], [[END_INT]] ; CHECK-NEXT: [[DIFF_CMP:%.*]] = icmp ult i64 [[PTR_DIFF]], 2 ; CHECK-NEXT: call void @llvm.assume(i1 [[DIFF_CMP]]) +; CHECK-NEXT: [[COMPUTED_END:%.*]] = getelementptr i8, ptr [[START]], i64 [[PTR_DIFF]] ; CHECK-NEXT: [[ENTRY_CMP:%.*]] = icmp eq ptr [[START]], [[END]] ; CHECK-NEXT: br i1 [[ENTRY_CMP]], label %[[EXIT:.*]], label %[[LOOP_BODY_PREHEADER:.*]] ; CHECK: [[LOOP_BODY_PREHEADER]]: ; CHECK-NEXT: br label %[[LOOP_BODY:.*]] ; CHECK: [[LOOP_BODY]]: +; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[IV_NEXT:%.*]], %[[LOOP_BODY]] ], [ [[START]], %[[LOOP_BODY_PREHEADER]] ] ; CHECK-NEXT: [[TMP0:%.*]] = call i1 @cond() -; CHECK-NEXT: br i1 true, label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]] +; CHECK-NEXT: [[IV_NEXT]] = getelementptr i8, ptr [[IV]], i64 1 +; CHECK-NEXT: [[LOOP_CMP:%.*]] = icmp eq ptr [[IV_NEXT]], [[COMPUTED_END]] +; CHECK-NEXT: br i1 [[LOOP_CMP]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]] ; CHECK: [[EXIT_LOOPEXIT]]: ; CHECK-NEXT: br label %[[EXIT]] ; CHECK: [[EXIT]]: @@ -182,13 +185,17 @@ define void @test_ptr_diff_with_assume(ptr align 8 %start, ptr align 8 %end, ptr ; N32-NEXT: [[PTR_DIFF:%.*]] = sub i64 [[START_INT]], [[END_INT]] ; N32-NEXT: [[DIFF_CMP:%.*]] = icmp ult i64 [[PTR_DIFF]], 2 ; N32-NEXT: call void @llvm.assume(i1 [[DIFF_CMP]]) +; N32-NEXT: [[COMPUTED_END:%.*]] = getelementptr i8, ptr [[START]], i64 [[PTR_DIFF]] ; N32-NEXT: [[ENTRY_CMP:%.*]] = icmp eq ptr [[START]], [[END]] ; N32-NEXT: br i1 [[ENTRY_CMP]], label %[[EXIT:.*]], label %[[LOOP_BODY_PREHEADER:.*]] ; N32: [[LOOP_BODY_PREHEADER]]: ; N32-NEXT: br label %[[LOOP_BODY:.*]] ; N32: [[LOOP_BODY]]: +; N32-NEXT: [[IV:%.*]] = phi ptr [ [[IV_NEXT:%.*]], %[[LOOP_BODY]] ], [ [[START]], %[[LOOP_BODY_PREHEADER]] ] ; N32-NEXT: [[TMP0:%.*]] = call i1 @cond() -; N32-NEXT: br i1 true, label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]] +; N32-NEXT: [[IV_NEXT]] = getelementptr i8, ptr [[IV]], i64 1 +; N32-NEXT: [[LOOP_CMP:%.*]] = icmp eq ptr [[IV_NEXT]], [[COMPUTED_END]] +; N32-NEXT: br i1 [[LOOP_CMP]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]] ; N32: [[EXIT_LOOPEXIT]]: ; N32-NEXT: br label %[[EXIT]] ; N32: [[EXIT]]: