Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine more examples to new Checked matcher API #91097

Closed
wants to merge 1 commit into from

Conversation

AtariDreams
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Collaborator

llvmbot commented May 4, 2024

@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-backend-x86

@llvm/pr-subscribers-llvm-transforms

Author: AtariDreams (AtariDreams)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/91097.diff

5 Files Affected:

  • (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+35-21)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+6-5)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+5-3)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+2-2)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp (+4-2)
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 4061dae83c10f3..6918263715dfc9 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -1028,33 +1028,43 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
     // Make sure that a constant is not the minimum signed value because taking
     // the abs() of that is undefined.
     Type *Ty = X->getType();
-    const APInt *C;
-    if (match(X, m_APInt(C)) && !C->isMinSignedValue()) {
-      // Is the variable divisor magnitude always greater than the constant
-      // dividend magnitude?
-      // |Y| > |C| --> Y < -abs(C) or Y > abs(C)
-      Constant *PosDividendC = ConstantInt::get(Ty, C->abs());
-      Constant *NegDividendC = ConstantInt::get(Ty, -C->abs());
+
+    // Is the variable divisor magnitude always greater than the constant
+    // dividend magnitude?
+    // |Y| > |C| --> Y < -abs(C) or Y > abs(C)
+    auto CheckSignCmp = [Ty, Y, Q, MaxRecurse](const APInt &C) {
+      if (C.isMinSignedValue())
+        return false;
+      Constant *PosDividendC = ConstantInt::get(Ty, C.abs());
+      Constant *NegDividendC = ConstantInt::get(Ty, -C.abs());
       if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) ||
           isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse))
         return true;
-    }
-    if (match(Y, m_APInt(C))) {
+      return false;
+    };
+
+    auto CheckSignCmpY = [Ty, X, Y, Q, MaxRecurse](const APInt &C) {
       // Special-case: we can't take the abs() of a minimum signed value. If
       // that's the divisor, then all we have to do is prove that the dividend
       // is also not the minimum signed value.
-      if (C->isMinSignedValue())
+      if (C.isMinSignedValue())
         return isICmpTrue(CmpInst::ICMP_NE, X, Y, Q, MaxRecurse);
 
       // Is the variable dividend magnitude always less than the constant
       // divisor magnitude?
       // |X| < |C| --> X > -abs(C) and X < abs(C)
-      Constant *PosDivisorC = ConstantInt::get(Ty, C->abs());
-      Constant *NegDivisorC = ConstantInt::get(Ty, -C->abs());
-      if (isICmpTrue(CmpInst::ICMP_SGT, X, NegDivisorC, Q, MaxRecurse) &&
-          isICmpTrue(CmpInst::ICMP_SLT, X, PosDivisorC, Q, MaxRecurse))
+      Constant *PosDividendC = ConstantInt::get(Ty, C.abs());
+      Constant *NegDividendC = ConstantInt::get(Ty, -C.abs());
+      if (isICmpTrue(CmpInst::ICMP_SLT, Y, NegDividendC, Q, MaxRecurse) ||
+          isICmpTrue(CmpInst::ICMP_SGT, Y, PosDividendC, Q, MaxRecurse))
         return true;
-    }
+      return false;
+    };
+
+    if (match(X, m_CheckedInt(CheckSignCmp)))
+      return true;
+    if (match(Y, m_CheckedInt(CheckSignCmpY)))
+      return true;
     return false;
   }
 
@@ -1063,9 +1073,11 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
   // Is the unsigned dividend known to be less than a constant divisor?
   // TODO: Convert this (and above) to range analysis
   //      ("computeConstantRangeIncludingKnownBits")?
-  const APInt *C;
-  if (match(Y, m_APInt(C)) &&
-      computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
+
+  auto CheckULT1 = [X, Q](const APInt &C) {
+    return computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(C);
+  };
+  if (match(Y, m_CheckedInt(CheckULT1)))
     return true;
 
   // Try again for any divisor:
@@ -2363,14 +2375,16 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
   // (-1 >> X) | (-1 << (C - X)) --> -1
   // ...with C <= bitwidth (and commuted variants).
   Value *X, *Y;
+  auto CheckULE = [X](const APInt &C) {
+    return C.ule(X->getType()->getScalarSizeInBits());
+  };
   if ((match(Op0, m_Shl(m_AllOnes(), m_Value(X))) &&
        match(Op1, m_LShr(m_AllOnes(), m_Value(Y)))) ||
       (match(Op1, m_Shl(m_AllOnes(), m_Value(X))) &&
        match(Op0, m_LShr(m_AllOnes(), m_Value(Y))))) {
     const APInt *C;
-    if ((match(X, m_Sub(m_APInt(C), m_Specific(Y))) ||
-         match(Y, m_Sub(m_APInt(C), m_Specific(X)))) &&
-        C->ule(X->getType()->getScalarSizeInBits())) {
+    if (match(X, m_Sub(m_CheckedInt(CheckULE), m_Specific(Y))) ||
+        match(Y, m_Sub(m_CheckedInt(CheckULE), m_Specific(X)))) {
       return ConstantInt::getAllOnesValue(X->getType());
     }
   }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index cf4a64ffded2e8..2b2df55f970abe 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -30447,11 +30447,12 @@ static std::pair<Value *, BitTestKind> FindSingleBitChange(Value *V) {
       Value *BitV = I->getOperand(1);
 
       Value *AndOp;
-      const APInt *AndC;
-      if (match(BitV, m_c_And(m_Value(AndOp), m_APInt(AndC)))) {
-        // Read past a shiftmask instruction to find count
-        if (*AndC == (I->getType()->getPrimitiveSizeInBits() - 1))
-          BitV = AndOp;
+      // Read past a shiftmask instruction to find count
+      auto *IsMask = [I](const APInt &AndC) {
+        return AndC == I->getType()->getPrimitiveSizeInBits() - 1;
+      };
+      if (match(BitV, m_c_And(m_Value(AndOp), m_CheckedInt(IsMask)))) {
+        BitV = AndOp;
       }
       return {BitV, BTK};
     }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 51ac77348ed9e3..c2327c405267a0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1762,6 +1762,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   // -->
   // BW - ctlz(A - 1, false)
   const APInt *XorC;
+  auto CheckBW = [A](const APInt &XorC) {
+    return XorC == A->getType()->getScalarSizeInBits() - 1;
+  };
   if (match(&I,
             m_c_Add(
                 m_ZExt(m_ICmp(Pred, m_Intrinsic<Intrinsic::ctpop>(m_Value(A)),
@@ -1769,9 +1772,8 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
                 m_OneUse(m_ZExtOrSelf(m_OneUse(m_Xor(
                     m_OneUse(m_TruncOrSelf(m_OneUse(
                         m_Intrinsic<Intrinsic::ctlz>(m_Deferred(A), m_One())))),
-                    m_APInt(XorC))))))) &&
-      (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE) &&
-      *XorC == A->getType()->getScalarSizeInBits() - 1) {
+                    m_CheckedInt(CheckBW))))))) &&
+      (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_NE)) {
     Value *Sub = Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()));
     Value *Ctlz = Builder.CreateIntrinsic(Intrinsic::ctlz, {A->getType()},
                                           {Sub, Builder.getFalse()});
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 11e31877de38c2..bc6c9fd7deeaf5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -764,8 +764,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
     }
 
     {
-      const APInt *C;
-      if (match(Src, m_Shl(m_APInt(C), m_Value(X))) && (*C)[0] == 1) {
+      auto CheckOdd = [](const APInt &C) { return (C)[0] == 1; };
+      if (match(Src, m_Shl(m_CheckedInt(CheckOdd), m_Value(X)))) {
         // trunc (C << X) to i1 --> X == 0, where C is odd
         return new ICmpInst(ICmpInst::Predicate::ICMP_EQ, X, Zero);
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 99f1f8eb34bb5a..8f037c136be933 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -2071,8 +2071,10 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) {
   }
   case Instruction::Or: {
     // or X, C --> add X, C (when X and C have no common bits set)
-    const APInt *C;
-    if (match(BO1, m_APInt(C)) && MaskedValueIsZero(BO0, *C, DL))
+    auto *CheckMaskedValIsZero = [BO0, DL](const APInt &C) {
+      return MaskedValueIsZero(BO0, C, DL);
+    };
+    if (match(BO1, m_CheckedInt(CheckMaskedValIsZero)))
       return {Instruction::Add, BO0, BO1};
     break;
   }

@AtariDreams AtariDreams marked this pull request as draft May 4, 2024 23:37
@AtariDreams AtariDreams force-pushed the checkedint branch 2 times, most recently from 185c775 to ff54fa3 Compare May 5, 2024 00:11
@AtariDreams AtariDreams marked this pull request as ready for review May 5, 2024 00:14
Copy link

github-actions bot commented May 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants