Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Remove over-generalization from computeKnownBitsFromCmp() #72637

Merged
merged 4 commits into from
Nov 27, 2023

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Nov 17, 2023

For practical purposes, the only KnownBits patterns we care about are those involving a constant comparison RHS and constant mask. However, the actual implementation is written in a very general way -- and of course, with basically no test coverage of those generalizations.

This patch reduces the implementation to only handle cases with constant operands.

The test changes are all in "make sure we don't crash" tests.

@llvmbot
Copy link
Collaborator

llvmbot commented Nov 17, 2023

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

Changes

For practical purposes, the only KnownBits patterns we care about are those involving a constant comparison RHS and constant mask. However, the actual implementation is written in a very general way -- and of course, with basically no test coverage of those generalizations.

This patch reduces the implementation to only handle cases with constant operands. The only non-constant case I've kept are plain V pred A comparisons, where I am less confident that this is useless.

The test changes are all in "make sure we don't crash" tests.


Full diff: https://github.com/llvm/llvm-project/pull/72637.diff

4 Files Affected:

  • (modified) llvm/lib/Analysis/AssumptionCache.cpp (+11-21)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+43-56)
  • (modified) llvm/test/Transforms/InstCombine/assume.ll (+2-2)
  • (modified) llvm/test/Transforms/InstCombine/icmp.ll (+3-2)
diff --git a/llvm/lib/Analysis/AssumptionCache.cpp b/llvm/lib/Analysis/AssumptionCache.cpp
index 81b26678ae5d790..3139b3e8f319099 100644
--- a/llvm/lib/Analysis/AssumptionCache.cpp
+++ b/llvm/lib/Analysis/AssumptionCache.cpp
@@ -92,29 +92,19 @@ findAffectedValues(CallBase *CI, TargetTransformInfo *TTI,
     AddAffected(B);
 
     if (Pred == ICmpInst::ICMP_EQ) {
-      // For equality comparisons, we handle the case of bit inversion.
-      auto AddAffectedFromEq = [&AddAffected](Value *V) {
-        Value *A, *B;
-        // (A & B) or (A | B) or (A ^ B).
-        if (match(V, m_BitwiseLogic(m_Value(A), m_Value(B)))) {
-          AddAffected(A);
-          AddAffected(B);
-          // (A << C) or (A >>_s C) or (A >>_u C) where C is some constant.
-        } else if (match(V, m_Shift(m_Value(A), m_ConstantInt()))) {
-          AddAffected(A);
-        }
-      };
-
-      AddAffectedFromEq(A);
-      AddAffectedFromEq(B);
+      if (match(B, m_ConstantInt())) {
+        Value *X;
+        // (X & C) or (X | C) or (X ^ C).
+        // (X << C) or (X >>_s C) or (X >>_u C).
+        if (match(A, m_BitwiseLogic(m_Value(X), m_ConstantInt())) ||
+            match(A, m_Shift(m_Value(X), m_ConstantInt())))
+          AddAffected(X);
+      }
     } else if (Pred == ICmpInst::ICMP_NE) {
-      Value *X, *Y;
-      // Handle (a & b != 0). If a/b is a power of 2 we can use this
-      // information.
-      if (match(A, m_And(m_Value(X), m_Value(Y))) && match(B, m_Zero())) {
+      Value *X;
+      // Handle (X & pow2 != 0).
+      if (match(A, m_And(m_Value(X), m_Power2())) && match(B, m_Zero()))
         AddAffected(X);
-        AddAffected(Y);
-      }
     } else if (Pred == ICmpInst::ICMP_ULT) {
       Value *X;
       // Handle (A + C1) u< C2, which is the canonical form of A > C3 && A < C4,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e25aa9c6863f335..10a4a6e349c3794 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -640,82 +640,69 @@ static void computeKnownBitsFromCmp(const Value *V, const ICmpInst *Cmp,
   QueryNoAC.AC = nullptr;
 
   // Note that ptrtoint may change the bitwidth.
-  Value *A, *B;
   auto m_V =
       m_CombineOr(m_Specific(V), m_PtrToIntSameSize(Q.DL, m_Specific(V)));
 
   CmpInst::Predicate Pred;
-  uint64_t C;
+  const APInt *Mask, *C;
+  uint64_t ShAmt;
   switch (Cmp->getPredicate()) {
   case ICmpInst::ICMP_EQ:
-    // assume(v = a)
-    if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      Known = Known.unionWith(RHSKnown);
-      // assume(v & b = a)
+    // assume(V = C)
+    if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
+      Known = Known.unionWith(KnownBits::makeConstant(*C));
+      // assume(V & Mask = C)
     } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits MaskKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in the mask that are known to be one, we can propagate
-      // known bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & MaskKnown.One;
-      Known.One |= RHSKnown.One & MaskKnown.One;
-      // assume(v | b = a)
+                     m_ICmp(Pred, m_And(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For one bits in Mask, we can propagate bits from C to V.
+      Known.Zero |= ~*C & *Mask;
+      Known.One |= *C & *Mask;
+      // assume(V | Mask = C)
     } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in B that are known to be zero, we can propagate known
-      // bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & BKnown.Zero;
-      Known.One |= RHSKnown.One & BKnown.Zero;
-      // assume(v ^ b = a)
+                     m_ICmp(Pred, m_Or(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For zero bits in Mask, we can propagate bits from C to V.
+      Known.Zero |= ~*C & ~*Mask;
+      Known.One |= *C & ~*Mask;
+      // assume(V ^ Mask = C)
     } else if (match(Cmp,
-                     m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A)))) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-      KnownBits BKnown = computeKnownBits(B, Depth + 1, QueryNoAC);
-
-      // For those bits in B that are known to be zero, we can propagate known
-      // bits from the RHS to V. For those bits in B that are known to be one,
-      // we can propagate inverted known bits from the RHS to V.
-      Known.Zero |= RHSKnown.Zero & BKnown.Zero;
-      Known.One |= RHSKnown.One & BKnown.Zero;
-      Known.Zero |= RHSKnown.One & BKnown.One;
-      Known.One |= RHSKnown.Zero & BKnown.One;
-      // assume(v << c = a)
-    } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)),
-                                   m_Value(A))) &&
-               C < BitWidth) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
-
-      // For those bits in RHS that are known, we can propagate them to known
-      // bits in V shifted to the right by C.
-      RHSKnown.Zero.lshrInPlace(C);
-      RHSKnown.One.lshrInPlace(C);
+                     m_ICmp(Pred, m_Xor(m_V, m_APInt(Mask)), m_APInt(C)))) {
+      // For those bits in Mask that are zero, we can propagate known bits
+      // from C to V. For those bits in Mask that are one, we can propagate
+      // inverted bits from C to V.
+      Known.Zero |= ~*C & ~*Mask;
+      Known.One |= *C & ~*Mask;
+      Known.Zero |= *C & *Mask;
+      Known.One |= ~*C & *Mask;
+      // assume(V << ShAmt = C)
+    } else if (match(Cmp, m_ICmp(Pred, m_Shl(m_V, m_ConstantInt(ShAmt)),
+                                 m_APInt(C))) &&
+               ShAmt < BitWidth) {
+      // For those bits in C that are known, we can propagate them to known
+      // bits in V shifted to the right by ShAmt.
+      KnownBits RHSKnown = KnownBits::makeConstant(*C);
+      RHSKnown.Zero.lshrInPlace(ShAmt);
+      RHSKnown.One.lshrInPlace(ShAmt);
       Known = Known.unionWith(RHSKnown);
-      // assume(v >> c = a)
-    } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)),
-                                   m_Value(A))) &&
-               C < BitWidth) {
-      KnownBits RHSKnown = computeKnownBits(A, Depth + 1, QueryNoAC);
+      // assume(V >> ShAmt = C)
+    } else if (match(Cmp, m_ICmp(Pred, m_Shr(m_V, m_ConstantInt(ShAmt)),
+                                 m_APInt(C))) &&
+               ShAmt < BitWidth) {
+      KnownBits RHSKnown = KnownBits::makeConstant(*C);
       // For those bits in RHS that are known, we can propagate them to known
       // bits in V shifted to the right by C.
-      Known.Zero |= RHSKnown.Zero << C;
-      Known.One |= RHSKnown.One << C;
+      Known.Zero |= RHSKnown.Zero << ShAmt;
+      Known.One |= RHSKnown.One << ShAmt;
     }
     break;
   case ICmpInst::ICMP_NE: {
-    // assume (v & b != 0) where b is a power of 2
+    // assume (V & B != 0) where B is a power of 2
     const APInt *BPow2;
-    if (match(Cmp, m_ICmp(Pred, m_c_And(m_V, m_Power2(BPow2)), m_Zero()))) {
+    if (match(Cmp, m_ICmp(Pred, m_And(m_V, m_Power2(BPow2)), m_Zero())))
       Known.One |= *BPow2;
-    }
     break;
   }
   default:
+    Value *A;
     const APInt *Offset = nullptr;
     if (match(Cmp, m_ICmp(Pred, m_CombineOr(m_V, m_Add(m_V, m_APInt(Offset))),
                           m_Value(A)))) {
diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll
index 934e9594f3f7b5a..5b2039b5b480501 100644
--- a/llvm/test/Transforms/InstCombine/assume.ll
+++ b/llvm/test/Transforms/InstCombine/assume.ll
@@ -265,7 +265,7 @@ define i32 @bundle2(ptr %P) {
 
 define i1 @nonnull1(ptr %a) {
 ; CHECK-LABEL: @nonnull1(
-; CHECK-NEXT:    [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull !6, !noundef !6
+; CHECK-NEXT:    [[LOAD:%.*]] = load ptr, ptr [[A:%.*]], align 8, !nonnull [[META6:![0-9]+]], !noundef [[META6]]
 ; CHECK-NEXT:    tail call void @escape(ptr nonnull [[LOAD]])
 ; CHECK-NEXT:    ret i1 false
 ;
@@ -383,7 +383,7 @@ define i1 @nonnull5(ptr %a) {
 define i32 @assumption_conflicts_with_known_bits(i32 %a, i32 %b) {
 ; CHECK-LABEL: @assumption_conflicts_with_known_bits(
 ; CHECK-NEXT:    store i1 true, ptr poison, align 1
-; CHECK-NEXT:    ret i32 poison
+; CHECK-NEXT:    ret i32 1
 ;
   %and1 = and i32 %b, 3
   %B1 = lshr i32 %and1, %and1
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 78ac730cf026ed9..d49cb79e1e27c98 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -4016,9 +4016,10 @@ define i32 @abs_preserve(i32 %x) {
 declare void @llvm.assume(i1)
 define i1 @PR35794(ptr %a) {
 ; CHECK-LABEL: @PR35794(
-; CHECK-NEXT:    [[MASKCOND:%.*]] = icmp eq ptr [[A:%.*]], null
+; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt ptr [[A:%.*]], inttoptr (i64 -1 to ptr)
+; CHECK-NEXT:    [[MASKCOND:%.*]] = icmp eq ptr [[A]], null
 ; CHECK-NEXT:    tail call void @llvm.assume(i1 [[MASKCOND]])
-; CHECK-NEXT:    ret i1 true
+; CHECK-NEXT:    ret i1 [[CMP]]
 ;
   %cmp = icmp sgt ptr %a, inttoptr (i64 -1 to ptr)
   %maskcond = icmp eq ptr %a, null

Copy link

github-actions bot commented Nov 17, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

llvm/lib/Analysis/ValueTracking.cpp Outdated Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Show resolved Hide resolved
llvm/lib/Analysis/ValueTracking.cpp Show resolved Hide resolved
@nikic
Copy link
Contributor Author

nikic commented Nov 20, 2023

I've made two additional updates here:

  • Also drop non-constant RHS support for plain (non-mask) comparison. When using computeKnownBitsFromCmp() for dominating conditions, I found that supporting non-constant RHS here has significant compile-time impact but very little impact on results.
  • Explicitly handle the comparison with null pointer case. I'm not sure this is really important, but there were additional test changes when implementing the previous point, so I just did it to be safe.

// assume(v | b = a)
// assume(V = C)
if (match(Cmp, m_ICmp(Pred, m_V, m_APInt(C)))) {
Known = Known.unionWith(KnownBits::makeConstant(*C));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why unionWith as opposed to known = KnownBits::makeConstant(*C)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just preserves the existing behavior wrt conflicts. I can drop the unionWith if we don't care about that.

@goldsteinn
Copy link
Contributor

Some of the simplifications here make sense independent of general -> constant simplification. But for the general -> constant simplification, imo it makes more sense to add coverage as opposed to just dropping the code.
Particularly a few cases like X == Y, think its simple enough to support non-constant Y and can imagine practical value from that type of analysis.

@nikic
Copy link
Contributor Author

nikic commented Nov 20, 2023

Some of the simplifications here make sense independent of general -> constant simplification. But for the general -> constant simplification, imo it makes more sense to add coverage as opposed to just dropping the code. Particularly a few cases like X == Y, think its simple enough to support non-constant Y and can imagine practical value from that type of analysis.

The context of this change is that I want to handle dominating conditions in computeKnownBits(), and would like it to use the same logic as computeKnownBitsFromCmp() currently used for assumes. Currently, we don't really feel the compile-time impact of this function, because assumes are very rare when compiling C/C++. When handling dominating conditions, the support for non-constants adds quite a lot of compile-time overhead without any significant effect on results (I checked).

I could use completely independent logic for assumes and dominating conditions, but I don't particularly want to.

@goldsteinn
Copy link
Contributor

Some of the simplifications here make sense independent of general -> constant simplification. But for the general -> constant simplification, imo it makes more sense to add coverage as opposed to just dropping the code. Particularly a few cases like X == Y, think its simple enough to support non-constant Y and can imagine practical value from that type of analysis.

The context of this change is that I want to handle dominating conditions in computeKnownBits(), and would like it to use the same logic as computeKnownBitsFromCmp() currently used for assumes. Currently, we don't really feel the compile-time impact of this function, because assumes are very rare when compiling C/C++. When handling dominating conditions, the support for non-constants adds quite a lot of compile-time overhead without any significant effect on results (I checked).

I could use completely independent logic for assumes and dominating conditions, but I don't particularly want to.

I see. That seems like a reasonable motivation.
Although I'd think a few simple cases (like X == Y) would provide enough value to justify the cost.

@nikic
Copy link
Contributor Author

nikic commented Nov 20, 2023

I see. That seems like a reasonable motivation. Although I'd think a few simple cases (like X == Y) would provide enough value to justify the cost.

The X == Y case exists for completeness more than anything: For plain equalities, GVN will replace X with Y (or vice versa), independently of whether Y is constant.

@goldsteinn
Copy link
Contributor

I see. That seems like a reasonable motivation. Although I'd think a few simple cases (like X == Y) would provide enough value to justify the cost.

The X == Y case exists for completeness more than anything: For plain equalities, GVN will replace X with Y (or vice versa), independently of whether Y is constant.

Ahh, you're right. Good point.

@nikic
Copy link
Contributor Author

nikic commented Nov 27, 2023

ping

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@nikic nikic merged commit 28a5e6b into llvm:main Nov 27, 2023
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants