Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,7 @@ class ScalarEvolution {

class LoopGuards {
DenseMap<const SCEV *, const SCEV *> RewriteMap;
SmallDenseSet<std::pair<const SCEV *, const SCEV *>> NotEqual;
bool PreserveNUW = false;
bool PreserveNSW = false;
ScalarEvolution &SE;
Expand Down
53 changes: 42 additions & 11 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15740,19 +15740,26 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
GetNextSCEVDividesByDivisor(One, DividesBy);
To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
} else {
// LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
// but creating the subtraction eagerly is expensive. Track the
// inequalities in a separate map, and materialize the rewrite lazily
// when encountering a suitable subtraction while re-writing.
if (LHS->getType()->isPointerTy()) {
LHS = SE.getLosslessPtrToIntExpr(LHS);
RHS = SE.getLosslessPtrToIntExpr(RHS);
if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS))
break;
}
auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) {
const SCEV *Sub = SE.getMinusSCEV(A, B);
AddRewrite(Sub, Sub,
SE.getUMaxExpr(Sub, SE.getOne(From->getType())));
};
AddSubRewrite(LHS, RHS);
AddSubRewrite(RHS, LHS);
const SCEVConstant *C;
const SCEV *A, *B;
if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) &&
match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't really need the other operand to be SCEVConstant, though I guess the non-constant case may be less practically relevant...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I couldn't find a case where it would matter. I guess we could check both sides to find a common SCEV on both sides.

RHS = A;
LHS = B;
}
if (LHS > RHS)
std::swap(LHS, RHS);
Guards.NotEqual.insert({LHS, RHS});
continue;
}
break;
Expand Down Expand Up @@ -15886,13 +15893,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
class SCEVLoopGuardRewriter
: public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
const DenseMap<const SCEV *, const SCEV *> &Map;
const SmallDenseSet<std::pair<const SCEV *, const SCEV *>> &NotEqual;

SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;

public:
SCEVLoopGuardRewriter(ScalarEvolution &SE,
const ScalarEvolution::LoopGuards &Guards)
: SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
: SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
NotEqual(Guards.NotEqual) {
if (Guards.PreserveNUW)
FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
if (Guards.PreserveNSW)
Expand Down Expand Up @@ -15947,14 +15956,36 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}

const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
// Helper to check if S is a subtraction (A - B) where A != B, and if so,
// return UMax(S, 1).
auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
const SCEV *LHS, *RHS;
if (MatchBinarySub(S, LHS, RHS)) {
if (LHS > RHS)
std::swap(LHS, RHS);
if (NotEqual.contains({LHS, RHS}))
return SE.getUMaxExpr(S, SE.getOne(S->getType()));
}
return nullptr;
};

// Check if Expr itself is a subtraction pattern with guard info.
if (const SCEV *Rewritten = RewriteSubtraction(Expr))
return Rewritten;

// Trip count expressions sometimes consist of adding 3 operands, i.e.
// (Const + A + B). There may be guard info for A + B, and if so, apply
// it.
// TODO: Could more generally apply guards to Add sub-expressions.
if (isa<SCEVConstant>(Expr->getOperand(0)) &&
Expr->getNumOperands() == 3) {
if (const SCEV *S = Map.lookup(
SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
const SCEV *Add =
SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
if (const SCEV *Rewritten = RewriteSubtraction(Add))
return SE.getAddExpr(
Expr->getOperand(0), Rewritten,
ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
if (const SCEV *S = Map.lookup(Add))
return SE.getAddExpr(Expr->getOperand(0), S);
}
SmallVector<const SCEV *, 2> Operands;
Expand Down Expand Up @@ -15989,7 +16020,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
}
};

if (RewriteMap.empty())
if (RewriteMap.empty() && NotEqual.empty())
return Expr;

SCEVLoopGuardRewriter Rewriter(SE, *this);
Expand Down
15 changes: 11 additions & 4 deletions llvm/test/Transforms/IndVarSimplify/pointer-loop-guards.ll
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,14 @@ define void @test_sub_cmp(ptr align 8 %start, ptr %end) {
; N32-NEXT: [[CMP_ENTRY:%.*]] = icmp eq ptr [[START]], [[END]]
; N32-NEXT: br i1 [[CMP_ENTRY]], label %[[EXIT:.*]], label %[[LOOP_HEADER_PREHEADER:.*]]
; N32: [[LOOP_HEADER_PREHEADER]]:
; N32-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[PTR_DIFF]], i64 1)
; N32-NEXT: br label %[[LOOP_HEADER:.*]]
; N32: [[LOOP_HEADER]]:
; N32-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[LOOP_LATCH:.*]] ], [ 0, %[[LOOP_HEADER_PREHEADER]] ]
; N32-NEXT: [[C_1:%.*]] = call i1 @cond()
; N32-NEXT: br i1 [[C_1]], label %[[EXIT_EARLY:.*]], label %[[LOOP_LATCH]]
; N32: [[LOOP_LATCH]]:
; N32-NEXT: [[IV_NEXT]] = add nuw i64 [[IV]], 1
; N32-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[UMAX]]
; N32-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[PTR_DIFF]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found that there's some improvements with the change added in 0590c9e

; N32-NEXT: br i1 [[EXITCOND]], label %[[LOOP_HEADER]], label %[[EXIT_LOOPEXIT:.*]]
; N32: [[EXIT_EARLY]]:
; N32-NEXT: br label %[[EXIT]]
Expand Down Expand Up @@ -162,13 +161,17 @@ define void @test_ptr_diff_with_assume(ptr align 8 %start, ptr align 8 %end, ptr
; CHECK-NEXT: [[PTR_DIFF:%.*]] = sub i64 [[START_INT]], [[END_INT]]
; CHECK-NEXT: [[DIFF_CMP:%.*]] = icmp ult i64 [[PTR_DIFF]], 2
; CHECK-NEXT: call void @llvm.assume(i1 [[DIFF_CMP]])
; CHECK-NEXT: [[COMPUTED_END:%.*]] = getelementptr i8, ptr [[START]], i64 [[PTR_DIFF]]
; CHECK-NEXT: [[ENTRY_CMP:%.*]] = icmp eq ptr [[START]], [[END]]
; CHECK-NEXT: br i1 [[ENTRY_CMP]], label %[[EXIT:.*]], label %[[LOOP_BODY_PREHEADER:.*]]
; CHECK: [[LOOP_BODY_PREHEADER]]:
; CHECK-NEXT: br label %[[LOOP_BODY:.*]]
; CHECK: [[LOOP_BODY]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[IV_NEXT:%.*]], %[[LOOP_BODY]] ], [ [[START]], %[[LOOP_BODY_PREHEADER]] ]
; CHECK-NEXT: [[TMP0:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 true, label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]]
; CHECK-NEXT: [[IV_NEXT]] = getelementptr i8, ptr [[IV]], i64 1
; CHECK-NEXT: [[LOOP_CMP:%.*]] = icmp eq ptr [[IV_NEXT]], [[COMPUTED_END]]
; CHECK-NEXT: br i1 [[LOOP_CMP]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]]
; CHECK: [[EXIT_LOOPEXIT]]:
; CHECK-NEXT: br label %[[EXIT]]
; CHECK: [[EXIT]]:
Expand All @@ -182,13 +185,17 @@ define void @test_ptr_diff_with_assume(ptr align 8 %start, ptr align 8 %end, ptr
; N32-NEXT: [[PTR_DIFF:%.*]] = sub i64 [[START_INT]], [[END_INT]]
; N32-NEXT: [[DIFF_CMP:%.*]] = icmp ult i64 [[PTR_DIFF]], 2
; N32-NEXT: call void @llvm.assume(i1 [[DIFF_CMP]])
; N32-NEXT: [[COMPUTED_END:%.*]] = getelementptr i8, ptr [[START]], i64 [[PTR_DIFF]]
; N32-NEXT: [[ENTRY_CMP:%.*]] = icmp eq ptr [[START]], [[END]]
; N32-NEXT: br i1 [[ENTRY_CMP]], label %[[EXIT:.*]], label %[[LOOP_BODY_PREHEADER:.*]]
; N32: [[LOOP_BODY_PREHEADER]]:
; N32-NEXT: br label %[[LOOP_BODY:.*]]
; N32: [[LOOP_BODY]]:
; N32-NEXT: [[IV:%.*]] = phi ptr [ [[IV_NEXT:%.*]], %[[LOOP_BODY]] ], [ [[START]], %[[LOOP_BODY_PREHEADER]] ]
; N32-NEXT: [[TMP0:%.*]] = call i1 @cond()
; N32-NEXT: br i1 true, label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also one case that regressed, but there are quite a few improvements with unrolling and others on macOS.

; N32-NEXT: [[IV_NEXT]] = getelementptr i8, ptr [[IV]], i64 1
; N32-NEXT: [[LOOP_CMP:%.*]] = icmp eq ptr [[IV_NEXT]], [[COMPUTED_END]]
; N32-NEXT: br i1 [[LOOP_CMP]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP_BODY]]
; N32: [[EXIT_LOOPEXIT]]:
; N32-NEXT: br label %[[EXIT]]
; N32: [[EXIT]]:
Expand Down
Loading