diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 5ff7c1027108b..5d8e822eaddff 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -6719,49 +6719,23 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, return std::nullopt; } - // For `IsToFold`, a primary IV can be replaced by other affine AddRec when it - // is only used by the terminating condition. To check for this, we may need - // to traverse through a chain of use-def until we can examine the final - // usage. - // *----------------------* - // *---->| LoopHeader: | - // | | PrimaryIV = phi ... | - // | *----------------------* - // | | - // | | - // | chain of - // | single use - // used by | - // phi | - // | Value - // | / \ - // | chain of chain of - // | single use single use - // | / \ - // | / \ - // *- Value Value --> used by terminating condition - auto IsToFold = [&](PHINode &PN) -> bool { - Value *V = &PN; - - while (V->getNumUses() == 1) - V = *V->user_begin(); - - if (V->getNumUses() != 2) - return false; + BinaryOperator *LHS = dyn_cast(TermCond->getOperand(0)); + Value *RHS = TermCond->getOperand(1); + if (!LHS || !L->isLoopInvariant(RHS)) + // We could pattern match the inverse form of the icmp, but that is + // non-canonical, and this pass is running *very* late in the pipeline. + return std::nullopt; - Value *VToPN = nullptr; - Value *VToTermCond = nullptr; - for (User *U : V->users()) { - while (U->getNumUses() == 1) { - if (isa(U)) - VToPN = U; - if (U == TermCond) - VToTermCond = U; - U = *U->user_begin(); - } - } - return VToPN && VToTermCond; - }; + // Find the IV used by the current exit condition. + PHINode *ToFold; + Value *ToFoldStart, *ToFoldStep; + if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) + return std::nullopt; + + // If that IV isn't dead after we rewrite the exit condition in terms of + // another IV, there's no point in doing the transform. + if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) + return std::nullopt; // If this is an IV which we could replace the terminating condition, return // the final value of the alternative IV on the last iteration. @@ -6789,11 +6763,13 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, return TermValueS; }; - PHINode *ToFold = nullptr; PHINode *ToHelpFold = nullptr; const SCEV *TermValueS = nullptr; for (PHINode &PN : L->getHeader()->phis()) { + if (ToFold == &PN) + continue; + if (!SE.isSCEVable(PN.getType())) { LLVM_DEBUG(dbgs() << "IV of phi '" << PN << "' is not SCEV-able, not qualified for the " @@ -6809,9 +6785,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, continue; } - if (IsToFold(PN)) - ToFold = &PN; - else if (auto P = getAlternateIVEnd(PN)) { + if (auto P = getAlternateIVEnd(PN)) { ToHelpFold = &PN; TermValueS = P; } diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll index 8682351a4e30c..1b9b58f79b480 100644 --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll @@ -106,15 +106,14 @@ define void @NonAddRecIV(ptr %a) { ; CHECK-SAME: (ptr [[A:%.*]]) { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[A]], i32 84 +; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[A]], i64 148 ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.body: ; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[UGLYGEP2:%.*]], [[FOR_BODY]] ], [ [[UGLYGEP]], [[ENTRY:%.*]] ] -; CHECK-NEXT: [[LSR_IV:%.*]] = phi i32 [ [[LSR_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 1, [[ENTRY]] ] ; CHECK-NEXT: store i32 1, ptr [[LSR_IV1]], align 4 -; CHECK-NEXT: [[LSR_IV_NEXT]] = mul nsw i32 [[LSR_IV]], 2 ; CHECK-NEXT: [[UGLYGEP2]] = getelementptr i8, ptr [[LSR_IV1]], i64 4 -; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i32 [[LSR_IV_NEXT]], 65536 -; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END:%.*]], label [[FOR_BODY]] +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[UGLYGEP2]], [[SCEVGEP]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END:%.*]], label [[FOR_BODY]] ; CHECK: for.end: ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll index 7da1a73a21d1e..a72e859791574 100644 --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold.ll @@ -297,39 +297,6 @@ define void @IcmpSgt(ptr %a) { ; CHECK-LABEL: @IcmpSgt( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i32 84 -; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[A]], i64 88 -; CHECK-NEXT: br label [[FOR_BODY:%.*]] -; CHECK: for.body: -; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[UGLYGEP2:%.*]], [[FOR_BODY]] ], [ [[UGLYGEP]], [[ENTRY:%.*]] ] -; CHECK-NEXT: store i32 1, ptr [[LSR_IV1]], align 4 -; CHECK-NEXT: [[UGLYGEP2]] = getelementptr i8, ptr [[LSR_IV1]], i32 4 -; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[UGLYGEP2]], [[SCEVGEP]] -; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END:%.*]], label [[FOR_BODY]] -; CHECK: for.end: -; CHECK-NEXT: ret void -; -entry: - %uglygep = getelementptr i8, ptr %a, i32 84 - br label %for.body - -for.body: ; preds = %for.body, %entry - %lsr.iv1 = phi ptr [ %uglygep2, %for.body ], [ %uglygep, %entry ] - %lsr.iv = phi i32 [ %lsr.iv.next, %for.body ], [ 379, %entry ] - store i32 1, ptr %lsr.iv1, align 4 - %lsr.iv.next = add nsw i32 %lsr.iv, -1 - %uglygep2 = getelementptr i8, ptr %lsr.iv1, i32 4 - %exitcond.not = icmp sgt i32 0, %lsr.iv.next - br i1 %exitcond.not, label %for.body, label %for.end - -for.end: ; preds = %for.body - ret void -} - -; Invert predicate and branches -define void @IcmpSgt2(ptr %a) { -; CHECK-LABEL: @IcmpSgt2( -; CHECK-NEXT: entry: -; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i32 84 ; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i8, ptr [[A]], i64 1600 ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.body: