-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Remove over-generalization from computeKnownBitsFromCmp() #72637
Conversation
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Nikita Popov (nikic) ChangesFor practical purposes, the only KnownBits patterns we care about are those involving a constant comparison RHS and constant mask. However, the actual implementation is written in a very general way -- and of course, with basically no test coverage of those generalizations. This patch reduces the implementation to only handle cases with constant operands. The only non-constant case I've kept are plain The test changes are all in "make sure we don't crash" tests. Full diff: https://github.com/llvm/llvm-project/pull/72637.diff 4 Files Affected:
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index 81b26678ae5d790..3139b3e8f319099 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -92,29 +92,19 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
AddAffected(B);
if (Pred == ICmpInst::ICMP_EQ) {
- // For equality comparisons, we handle the case of bit inversion.
- auto AddAffectedFromEq = [&AddAffected](Value *V) {
- Value *A, *B;
- // (A & B) or (A | B) or (A ^ B).
- if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
- AddAffected(A);
- AddAffected(B);
- // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
- } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
- AddAffected(A);
- }
- };
-
- AddAffectedFromEq(A);
- AddAffectedFromEq(B);
+ if (match(B, m_ConstantInt())) {
+ Value *X;
+ // (X & C) or (X | C) or (X ^ C).
+ // (X << C) or (X >>_s C) or (X >>_u C).
+ if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+ match(A, m_Shift(m_Value(X), m_ConstantInt())))
+ AddAffected(X);
+ }
} else if (Pred == ICmpInst::ICMP_NE) {
- Value *X, *Y;
- // Handle (a & b != 0). If a/b is a power of 2 we can use this
- // information.
- if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
+ Value *X;
+ // Handle (X & pow2 != 0).
+ if (match(A, m_And(m_Value(X), m_Power2())) && match(B, m_Zero()))
AddAffected(X);
- AddAffected(Y);
- }
} else if (Pred == ICmpInst::ICMP_ULT) {
Value *X;
// Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e25aa9c6863f335..10a4a6e349c3794 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -640,82 +640,69 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
QueryNoAC.AC = nullptr;
// Note that ptrtoint may change the bitwidth.
- Value *A, *B;
auto m_V =
m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
CmpInst::Predicate Pred;
- uint64_t C;
+ const APInt *Mask, *C;
+ uint64_t ShAmt;
switch (Cmp->getPredicate()) {
case ICmpInst::ICMP_EQ:
- // assume(v = a)
- if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- Known = Known.unionWith(RHSKnown);
- // assume(v & b = a)
+ // assume(V = C)
+ if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
+ Known = Known.unionWith(KnownBits::makeConstant(*C));
+ // assume(V & Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in the mask that are known to be one, we can propagate
- // known bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & MaskKnown.One;
- Known.One |= RHSKnown.One & MaskKnown.One;
- // assume(v | b = a)
+ m_ICmp(Pred, m_And(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For one bits in Mask, we can propagate bits from C to V.
+ Known.Zero |= ~*C & *Mask;
+ Known.One |= *C & *Mask;
+ // assume(V | Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in B that are known to be zero, we can propagate known
- // bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & BKnown.Zero;
- Known.One |= RHSKnown.One & BKnown.Zero;
- // assume(v ^ b = a)
+ m_ICmp(Pred, m_Or(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For zero bits in Mask, we can propagate bits from C to V.
+ Known.Zero |= ~*C & ~*Mask;
+ Known.One |= *C & ~*Mask;
+ // assume(V ^ Mask = C)
} else if (match(Cmp,
- m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
- KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
- // For those bits in B that are known to be zero, we can propagate known
- // bits from the RHS to V. For those bits in B that are known to be one,
- // we can propagate inverted known bits from the RHS to V.
- Known.Zero |= RHSKnown.Zero & BKnown.Zero;
- Known.One |= RHSKnown.One & BKnown.Zero;
- Known.Zero |= RHSKnown.One & BKnown.One;
- Known.One |= RHSKnown.Zero & BKnown.One;
- // assume(v << c = a)
- } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
- m_Value(A))) &&
- C < BitWidth) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-
- // For those bits in RHS that are known, we can propagate them to known
- // bits in V shifted to the right by C.
- RHSKnown.Zero.lshrInPlace(C);
- RHSKnown.One.lshrInPlace(C);
+ m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
+ // For those bits in Mask that are zero, we can propagate known bits
+ // from C to V. For those bits in Mask that are one, we can propagate
+ // inverted bits from C to V.
+ Known.Zero |= ~*C & ~*Mask;
+ Known.One |= *C & ~*Mask;
+ Known.Zero |= *C & *Mask;
+ Known.One |= ~*C & *Mask;
+ // assume(V << ShAmt = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
+ m_APInt(C))) &&
+ ShAmt < BitWidth) {
+ // For those bits in C that are known, we can propagate them to known
+ // bits in V shifted to the right by ShAmt.
+ KnownBits RHSKnown = KnownBits::makeConstant(*C);
+ RHSKnown.Zero.lshrInPlace(ShAmt);
+ RHSKnown.One.lshrInPlace(ShAmt);
Known = Known.unionWith(RHSKnown);
- // assume(v >> c = a)
- } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
- m_Value(A))) &&
- C < BitWidth) {
- KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+ // assume(V >> ShAmt = C)
+ } else if (match(Cmp, m_ICmp(Pred, m_Shr(m_V, m_ConstantInt(ShAmt)),
+ m_APInt(C))) &&
+ ShAmt < BitWidth) {
+ KnownBits RHSKnown = KnownBits::makeConstant(*C);
// For those bits in RHS that are known, we can propagate them to known
// bits in V shifted to the right by C.
- Known.Zero |= RHSKnown.Zero << C;
- Known.One |= RHSKnown.One << C;
+ Known.Zero |= RHSKnown.Zero << ShAmt;
+ Known.One |= RHSKnown.One << ShAmt;
}
break;
case ICmpInst::ICMP_NE: {
- // assume (v & b != 0) where b is a power of 2
+ // assume (V & B != 0) where B is a power of 2
const APInt *BPow2;
- if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) {
+ if (match(Cmp, m_ICmp(Pred, m_And(m_V, m_Power2(BPow2)), m_Zero())))
Known.One |= *BPow2;
- }
break;
}
default:
+ Value *A;
const APInt *Offset = nullptr;
if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))),
m_Value(A)))) {
diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll
index 934e9594f3f7b5a..5b2039b5b480501 100644
--- a/llvm/test/Transforms/InstCombine/assume.ll
+++ b/llvm/test/Transforms/InstCombine/assume.ll
@@ -265,7 +265,7 @@ define i32 @bundle2(ptr %P) {
define i1 @nonnull1(ptr %a) {
; CHECK-LABEL: @nonnull1(
-; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull !6, !noundef !6
+; CHECK-NEXT: [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull [[META6:![0-9]+]], !noundef [[META6]]
; CHECK-NEXT: tail call void @escape(ptr nonnull [[LOAD]])
; CHECK-NEXT: ret i1 false
;
@@ -383,7 +383,7 @@ define i1 @nonnull5(ptr %a) {
define i32 @assumption_conflicts_with_known_bits(i32 %a, i32 %b) {
; CHECK-LABEL: @assumption_conflicts_with_known_bits(
; CHECK-NEXT: store i1 true, ptr poison, align 1
-; CHECK-NEXT: ret i32 poison
+; CHECK-NEXT: ret i32 1
;
%and1 = and i32 %b, 3
%B1 = lshr i32 %and1, %and1
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 78ac730cf026ed9..d49cb79e1e27c98 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -4016,9 +4016,10 @@ define i32 @abs_preserve(i32 %x) {
declare void @llvm.assume(i1)
define i1 @PR35794(ptr %a) {
; CHECK-LABEL: @PR35794(
-; CHECK-NEXT: [[MASKCOND:%.*]] = icmp eq ptr [[A:%.*]], null
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt ptr [[A:%.*]], inttoptr (i64 -1 to ptr)
+; CHECK-NEXT: [[MASKCOND:%.*]] = icmp eq ptr [[A]], null
; CHECK-NEXT: tail call void @llvm.assume(i1 [[MASKCOND]])
-; CHECK-NEXT: ret i1 true
+; CHECK-NEXT: ret i1 [[CMP]]
;
%cmp = icmp sgt ptr %a, inttoptr (i64 -1 to ptr)
%maskcond = icmp eq ptr %a, null
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
3d4581c
to
f92c8f5
Compare
These don't get handled by the m_APInt code.
f92c8f5
to
3e6f8bd
Compare
I've made two additional updates here:
|
// assume(v | b = a) | ||
// assume(V = C) | ||
if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) { | ||
Known = Known.unionWith(KnownBits::makeConstant(*C)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why unionWith as opposed to known = KnownBits::makeConstant(*C)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just preserves the existing behavior wrt conflicts. I can drop the unionWith if we don't care about that.
Some of the simplifications here make sense independent of general -> constant simplification. But for the general -> constant simplification, imo it makes more sense to add coverage as opposed to just dropping the code. |
The context of this change is that I want to handle dominating conditions in computeKnownBits(), and would like it to use the same logic as computeKnownBitsFromCmp() currently used for assumes. Currently, we don't really feel the compile-time impact of this function, because assumes are very rare when compiling C/C++. When handling dominating conditions, the support for non-constants adds quite a lot of compile-time overhead without any significant effect on results (I checked). I could use completely independent logic for assumes and dominating conditions, but I don't particularly want to. |
I see. That seems like a reasonable motivation. |
The |
Ahh, you're right. Good point. |
ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
For practical purposes, the only KnownBits patterns we care about are those involving a constant comparison RHS and constant mask. However, the actual implementation is written in a very general way -- and of course, with basically no test coverage of those generalizations.
This patch reduces the implementation to only handle cases with constant operands.
The test changes are all in "make sure we don't crash" tests.