Skip to content

Commit

Permalink
[SCEV] Simplify trunc-of-add/mul to add/mul-of-trunc under more circu…
Browse files Browse the repository at this point in the history
…mstances.

Summary:
Previously we would do this simplification only if it did not introduce
any new truncs (excepting new truncs which replace other cast ops).

This change weakens this condition: If the number of truncs stays the
same, but we're able to transform trunc(X + Y) to X + trunc(Y), that's
still simpler, and it may open up additional transformations.

While we're here, also clean up some duplicated code.

Reviewers: sanjoy

Subscribers: hiraditya, llvm-commits

Differential Revision: https://reviews.llvm.org/D48160

llvm-svn: 334736
  • Loading branch information
Justin Lebar committed Jun 14, 2018
1 parent 62a0747 commit b326904
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 40 deletions.
54 changes: 22 additions & 32 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Expand Up @@ -1256,42 +1256,32 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
return getTruncateOrZeroExtend(SZ->getOperand(), Ty);

// trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
// eliminate all the truncates, or we replace other casts with truncates.
if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
// trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
// trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
// if after transforming we have at most one truncate, not counting truncates
// that replace other casts.
if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
auto *CommOp = cast<SCEVCommutativeExpr>(Op);
SmallVector<const SCEV *, 4> Operands;
bool hasTrunc = false;
for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
if (!isa<SCEVCastExpr>(SA->getOperand(i)))
hasTrunc = isa<SCEVTruncateExpr>(S);
unsigned numTruncs = 0;
for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
++i) {
const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty);
if (!isa<SCEVCastExpr>(CommOp->getOperand(i)) && isa<SCEVTruncateExpr>(S))
numTruncs++;
Operands.push_back(S);
}
if (!hasTrunc)
return getAddExpr(Operands);
// In spite we checked in the beginning that ID is not in the cache,
// it is possible that during recursion and different modification
// ID came to cache, so if we found it, just return it.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
}

// trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
// eliminate all the truncates, or we replace other casts with truncates.
if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
SmallVector<const SCEV *, 4> Operands;
bool hasTrunc = false;
for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
if (!isa<SCEVCastExpr>(SM->getOperand(i)))
hasTrunc = isa<SCEVTruncateExpr>(S);
Operands.push_back(S);
if (numTruncs < 2) {
if (isa<SCEVAddExpr>(Op))
return getAddExpr(Operands);
else if (isa<SCEVMulExpr>(Op))
return getMulExpr(Operands);
else
llvm_unreachable("Unexpected SCEV type for Op.");
}
if (!hasTrunc)
return getMulExpr(Operands);
// In spite we checked in the beginning that ID is not in the cache,
// it is possible that during recursion and different modification
// ID came to cache, so if we found it, just return it.
// Although we checked in the beginning that ID is not in the cache, it is
// possible that during recursion and different modification ID was inserted
// into the cache. So if we find it, just return it.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
}
Expand Down
9 changes: 5 additions & 4 deletions llvm/test/Analysis/ScalarEvolution/different-loops-recs.ll
Expand Up @@ -277,9 +277,10 @@ define void @test_04() {
; CHECK: %tmp11 = add i64 %tmp10, undef
; CHECK-NEXT: --> ((sext i8 %tmp8 to i64) + {(-2 + undef),+,-1}<nw><%loop2>)
; CHECK: %tmp13 = trunc i64 %tmp11 to i32
; CHECK-NEXT: --> ((sext i8 %tmp8 to i32) + {(trunc i64 (-2 + undef) to i32),+,-1}<%loop2>)
; CHECK-NEXT: --> ((sext i8 %tmp8 to i32) + {(-2 + (trunc i64 undef to i32)),+,-1}<%loop2>)
; CHECK: %tmp14 = sub i32 %tmp13, %tmp2
; CHECK-NEXT: --> ((sext i8 %tmp8 to i32) + {{{{}}(-2 + (trunc i64 (-2 + undef) to i32)),+,-1}<%loop1>,+,-1}<%loop2>)
; `{{[{][{]}}` is the ugliness needed to match `{{`
; CHECK-NEXT: --> ((sext i8 %tmp8 to i32) + {{[{][{]}}(-4 + (trunc i64 undef to i32)),+,-1}<%loop1>,+,-1}<%loop2>)
; CHECK: %tmp15 = add nuw nsw i64 %tmp7, 1
; CHECK-NEXT: --> {3,+,1}<nuw><nsw><%loop2>

Expand Down Expand Up @@ -462,9 +463,9 @@ define void @test_08() {
; CHECK: %tmp11 = add i64 %iv.2.2, %iv.2.1
; CHECK-NEXT: --> ({0,+,-1}<nsw><%loop_2> + %iv.2.1)
; CHECK: %tmp12 = trunc i64 %tmp11 to i32
; CHECK-NEXT: --> (trunc i64 ({0,+,-1}<nsw><%loop_2> + %iv.2.1) to i32)
; CHECK-NEXT: --> ((trunc i64 %iv.2.1 to i32) + {0,+,-1}<%loop_2>)
; CHECK: %tmp14 = mul i32 %tmp12, %tmp7
; CHECK-NEXT: --> ((trunc i64 ({0,+,-1}<nsw><%loop_2> + %iv.2.1) to i32) * {-1,+,-1}<%loop_1>)
; CHECK-NEXT: --> (((trunc i64 %iv.2.1 to i32) + {0,+,-1}<%loop_2>) * {-1,+,-1}<%loop_1>)
; CHECK: %tmp16 = mul i64 %iv.2.1, %iv.1.1
; CHECK-NEXT: --> ({2,+,1}<nuw><nsw><%loop_1> * %iv.2.1)

Expand Down
Expand Up @@ -4,7 +4,7 @@

target datalayout = "e-p:32:32:32-p1:16:16:16-p2:8:8:8-p4:64:64:64-n16:32:64"

; CHECK: {%d,+,4}<%bb>{{ U: [^ ]+ S: [^ ]+}}{{ *}}Exits: ((4 * (trunc i32 (-1 + %n) to i16)) + %d)
; CHECK: {%d,+,4}<%bb>{{ U: [^ ]+ S: [^ ]+}}{{ *}} Exits: (-4 + (4 * (trunc i32 %n to i16)) + %d)


define void @foo(i32 addrspace(1)* nocapture %d, i32 %n) nounwind {
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Analysis/ScalarEvolution/sext-inreg.ll
Expand Up @@ -15,10 +15,10 @@ bb:
%t2 = ashr i64 %t1, 7
; CHECK: %t2 = ashr i64 %t1, 7
; CHECK-NEXT: sext i57 {0,+,199}<%bb> to i64
; CHECK-SAME: Exits: (sext i57 (199 * (trunc i64 (-1 + (2780916192016515319 * %n)) to i57)) to i64)
; CHECK-SAME: Exits: (sext i57 (-199 + (trunc i64 %n to i57)) to i64)
; CHECK: %s2 = ashr i64 %s1, 5
; CHECK-NEXT: sext i59 {0,+,199}<%bb> to i64
; CHECK-SAME: Exits: (sext i59 (199 * (trunc i64 (-1 + (2780916192016515319 * %n)) to i59)) to i64)
; CHECK-SAME: Exits: (sext i59 (-199 + (trunc i64 %n to i59)) to i64)
%s1 = shl i64 %i.01, 5
%s2 = ashr i64 %s1, 5
%t3 = getelementptr i64, i64* %x, i64 %i.01
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Analysis/ScalarEvolution/strip-injective-zext.ll
Expand Up @@ -9,7 +9,7 @@

; Check that the backedge taken count was actually computed:
; CHECK: Determining loop execution counts for: @f0
; CHECK-NEXT: Loop %b2: backedge-taken count is (-1 * (trunc i32 (1 + %a1) to i2))
; CHECK-NEXT: Loop %b2: backedge-taken count is (-1 + (-1 * (trunc i32 %a1 to i2)))

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64"

Expand Down
25 changes: 25 additions & 0 deletions llvm/test/Analysis/ScalarEvolution/trunc-simplify.ll
@@ -0,0 +1,25 @@
; RUN: opt < %s -analyze -scalar-evolution | FileCheck %s

; Check that we convert
; trunc(C * a) -> trunc(C) * trunc(a)
; if C is a constant.
; CHECK-LABEL: @trunc_of_mul
define i8 @trunc_of_mul(i32 %a) {
%b = mul i32 %a, 100
; CHECK: %c
; CHECK-NEXT: --> (100 * (trunc i32 %a to i8))
%c = trunc i32 %b to i8
ret i8 %c
}

; Check that we convert
; trunc(C + a) -> trunc(C) + trunc(a)
; if C is a constant.
; CHECK-LABEL: @trunc_of_add
define i8 @trunc_of_add(i32 %a) {
%b = add i32 %a, 100
; CHECK: %c
; CHECK-NEXT: --> (100 + (trunc i32 %a to i8))
%c = trunc i32 %b to i8
ret i8 %c
}

0 comments on commit b326904

Please sign in to comment.