Skip to content

Commit

Permalink
[SCEV] Fix trip multiple calculation
Browse files Browse the repository at this point in the history
If loop bound containing calculations like min(a,b), the Scalar
Evolution API getSmallConstantTripMultiple returns 4294967295 "-1"
as the trip multiple. The problem is that, SCEV use -1 * umax to
represent umin. The multiple constant -1 was returned, and the logic
of guarding against huge trip counts was skipped. Because -1 has 32
active bits.

The fix attempt to factor more general cases. First try to get the
greatest power of two divisor of trip count expression. In case
overflow happens, the trip count expression is still divisible by the
greatest power of two divisor returned. Returns 1 if not divisible by 2.

Patch by Huihui Zhang <huihuiz@codeaurora.org>

Differential Revision: https://reviews.llvm.org/D30840

llvm-svn: 298301
  • Loading branch information
Eli Friedman committed Mar 20, 2017
1 parent 6043e0f commit b1578d3
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 10 deletions.
26 changes: 16 additions & 10 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Expand Up @@ -5504,17 +5504,16 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
return 1;

// Get the trip count from the BE count by adding 1.
const SCEV *TCMul = getAddExpr(ExitCount, getOne(ExitCount->getType()));
// FIXME: SCEV distributes multiplication as V1*C1 + V2*C1. We could attempt
// to factor simple cases.
if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(TCMul))
TCMul = Mul->getOperand(0);

const SCEVConstant *MulC = dyn_cast<SCEVConstant>(TCMul);
if (!MulC)
return 1;
const SCEV *TCExpr = getAddExpr(ExitCount, getOne(ExitCount->getType()));

const SCEVConstant *TC = dyn_cast<SCEVConstant>(TCExpr);
if (!TC)
// Attempt to factor more general cases. Returns the greatest power of
// two divisor. If overflow happens, the trip count expression is still
// divisible by the greatest power of 2 divisor returned.
return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr));

ConstantInt *Result = MulC->getValue();
ConstantInt *Result = TC->getValue();

// Guard against huge trip counts (this requires checking
// for zero to handle the case where the trip count == -1 and the
Expand Down Expand Up @@ -9667,6 +9666,13 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
OS << "Unpredictable predicated backedge-taken count. ";
}
OS << "\n";

if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
OS << "Loop ";
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
OS << ": ";
OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
}
}

static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) {
Expand Down
125 changes: 125 additions & 0 deletions llvm/test/Analysis/ScalarEvolution/tripmultiple_calculation.ll
@@ -0,0 +1,125 @@
; RUN: opt -S -analyze -scalar-evolution < %s 2>&1 | FileCheck %s

; umin is represented using -1 * umax in scalar evolution. -1 is considered as the
; constant of the multiply expression (-1 * ((-1 + (-1 * %a)) umax (-1 + (-1 * %b)))).
; Returns the greatest power of 2 divisor by evaluating the minimal trailing zeros
; for the trip count expression.
;
; int foo(uint32_t a, uint32_t b, uint32_t *c) {
; for (uint32_t i = 0; i < (uint32_t)(a < b ? a : b) + 1; i++)
; c[i] = i;
; return 0;
; }
;
; CHECK: Loop %for.body: Trip multiple is 1

define i32 @foo(i32 %a, i32 %b, i32* %c) {
entry:
%cmp = icmp ult i32 %a, %b
%cond = select i1 %cmp, i32 %a, i32 %b
%add = add i32 %cond, 1
%cmp18 = icmp eq i32 %add, 0
br i1 %cmp18, label %for.cond.cleanup, label %for.body.preheader

for.body.preheader: ; preds = %entry
br label %for.body

for.cond.cleanup.loopexit: ; preds = %for.body
br label %for.cond.cleanup

for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry
ret i32 0

for.body: ; preds = %for.body.preheader, %for.body
%i.09 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
%arrayidx = getelementptr inbounds i32, i32* %c, i32 %i.09
store i32 %i.09, i32* %arrayidx, align 4
%inc = add nuw i32 %i.09, 1
%cmp1 = icmp ult i32 %inc, %add
br i1 %cmp1, label %for.body, label %for.cond.cleanup.loopexit
}

; Overflow may happen for the multiply expression n * 3, verify that trip
; multiple is set to 1 if NUW/NSW are not set.
;
; __attribute__((noinline)) void a(unsigned n) {
; #pragma unroll(3)
; for (unsigned i = 0; i != n * 3; ++i)
; printf("TEST%u\n", i);
; }
; int main() { a(2863311531U); }
;
; CHECK: Loop %for.body: Trip multiple is 1

@.str2 = private unnamed_addr constant [8 x i8] c"TEST%u\0A\00", align 1

define void @foo2(i32 %n) {
entry:
%mul = mul i32 %n, 3
%cmp4 = icmp eq i32 %mul, 0
br i1 %cmp4, label %for.cond.cleanup, label %for.body.preheader

for.body.preheader: ; preds = %entry
br label %for.body

for.cond.cleanup.loopexit: ; preds = %for.body
br label %for.cond.cleanup

for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry
ret void

for.body: ; preds = %for.body.preheader, %for.body
%i.05 = phi i32 [ %inc, %for.body ], [ 0, %for.body.preheader ]
%call = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([8 x i8], [8 x i8]* @.str2, i32 0, i32 0), i32 %i.05)
%inc = add nuw i32 %i.05, 1
%cmp = icmp eq i32 %inc, %mul
br i1 %cmp, label %for.cond.cleanup.loopexit, label %for.body
}

declare i32 @printf(i8* nocapture readonly, ...)


; If we couldn't prove no overflow for the multiply expression 24 * n,
; returns the greatest power of 2 divisor. If overflows happens
; the trip count is still divisible by the greatest power of 2 divisor.
;
; CHECK: Loop %l3: Trip multiple is 8

declare void @f()

define i32 @foo3(i32 %n) {
entry:
%loop_ctl = mul i32 %n, 24
br label %l3

l3:
%x.0 = phi i32 [ 0, %entry ], [ %inc, %l3 ]
call void @f()
%inc = add i32 %x.0, 1
%exitcond = icmp eq i32 %inc, %loop_ctl
br i1 %exitcond, label %exit, label %l3

exit:
ret i32 0
}

; If the trip count is a constant, verify that we obtained the trip
; count itself. For huge trip counts, or zero, we return 1.
;
; CHECK: Loop %l3: Trip multiple is 3

define i32 @foo4(i32 %n) {
entry:
br label %l3

l3:
%x.0 = phi i32 [ 0, %entry ], [ %inc, %l3 ]
call void @f()
%inc = add i32 %x.0, 1
%exitcond = icmp eq i32 %inc, 3
br i1 %exitcond, label %exit, label %l3

exit:
ret i32 0
}

0 comments on commit b1578d3

Please sign in to comment.