-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SCEV] Generalize (C * A /u C) -> A fold to (C1 * A /u C2) -> C1/C2 * A. #157159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-analysis Author: Florian Hahn (fhahn) ChangesGeneralize fold added in 74ec38f to support multiplying and dividing https://alive2.llvm.org/ce/z/eqJ2xj Full diff: https://github.com/llvm/llvm-project/pull/157159.diff 2 Files Affected:
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index bd57d1192eb94..f9fbcb05798da 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -3216,13 +3216,16 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
};
}
- // Try to fold (C * D /u C) -> D, if C is a power-of-2 and D is a multiple
- // of C.
+ // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
+ // D is a multiple of C2, and C1 is a multiple of C1.
const SCEV *D;
- if (match(Ops[1], m_scev_UDiv(m_SCEV(D), m_scev_Specific(LHSC))) &&
- LHSC->getAPInt().isPowerOf2() &&
- LHSC->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
- return D;
+ const SCEVConstant *C2;
+ if (LHSC->getAPInt().isPowerOf2() &&
+ match(Ops[1], m_scev_UDiv(m_SCEV(D), m_SCEVConstant(C2))) &&
+ C2->getAPInt().isPowerOf2() &&
+ getMinTrailingZeros(LHSC) >= getMinTrailingZeros(C2) &&
+ getMinTrailingZeros(LHSC) <= getMinTrailingZeros(D)) {
+ return getMulExpr(getUDivExpr(LHSC, C2), D);
}
}
}
diff --git a/llvm/test/Analysis/ScalarEvolution/mul-udiv-folds.ll b/llvm/test/Analysis/ScalarEvolution/mul-udiv-folds.ll
index 9f4360d2ae383..afe69ceb148aa 100644
--- a/llvm/test/Analysis/ScalarEvolution/mul-udiv-folds.ll
+++ b/llvm/test/Analysis/ScalarEvolution/mul-udiv-folds.ll
@@ -19,7 +19,7 @@ define void @udiv4_and_udiv2_mul_4(i1 %c, ptr %A) {
; CHECK-NEXT: %iv = phi i64 [ %iv.start, %entry ], [ %iv.next, %loop ]
; CHECK-NEXT: --> {((zext i32 %start to i64) /u 4),+,1}<%loop> U: full-set S: full-set Exits: ((zext i32 %start to i64) /u 2) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %gep = getelementptr i32, ptr %A, i64 %iv
-; CHECK-NEXT: --> {((zext i32 %start to i64) + %A),+,4}<%loop> U: full-set S: full-set Exits: ((zext i32 %start to i64) + (4 * ((zext i32 %start to i64) /u 2))<nuw><nsw> + (-4 * ((zext i32 %start to i64) /u 4))<nsw> + %A) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: --> {((zext i32 %start to i64) + %A),+,4}<%loop> U: full-set S: full-set Exits: ((3 * (zext i32 %start to i64))<nuw><nsw> + (-4 * ((zext i32 %start to i64) /u 4))<nsw> + %A) LoopDispositions: { %loop: Computable }
; CHECK-NEXT: %iv.next = add i64 %iv, 1
; CHECK-NEXT: --> {(1 + ((zext i32 %start to i64) /u 4))<nuw><nsw>,+,1}<%loop> U: full-set S: full-set Exits: (1 + ((zext i32 %start to i64) /u 2))<nuw><nsw> LoopDispositions: { %loop: Computable }
; CHECK-NEXT: Determining loop execution counts for: @udiv4_and_udiv2_mul_4
@@ -48,6 +48,52 @@ exit:
ret void
}
+define void @udiv4_and_udiv2_mul_6(i1 %c, ptr %A) {
+; CHECK-LABEL: 'udiv4_and_udiv2_mul_6'
+; CHECK-NEXT: Classifying expressions for: @udiv4_and_udiv2_mul_6
+; CHECK-NEXT: %start = select i1 %c, i32 512, i32 0
+; CHECK-NEXT: --> %start U: [0,513) S: [0,513)
+; CHECK-NEXT: %div.2 = lshr i32 %start, 1
+; CHECK-NEXT: --> (%start /u 2) U: [0,257) S: [0,257)
+; CHECK-NEXT: %div.4 = lshr i32 %start, 2
+; CHECK-NEXT: --> (%start /u 4) U: [0,129) S: [0,129)
+; CHECK-NEXT: %iv.start = zext i32 %div.4 to i64
+; CHECK-NEXT: --> ((zext i32 %start to i64) /u 4) U: [0,129) S: [0,129)
+; CHECK-NEXT: %wide.trip.count = zext i32 %div.2 to i64
+; CHECK-NEXT: --> ((zext i32 %start to i64) /u 2) U: [0,257) S: [0,257)
+; CHECK-NEXT: %iv = phi i64 [ %iv.start, %entry ], [ %iv.next, %loop ]
+; CHECK-NEXT: --> {((zext i32 %start to i64) /u 4),+,1}<%loop> U: full-set S: full-set Exits: ((zext i32 %start to i64) /u 2) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: %gep = getelementptr <{ i32, i16 }>, ptr %A, i64 %iv
+; CHECK-NEXT: --> {((6 * ((zext i32 %start to i64) /u 4))<nuw><nsw> + %A),+,6}<%loop> U: full-set S: full-set Exits: ((6 * ((zext i32 %start to i64) /u 2))<nuw><nsw> + %A) LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: %iv.next = add i64 %iv, 1
+; CHECK-NEXT: --> {(1 + ((zext i32 %start to i64) /u 4))<nuw><nsw>,+,1}<%loop> U: full-set S: full-set Exits: (1 + ((zext i32 %start to i64) /u 2))<nuw><nsw> LoopDispositions: { %loop: Computable }
+; CHECK-NEXT: Determining loop execution counts for: @udiv4_and_udiv2_mul_6
+; CHECK-NEXT: Loop %loop: backedge-taken count is ((-1 * ((zext i32 %start to i64) /u 4))<nsw> + ((zext i32 %start to i64) /u 2))
+; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 -1
+; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-1 * ((zext i32 %start to i64) /u 4))<nsw> + ((zext i32 %start to i64) /u 2))
+; CHECK-NEXT: Loop %loop: Trip multiple is 1
+;
+entry:
+ %start = select i1 %c, i32 512, i32 0
+ %div.2 = lshr i32 %start, 1
+ %div.4 = lshr i32 %start, 2
+ %iv.start = zext i32 %div.4 to i64
+ %wide.trip.count = zext i32 %div.2 to i64
+ br label %loop
+
+loop:
+ %iv = phi i64 [ %iv.start, %entry ], [ %iv.next, %loop ]
+ %gep = getelementptr <{i32, i16}>, ptr %A, i64 %iv
+ call void @use(ptr %gep)
+ %iv.next = add i64 %iv, 1
+ %ec = icmp eq i64 %iv, %wide.trip.count
+ br i1 %ec, label %exit, label %loop
+
+exit:
+ ret void
+}
+
+
define void @udiv4_and_udiv2_mul_1(i1 %c, ptr %A) {
; CHECK-LABEL: 'udiv4_and_udiv2_mul_1'
; CHECK-NEXT: Classifying expressions for: @udiv4_and_udiv2_mul_1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can handle the case where C2 is greater than C1: just generate D /u (C2 / C1)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, but probably better to do this separately, as when doing so bc09333 it appears there's another fold missing to avoid regressions:
(-2 * ((zext i32 %start to i64) /u 4))<nsw> -> (-1 * ((zext i32 %start to i64) /u 2))<nsw>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay.
Is there some reason to prefer getMinTrailingZeros(LHSC)
over LHSC->getAPInt().logBase2()
. I slightly prefer the latter because it's more clear that we expect an exact number, not just a "minimum".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consolidate tests for multiple divisors in a single loop, add multiplies by 1, 2, 5, 6. Extends test coverage for #157159.
…g test. Consolidate tests for multiple divisors in a single loop, add multiplies by 1, 2, 5, 6. Extends test coverage for llvm/llvm-project#157159.
a06fe6f
to
b2d8a35
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Also a couple minor suggestions; accept or reject if you want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
APInt LHSV = LHSC->getAPInt(); | |
const APInt &LHSV = LHSC->getAPInt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LHSV.logBase2() >= C2->getAPInt().logBase2() && | |
LHSV.uge(C2->getAPInt()) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done thanks
Generalize fold added in 74ec38f to support multiplying and dividing by different constants, given they are both powers-of-2 and C1 is a multiple of C2, checked via their trailing zeros. https://alive2.llvm.org/ce/z/eqJ2xj
b2d8a35
to
000be20
Compare
000be20
to
81b23c3
Compare
// Try to fold (C * D /u C) -> D, if C is a power-of-2 and D is a multiple | ||
// of C. | ||
// Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2, | ||
// D is a multiple of C2, and C1 is a multiple of C1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably say C1 is a multiple of C2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I adjusted this in a follow-up PR (#157555), but can also fix separately.
if (LHSV.isPowerOf2() && | ||
match(Ops[1], m_scev_UDiv(m_SCEV(D), m_SCEVConstant(C2))) && | ||
C2->getAPInt().isPowerOf2() && LHSV.uge(C2->getAPInt()) && | ||
LHSV.logBase2() <= getMinTrailingZeros(D)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, the actual underlying API for getMinTrailingZeros() is getConstantMultiple(), so this is probably easy to generalize to the non-pow2 case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, initially I was checking for multiples directly, but started with just the power-of-2 cases, then extend it further
… -> C1/C2 * A. (#157159) Generalize fold added in 74ec38f (llvm/llvm-project#156730) to support multiplying and dividing by different constants, given they are both powers-of-2 and C1 is a multiple of C2, checked via logBase2. https://alive2.llvm.org/ce/z/eqJ2xj PR: llvm/llvm-project#157159
Generalize fold added in 74ec38f / #156730 to support multiplying and
dividing by different constants, given they are both powers-of-2 and C1 is a
multiple of C2, checked via their trailing zeros.
https://alive2.llvm.org/ce/z/eqJ2xj