Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 15 additions & 48 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15749,51 +15749,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
return RewriteMap.lookup_or(S, S);
};

// Check for the SCEV expression (A /u B) * B while B is a constant, inside
// \p Expr. The check is done recuresively on \p Expr, which is assumed to
// be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
// /u B) * B was found, and return the divisor B in \p DividesBy. For
// example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
// (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
// DividesBy.
std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
[&](const SCEV *Expr, const SCEV *&DividesBy) {
if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
if (Mul->getNumOperands() != 2)
return false;
auto *MulLHS = Mul->getOperand(0);
auto *MulRHS = Mul->getOperand(1);
if (isa<SCEVConstant>(MulLHS))
std::swap(MulLHS, MulRHS);
if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
if (Div->getOperand(1) == MulRHS) {
DividesBy = MulRHS;
return true;
}
}
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
return false;
};

// Return true if Expr known to divide by \p DividesBy.
std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
[&](const SCEV *Expr, const SCEV *DividesBy) {
if (SE.getURemExpr(Expr, DividesBy)->isZero())
return true;
if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
return false;
};

const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
const SCEV *DividesBy = nullptr;
if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
// Check that the whole expression is divided by DividesBy
DividesBy =
IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
const APInt &Multiple = SE.getConstantMultiple(RewrittenLHS);
if (!Multiple.isOne())
DividesBy = SE.getConstant(Multiple);

// Collect rewrites for LHS and its transitive operands based on the
// condition.
Expand Down Expand Up @@ -15865,17 +15825,23 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
EnqueueOperands(SMax);
break;
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_UGE:
To = SE.getUMaxExpr(FromRewritten, RHS);
case CmpInst::ICMP_UGE: {
const SCEV *OpAlignedUp =
DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
To = SE.getUMaxExpr(FromRewritten, OpAlignedUp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, are the changes here necessary at all? It looks like we are already doing these next/prev divisor adjustments for RHS in the switch above this. Maybe the generalization of the divisor logic is sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep it looks like we get all the benfits from the patch with just the switch to getConstantMultiple: #162617

if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
EnqueueOperands(UMin);
break;
}
case CmpInst::ICMP_SGT:
case CmpInst::ICMP_SGE:
To = SE.getSMaxExpr(FromRewritten, RHS);
case CmpInst::ICMP_SGE: {
const SCEV *OpAlignedUp =
DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
To = SE.getSMaxExpr(FromRewritten, OpAlignedUp);
if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
EnqueueOperands(SMin);
break;
}
case CmpInst::ICMP_EQ:
if (isa<SCEVConstant>(RHS))
To = RHS;
Expand Down Expand Up @@ -16005,7 +15971,8 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
for (const SCEV *Expr : ExprsToRewrite) {
const SCEV *RewriteTo = Guards.RewriteMap[Expr];
Guards.RewriteMap.erase(Expr);
Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
const SCEV *Rewritten = Guards.rewrite(RewriteTo);
Guards.RewriteMap.insert({Expr, Rewritten});
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Analysis/ScalarEvolution/trip-count-minmax.ll
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ define void @umin(i32 noundef %a, i32 noundef %b) {
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 2147483646
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a) umin (4 * %b)))
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
; CHECK-NEXT: Loop %for.body: Trip multiple is 2
;
; void umin(unsigned a, unsigned b) {
; a *= 2;
Expand Down Expand Up @@ -157,7 +157,7 @@ define void @smin(i32 noundef %a, i32 noundef %b) {
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is i32 2147483646
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + ((2 * %a)<nsw> smin (4 * %b)<nsw>))
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
; CHECK-NEXT: Loop %for.body: Trip multiple is 2
;
; void smin(signed a, signed b) {
; a *= 2;
Expand Down
47 changes: 40 additions & 7 deletions llvm/test/Transforms/LoopVectorize/single_early_exit.ll
Original file line number Diff line number Diff line change
Expand Up @@ -546,19 +546,50 @@ define i64 @loop_guards_needed_to_prove_deref_multiple(i32 %x, i1 %c, ptr derefe
; CHECK-NEXT: call void @llvm.assume(i1 [[PRE_2]])
; CHECK-NEXT: [[N:%.*]] = add i32 [[SEL]], -1
; CHECK-NEXT: [[N_EXT:%.*]] = zext i32 [[N]] to i64
; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[SEL]], -2
; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[TMP0]] to i64
; CHECK-NEXT: [[TMP2:%.*]] = add nuw nsw i64 [[TMP1]], 2
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ult i64 [[TMP2]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK: vector.ph:
; CHECK-NEXT: [[N_MOD_VF:%.*]] = urem i64 [[TMP2]], 4
; CHECK-NEXT: [[IV_NEXT:%.*]] = sub i64 [[TMP2]], [[N_MOD_VF]]
; CHECK-NEXT: br label [[LOOP_HEADER:%.*]]
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[LOOP_HEADER]] ]
; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[INDEX]]
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i8>, ptr [[TMP3]], align 1
; CHECK-NEXT: [[TMP4:%.*]] = icmp eq <4 x i8> [[WIDE_LOAD]], zeroinitializer
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4
; CHECK-NEXT: [[TMP5:%.*]] = freeze <4 x i1> [[TMP4]]
; CHECK-NEXT: [[TMP6:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP5]])
; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[IV_NEXT]]
; CHECK-NEXT: [[TMP8:%.*]] = or i1 [[TMP6]], [[TMP7]]
; CHECK-NEXT: br i1 [[TMP8]], label [[MIDDLE_SPLIT:%.*]], label [[LOOP_HEADER]], !llvm.loop [[LOOP11:![0-9]+]]
; CHECK: middle.split:
; CHECK-NEXT: br i1 [[TMP6]], label [[VECTOR_EARLY_EXIT:%.*]], label [[LOOP_LATCH:%.*]]
; CHECK: middle.block:
; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 [[TMP2]], [[IV_NEXT]]
; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT_LOOPEXIT:%.*]], label [[SCALAR_PH]]
; CHECK: vector.early.exit:
; CHECK-NEXT: [[TMP9:%.*]] = call i64 @llvm.experimental.cttz.elts.i64.v4i1(<4 x i1> [[TMP4]], i1 true)
; CHECK-NEXT: [[TMP10:%.*]] = add i64 [[INDEX]], [[TMP9]]
; CHECK-NEXT: br label [[EXIT_LOOPEXIT]]
; CHECK: scalar.ph:
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT]], [[LOOP_LATCH]] ], [ 0, [[PH]] ]
; CHECK-NEXT: br label [[LOOP_HEADER1:%.*]]
; CHECK: loop.header:
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], [[LOOP_LATCH:%.*]] ], [ 0, [[PH]] ]
; CHECK-NEXT: [[GEP_SRC_I:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV]]
; CHECK-NEXT: [[IV1:%.*]] = phi i64 [ [[IV_NEXT1:%.*]], [[LOOP_LATCH1:%.*]] ], [ [[IV]], [[SCALAR_PH]] ]
; CHECK-NEXT: [[GEP_SRC_I:%.*]] = getelementptr i8, ptr [[SRC]], i64 [[IV1]]
; CHECK-NEXT: [[L:%.*]] = load i8, ptr [[GEP_SRC_I]], align 1
; CHECK-NEXT: [[C_1:%.*]] = icmp eq i8 [[L]], 0
; CHECK-NEXT: br i1 [[C_1]], label [[EXIT_LOOPEXIT:%.*]], label [[LOOP_LATCH]]
; CHECK-NEXT: br i1 [[C_1]], label [[EXIT_LOOPEXIT]], label [[LOOP_LATCH1]]
; CHECK: loop.latch:
; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV]], [[N_EXT]]
; CHECK-NEXT: br i1 [[EC]], label [[EXIT_LOOPEXIT]], label [[LOOP_HEADER]]
; CHECK-NEXT: [[IV_NEXT1]] = add i64 [[IV1]], 1
; CHECK-NEXT: [[EC:%.*]] = icmp eq i64 [[IV1]], [[N_EXT]]
; CHECK-NEXT: br i1 [[EC]], label [[EXIT_LOOPEXIT]], label [[LOOP_HEADER1]], !llvm.loop [[LOOP12:![0-9]+]]
; CHECK: exit.loopexit:
; CHECK-NEXT: [[RES_PH:%.*]] = phi i64 [ [[IV]], [[LOOP_HEADER]] ], [ 0, [[LOOP_LATCH]] ]
; CHECK-NEXT: [[RES_PH:%.*]] = phi i64 [ [[IV1]], [[LOOP_HEADER1]] ], [ 0, [[LOOP_LATCH1]] ], [ 0, [[LOOP_LATCH]] ], [ [[TMP10]], [[VECTOR_EARLY_EXIT]] ]
; CHECK-NEXT: br label [[EXIT]]
; CHECK: exit:
; CHECK-NEXT: [[RES:%.*]] = phi i64 [ -1, [[ENTRY:%.*]] ], [ -2, [[THEN]] ], [ [[RES_PH]], [[EXIT_LOOPEXIT]] ]
Expand Down Expand Up @@ -609,4 +640,6 @@ exit:
; CHECK: [[LOOP8]] = distinct !{[[LOOP8]], [[META2]], [[META1]]}
; CHECK: [[LOOP9]] = distinct !{[[LOOP9]], [[META1]], [[META2]]}
; CHECK: [[LOOP10]] = distinct !{[[LOOP10]], [[META2]], [[META1]]}
; CHECK: [[LOOP11]] = distinct !{[[LOOP11]], [[META1]], [[META2]]}
; CHECK: [[LOOP12]] = distinct !{[[LOOP12]], [[META2]], [[META1]]}
;.
Loading