Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Add helper simplifying Instruction w/ constants with eq/ne Constants #86346

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Mar 22, 2024

  • Pre-commit: add test
  • Fold 'switch(rol(x, C1))'
  • [InstCombine] Add tests for simplifying Instruction w/ constants with eq/ne Constants; NFC
  • [InstCombine] Add helper simplifying Instruction w/ constants with eq/ne Constants; NFC
  • [InstCombine] Use simplifyOpWithConstantEqConsts in foldICmpEquality
  • [InstCombine] Use simplifyOpWithConstantEqConsts in visitSwitchInst

This is based on #86307, I will rebase once #86307 gets merged.

Proofs for the mul/div cases: https://alive2.llvm.org/ce/z/KBdR38

YanWQ-monad and others added 6 commits March 22, 2024 16:42
…/ne Constants; NFC

We want to be able to simplify these types of relationships in
multiple places (simplify `icmp eq/ne` and `switch` statements) so it
makes sense to roll all our logic into a single helper.
I dropped the obviously redundant cases, but we still have some
sitting around.
Replaces bespoke logic with logic shared by `visitICmp`. The new logic
is more complete.
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

Changes
  • Pre-commit: add test
  • Fold 'switch(rol(x, C1))'
  • [InstCombine] Add tests for simplifying Instruction w/ constants with eq/ne Constants; NFC
  • [InstCombine] Add helper simplifying Instruction w/ constants with eq/ne Constants; NFC
  • [InstCombine] Use simplifyOpWithConstantEqConsts in foldICmpEquality
  • [InstCombine] Use simplifyOpWithConstantEqConsts in visitSwitchInst

This is based on #86307, I will rebase once #86307 gets merged.


Patch is 64.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86346.diff

12 Files Affected:

  • (modified) llvm/include/llvm/Transforms/InstCombine/InstCombiner.h (+18)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+8-24)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+230-60)
  • (modified) llvm/test/Transforms/InstCombine/2009-02-20-InstCombine-SROA.ll (+8-8)
  • (modified) llvm/test/Transforms/InstCombine/icmp-add.ll (+2-4)
  • (modified) llvm/test/Transforms/InstCombine/icmp-equality-xor.ll (+1-2)
  • (modified) llvm/test/Transforms/InstCombine/icmp-sub.ll (+2-4)
  • (modified) llvm/test/Transforms/InstCombine/narrow-switch.ll (+8-8)
  • (modified) llvm/test/Transforms/InstCombine/prevent-cmp-merge.ll (+6-6)
  • (added) llvm/test/Transforms/InstCombine/simplify-cmp-eq.ll (+1606)
  • (modified) llvm/test/Transforms/InstCombine/switch-constant-expr.ll (+4-4)
  • (added) llvm/test/Transforms/InstCombine/switch-rol.ll (+63)
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 93090431cbb69f..4b89ea73f16e37 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -198,6 +198,24 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
                                                 PatternMatch::m_Value()));
   }
 
+  /// Assumes that we have `Op eq/ne Vals` (either icmp or switch). Will try to
+  /// constant fold `Vals` so that we can use `Op' eq/ne Vals'`. For example if
+  /// we have `Op` as `add X, C0`, it will simplify all `Vals` as `Vals[i] - C0`
+  /// and return `X`.
+  Value *simplifyOpWithConstantEqConsts(Value *Op, BuilderTy &Builder,
+                                        SmallVector<Constant *> &Vals,
+                                        bool ReqOneUseAdd = true);
+
+  Value *simplifyOpWithConstantEqConsts(Value *Op, BuilderTy &Builder,
+                                        Constant *&Val,
+                                        bool ReqOneUseAdd = true) {
+    SmallVector<Constant *> CVals;
+    CVals.push_back(Val);
+    Value *R = simplifyOpWithConstantEqConsts(Op, Builder, CVals, ReqOneUseAdd);
+    Val = CVals[0];
+    return R;
+  }
+
   /// Return nonnull value if V is free to invert under the condition of
   /// WillInvertAllUses.
   /// If Builder is nonnull, it will return a simplified ~V.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index db302d7e526844..19a59bcc96d506 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3572,16 +3572,6 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
       return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C));
     break;
 
-  case Intrinsic::bswap:
-    // bswap(A) == C  ->  A == bswap(C)
-    return new ICmpInst(Pred, II->getArgOperand(0),
-                        ConstantInt::get(Ty, C.byteSwap()));
-
-  case Intrinsic::bitreverse:
-    // bitreverse(A) == C  ->  A == bitreverse(C)
-    return new ICmpInst(Pred, II->getArgOperand(0),
-                        ConstantInt::get(Ty, C.reverseBits()));
-
   case Intrinsic::ctlz:
   case Intrinsic::cttz: {
     // ctz(A) == bitwidth(A)  ->  A == 0 and likewise for !=
@@ -3618,20 +3608,6 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant(
     break;
   }
 
-  case Intrinsic::fshl:
-  case Intrinsic::fshr:
-    if (II->getArgOperand(0) == II->getArgOperand(1)) {
-      const APInt *RotAmtC;
-      // ror(X, RotAmtC) == C --> X == rol(C, RotAmtC)
-      // rol(X, RotAmtC) == C --> X == ror(C, RotAmtC)
-      if (match(II->getArgOperand(2), m_APInt(RotAmtC)))
-        return new ICmpInst(Pred, II->getArgOperand(0),
-                            II->getIntrinsicID() == Intrinsic::fshl
-                                ? ConstantInt::get(Ty, C.rotr(*RotAmtC))
-                                : ConstantInt::get(Ty, C.rotl(*RotAmtC)));
-    }
-    break;
-
   case Intrinsic::umax:
   case Intrinsic::uadd_sat: {
     // uadd.sat(a, b) == 0  ->  (a | b) == 0
@@ -5456,6 +5432,14 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
   const CmpInst::Predicate Pred = I.getPredicate();
+  {
+    Constant *C;
+    if (match(Op1, m_ImmConstant(C))) {
+      if (auto *R = simplifyOpWithConstantEqConsts(Op0, Builder, C))
+        return new ICmpInst(Pred, R, C);
+    }
+  }
+
   Value *A, *B, *C, *D;
   if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) {
     if (A == Op1 || B == Op1) { // (A^B) == A  ->  B == 0
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 7c40fb4fc86082..f4a93434761c63 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3572,78 +3572,248 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
   return nullptr;
 }
 
-Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
-  Value *Cond = SI.getCondition();
-  Value *Op0;
-  ConstantInt *AddRHS;
-  if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) {
-    // Change 'switch (X+4) case 1:' into 'switch (X) case -3'.
-    for (auto Case : SI.cases()) {
-      Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS);
-      assert(isa<ConstantInt>(NewCase) &&
-             "Result of expression should be constant");
-      Case.setValue(cast<ConstantInt>(NewCase));
+Value *
+InstCombiner::simplifyOpWithConstantEqConsts(Value *Op, BuilderTy &Builder,
+                                             SmallVector<Constant *> &Vals,
+                                             bool ReqOneUseAdd) {
+
+  Operator *I = dyn_cast<Operator>(Op);
+  if (!I)
+    return nullptr;
+
+  auto ReverseAll = [&](function_ref<Constant *(Constant *)> ReverseF) {
+    for (size_t i = 0, e = Vals.size(); i < e; ++i) {
+      Vals[i] = ReverseF(Vals[i]);
     }
-    return replaceOperand(SI, 0, Op0);
+  };
+
+  SmallVector<const APInt *, 4> ValsAsAPInt;
+  for (Constant *C : Vals) {
+    const APInt *CAPInt;
+    if (!match(C, m_APInt(CAPInt)))
+      break;
+    ValsAsAPInt.push_back(CAPInt);
   }
+  bool UseAPInt = ValsAsAPInt.size() == Vals.size();
 
-  ConstantInt *SubLHS;
-  if (match(Cond, m_Sub(m_ConstantInt(SubLHS), m_Value(Op0)))) {
-    // Change 'switch (1-X) case 1:' into 'switch (X) case 0'.
-    for (auto Case : SI.cases()) {
-      Constant *NewCase = ConstantExpr::getSub(SubLHS, Case.getCaseValue());
-      assert(isa<ConstantInt>(NewCase) &&
-             "Result of expression should be constant");
-      Case.setValue(cast<ConstantInt>(NewCase));
+  auto ReverseAllAPInt = [&](function_ref<APInt(const APInt *)> ReverseF) {
+    assert(UseAPInt && "Can't reverse non-apint constants!");
+    for (size_t i = 0, e = Vals.size(); i < e; ++i) {
+      Vals[i] = ConstantInt::get(Vals[i]->getType(), ReverseF(ValsAsAPInt[i]));
     }
-    return replaceOperand(SI, 0, Op0);
-  }
-
-  uint64_t ShiftAmt;
-  if (match(Cond, m_Shl(m_Value(Op0), m_ConstantInt(ShiftAmt))) &&
-      ShiftAmt < Op0->getType()->getScalarSizeInBits() &&
-      all_of(SI.cases(), [&](const auto &Case) {
-        return Case.getCaseValue()->getValue().countr_zero() >= ShiftAmt;
-      })) {
-    // Change 'switch (X << 2) case 4:' into 'switch (X) case 1:'.
-    OverflowingBinaryOperator *Shl = cast<OverflowingBinaryOperator>(Cond);
-    if (Shl->hasNoUnsignedWrap() || Shl->hasNoSignedWrap() ||
-        Shl->hasOneUse()) {
-      Value *NewCond = Op0;
-      if (!Shl->hasNoUnsignedWrap() && !Shl->hasNoSignedWrap()) {
-        // If the shift may wrap, we need to mask off the shifted bits.
-        unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
-        NewCond = Builder.CreateAnd(
-            Op0, APInt::getLowBitsSet(BitWidth, BitWidth - ShiftAmt));
-      }
-      for (auto Case : SI.cases()) {
-        const APInt &CaseVal = Case.getCaseValue()->getValue();
-        APInt ShiftedCase = Shl->hasNoSignedWrap() ? CaseVal.ashr(ShiftAmt)
-                                                   : CaseVal.lshr(ShiftAmt);
-        Case.setValue(ConstantInt::get(SI.getContext(), ShiftedCase));
-      }
-      return replaceOperand(SI, 0, NewCond);
+  };
+
+  Constant *C;
+  switch (I->getOpcode()) {
+  default:
+    break;
+  case Instruction::Or:
+    if (!match(I, m_DisjointOr(m_Value(), m_Value())))
+      break;
+    // Can treat `or disjoint` as add
+    [[fallthrough]];
+  case Instruction::Add:
+    // We get some regressions if we drop the OneUse for add in some cases.
+    // See discussion in D58633.
+    if (ReqOneUseAdd && !I->hasOneUse())
+      break;
+    if (!match(I->getOperand(1), m_ImmConstant(C)))
+      break;
+    // X + C0 == C1 -> X == C1 - C0
+    ReverseAll([&](Constant *Val) { return ConstantExpr::getSub(Val, C); });
+    return I->getOperand(0);
+  case Instruction::Sub:
+    if (!match(I->getOperand(0), m_ImmConstant(C)))
+      break;
+    // C0 - X == C1 -> X == C0 - C1
+    ReverseAll([&](Constant *Val) { return ConstantExpr::getSub(C, Val); });
+    return I->getOperand(1);
+  case Instruction::Xor:
+    if (!match(I->getOperand(1), m_ImmConstant(C)))
+      break;
+    // X ^ C0 == C1 -> X == C1 ^ C0
+    ReverseAll([&](Constant *Val) { return ConstantExpr::getXor(Val, C); });
+    return I->getOperand(0);
+  case Instruction::Mul: {
+    const APInt *MC;
+    if (!UseAPInt || !match(I->getOperand(1), m_APInt(MC)) || MC->isZero())
+      break;
+    OverflowingBinaryOperator *Mul = cast<OverflowingBinaryOperator>(I);
+    if (!Mul->hasNoUnsignedWrap())
+      break;
+
+    // X nuw C0 == C1 -> X == C1 u/ C0 iff C1 u% C0 == 0
+    if (all_of(ValsAsAPInt,
+               [&](const APInt * AC) { return AC->urem(*MC).isZero(); })) {
+      ReverseAllAPInt([&](const APInt *Val) { return Val->udiv(*MC); });
+      return I->getOperand(0);
     }
+
+    // X nuw C0 == C1 -> X == C1 s/ C0 iff C1 s% C0 == 0
+    if (all_of(ValsAsAPInt, [&](const APInt * AC) {
+          return (!AC->isMinSignedValue() || !MC->isAllOnes()) &&
+                 AC->srem(*MC).isZero();
+        })) {
+      ReverseAllAPInt([&](const APInt *Val) { return Val->sdiv(*MC); });
+      return I->getOperand(0);
+    }
+    break;
   }
+  case Instruction::UDiv:
+  case Instruction::SDiv: {
+    const APInt *DC;
+    if (!UseAPInt)
+      break;
+    if (!UseAPInt || !match(I->getOperand(1), m_APInt(DC)))
+      break;
+    if (!cast<PossiblyExactOperator>(Op)->isExact())
+      break;
+    // X u/ C0 == C1 -> X == C0 * C1 iff C0 * C1 is nuw
+    // X s/ C0 == C1 -> X == C0 * C1 iff C0 * C1 is nsw
+    if (!all_of(ValsAsAPInt, [&](const APInt *AC) {
+          bool Ov;
+          (void)(I->getOpcode() == Instruction::UDiv ? DC->umul_ov(*AC, Ov)
+                                                     : DC->smul_ov(*AC, Ov));
+          return !Ov;
+        }))
+      break;
 
-  // Fold switch(zext/sext(X)) into switch(X) if possible.
-  if (match(Cond, m_ZExtOrSExt(m_Value(Op0)))) {
-    bool IsZExt = isa<ZExtInst>(Cond);
-    Type *SrcTy = Op0->getType();
+    ReverseAllAPInt([&](const APInt *Val) { return (*Val) * (*DC); });
+    return I->getOperand(0);
+  }
+  case Instruction::ZExt:
+  case Instruction::SExt: {
+    if (!UseAPInt)
+      break;
+    bool IsZExt = isa<ZExtInst>(I);
+    Type *SrcTy = I->getOperand(0)->getType();
     unsigned NewWidth = SrcTy->getScalarSizeInBits();
+    // zext(X) == C1 -> X == trunc C1 iff zext(trunc(C1)) == C1
+    // sext(X) == C1 -> X == trunc C1 iff sext(trunc(C1)) == C1
+    if (!all_of(ValsAsAPInt, [&](const APInt *AC) {
+          return IsZExt ? AC->isIntN(NewWidth) : AC->isSignedIntN(NewWidth);
+        }))
+      break;
 
-    if (all_of(SI.cases(), [&](const auto &Case) {
-          const APInt &CaseVal = Case.getCaseValue()->getValue();
-          return IsZExt ? CaseVal.isIntN(NewWidth)
-                        : CaseVal.isSignedIntN(NewWidth);
-        })) {
-      for (auto &Case : SI.cases()) {
-        APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth);
-        Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase));
+    for (size_t i = 0, e = Vals.size(); i < e; ++i) {
+      Vals[i] = ConstantInt::get(SrcTy, ValsAsAPInt[i]->trunc(NewWidth));
+    }
+    return I->getOperand(0);
+  }
+  case Instruction::Shl:
+  case Instruction::LShr:
+  case Instruction::AShr: {
+    if (!UseAPInt)
+      break;
+    uint64_t ShAmtC;
+    if (!match(I->getOperand(1), m_ConstantInt(ShAmtC)))
+      break;
+    if (ShAmtC >= I->getType()->getScalarSizeInBits())
+      break;
+
+    // X << C0 == C1 -> X == C1 >> C0 iff C1 >> C0 is exact
+    // X u>> C0 == C1 -> X == C1 << C0 iff C1 << C0 is nuw
+    // X s>> C0 == C1 -> X == C1 << C0 iff C1 << C0 is nsw
+    if (!all_of(ValsAsAPInt, [&](const APInt *AC) {
+          switch (I->getOpcode()) {
+          case Instruction::Shl:
+            return AC->countr_zero() >= ShAmtC;
+          case Instruction::LShr:
+            return AC->countl_zero() >= ShAmtC;
+          case Instruction::AShr:
+            return AC->getNumSignBits() >= ShAmtC;
+            return false;
+          default:
+            llvm_unreachable("Already checked Opcode");
+          }
+        }))
+      break;
+
+    bool HasExact = false, HasNUW = false, HasNSW = false;
+    if (I->getOpcode() == Instruction::Shl) {
+      OverflowingBinaryOperator *Shl = cast<OverflowingBinaryOperator>(I);
+      HasNUW = Shl->hasNoUnsignedWrap();
+      HasNSW = Shl->hasNoSignedWrap();
+    } else {
+      HasExact = cast<PossiblyExactOperator>(Op)->isExact();
+    }
+
+    Value *R = I->getOperand(0);
+    if (!HasExact && !HasNUW && !HasNSW) {
+      if (!I->hasOneUse())
+        break;
+
+      // We may be shifting out 1s from X, so need to mask it.
+      unsigned BitWidth = R->getType()->getScalarSizeInBits();
+      R = Builder.CreateAnd(
+          R, I->getOpcode() == Instruction::Shl
+                 ? APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)
+                 : APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC));
+    }
+
+    ReverseAllAPInt([&](const APInt *Val) {
+      if (I->getOpcode() == Instruction::Shl)
+        return HasNSW ? Val->ashr(ShAmtC) : Val->lshr(ShAmtC);
+      return Val->shl(ShAmtC);
+    });
+    return R;
+  }
+  case Instruction::Call:
+  case Instruction::Invoke: {
+    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+      switch (II->getIntrinsicID()) {
+      default:
+        break;
+      case Intrinsic::bitreverse:
+        if (!UseAPInt)
+          break;
+        // bitreverse(X) == C -> X == bitreverse(C)
+        ReverseAllAPInt([&](const APInt *Val) { return Val->reverseBits(); });
+        return II->getArgOperand(0);
+      case Intrinsic::bswap:
+        if (!UseAPInt)
+          break;
+        // bswap(X) == C -> X == bswap(C)
+        ReverseAllAPInt([&](const APInt *Val) { return Val->byteSwap(); });
+        return II->getArgOperand(0);
+      case Intrinsic::fshr:
+      case Intrinsic::fshl: {
+        if (!UseAPInt)
+          break;
+        if (II->getArgOperand(0) != II->getArgOperand(1))
+          break;
+        const APInt *RotAmtC;
+        if (!match(II->getArgOperand(2), m_APInt(RotAmtC)))
+          break;
+        // rol(X, C0) == C1 -> X == ror(C0, C1)
+        // ror(X, C0) == C1 -> X == rol(C0, C1)
+        ReverseAllAPInt([&](const APInt *Val) {
+          return II->getIntrinsicID() == Intrinsic::fshl ? Val->rotr(*RotAmtC)
+                                                         : Val->rotl(*RotAmtC);
+        });
+        return II->getArgOperand(0);
+      }
       }
-      return replaceOperand(SI, 0, Op0);
     }
   }
+  }
+  return nullptr;
+}
+
+Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
+  Value *Cond = SI.getCondition();
+
+  SmallVector<Constant *> CaseVals;
+  for (const auto &Case : SI.cases())
+    CaseVals.push_back(Case.getCaseValue());
+
+  if (auto *R = simplifyOpWithConstantEqConsts(Cond, Builder, CaseVals,
+                                               /*ReqOneUseAdd=*/false)) {
+    unsigned i = 0;
+    for (auto &Case : SI.cases())
+      Case.setValue(cast<ConstantInt>(CaseVals[i++]));
+    return replaceOperand(SI, 0, R);
+  }
 
   KnownBits Known = computeKnownBits(Cond, 0, &SI);
   unsigned LeadingKnownZeros = Known.countMinLeadingZeros();
diff --git a/llvm/test/Transforms/InstCombine/2009-02-20-InstCombine-SROA.ll b/llvm/test/Transforms/InstCombine/2009-02-20-InstCombine-SROA.ll
index b532c815567389..978b4e29ed628a 100644
--- a/llvm/test/Transforms/InstCombine/2009-02-20-InstCombine-SROA.ll
+++ b/llvm/test/Transforms/InstCombine/2009-02-20-InstCombine-SROA.ll
@@ -92,11 +92,11 @@ define ptr @_Z3fooRSt6vectorIiSaIiEE(ptr %X) {
 ; IC-NEXT:    [[TMP37:%.*]] = load ptr, ptr [[__FIRST_ADDR_I_I]], align 4
 ; IC-NEXT:    [[TMP38:%.*]] = ptrtoint ptr [[TMP37]] to i32
 ; IC-NEXT:    [[TMP39:%.*]] = sub i32 [[TMP36]], [[TMP38]]
-; IC-NEXT:    [[TMP40:%.*]] = ashr i32 [[TMP39]], 2
+; IC-NEXT:    [[TMP40:%.*]] = and i32 [[TMP39]], -4
 ; IC-NEXT:    switch i32 [[TMP40]], label [[BB26_I_I:%.*]] [
-; IC-NEXT:      i32 1, label [[BB22_I_I:%.*]]
-; IC-NEXT:      i32 2, label [[BB18_I_I:%.*]]
-; IC-NEXT:      i32 3, label [[BB14_I_I:%.*]]
+; IC-NEXT:      i32 4, label [[BB22_I_I:%.*]]
+; IC-NEXT:      i32 8, label [[BB18_I_I:%.*]]
+; IC-NEXT:      i32 12, label [[BB14_I_I:%.*]]
 ; IC-NEXT:    ]
 ; IC:       bb14.i.i:
 ; IC-NEXT:    [[TMP41:%.*]] = load ptr, ptr [[__FIRST_ADDR_I_I]], align 4
@@ -199,11 +199,11 @@ define ptr @_Z3fooRSt6vectorIiSaIiEE(ptr %X) {
 ; IC_SROA-NEXT:    [[TMP21:%.*]] = ptrtoint ptr [[TMP1]] to i32
 ; IC_SROA-NEXT:    [[TMP22:%.*]] = ptrtoint ptr [[__FIRST_ADDR_I_I_SROA_0_0]] to i32
 ; IC_SROA-NEXT:    [[TMP23:%.*]] = sub i32 [[TMP21]], [[TMP22]]
-; IC_SROA-NEXT:    [[TMP24:%.*]] = ashr i32 [[TMP23]], 2
+; IC_SROA-NEXT:    [[TMP24:%.*]] = and i32 [[TMP23]], -4
 ; IC_SROA-NEXT:    switch i32 [[TMP24]], label [[BB26_I_I:%.*]] [
-; IC_SROA-NEXT:      i32 1, label [[BB22_I_I:%.*]]
-; IC_SROA-NEXT:      i32 2, label [[BB18_I_I:%.*]]
-; IC_SROA-NEXT:      i32 3, label [[BB14_I_I:%.*]]
+; IC_SROA-NEXT:      i32 4, label [[BB22_I_I:%.*]]
+; IC_SROA-NEXT:      i32 8, label [[BB18_I_I:%.*]]
+; IC_SROA-NEXT:      i32 12, label [[BB14_I_I:%.*]]
 ; IC_SROA-NEXT:    ]
 ; IC_SROA:       bb14.i.i:
 ; IC_SROA-NEXT:    [[TMP25:%.*]] = load i32, ptr [[__FIRST_ADDR_I_I_SROA_0_0]], align 4
diff --git a/llvm/test/Transforms/InstCombine/icmp-add.ll b/llvm/test/Transforms/InstCombine/icmp-add.ll
index b99ed20d7d431c..5caf881a7d6d4f 100644
--- a/llvm/test/Transforms/InstCombine/icmp-add.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-add.ll
@@ -2371,8 +2371,7 @@ define <2 x i1> @icmp_eq_add_non_splat(<2 x i32> %a) {
 
 define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
 ; CHECK-LABEL: @icmp_eq_add_undef2(
-; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[A:%.*]], <i32 5, i32 5>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 undef>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 undef>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %add = add <2 x i32> %a, <i32 5, i32 5>
@@ -2382,8 +2381,7 @@ define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
 
 define <2 x i1> @icmp_eq_add_non_splat2(<2 x i32> %a) {
 ; CHECK-LABEL: @icmp_eq_add_non_splat2(
-; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[A:%.*]], <i32 5, i32 5>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 11>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 6>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %add = add <2 x i32> %a, <i32 5, i32 5>
diff --git a/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll b/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
index f5d5ef32c81e81..f9ba74bcbf7b99 100644
--- a/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-equality-xor.ll
@@ -136,8 +136,7 @@ define i1 @foo2(i32 %x, i32 %y) {
 define <2 x i1> @foo3(<2 x i8> %x) {
 ; CHECK-LABEL: @foo3(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -2, i8 -1>
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i8> [[XOR]], <i8 9, i8 79>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i8> [[X:%.*]], <i8 -9, i8 -80>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
 entry:
diff --git a/llvm/test/Transforms/InstCombine/icmp-sub.ll b/llvm/test/Transforms/InstCombine/icmp-sub.ll
index 5645dededf2e4b..422e8116f1b38c 100644
--- a/llvm/test/Transforms/InstCombine/icmp-sub.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-sub.ll
@@ -164,8 +164,7 @@ define <2 x i1> @icmp_eq_sub_non_splat(<2 x i32> %a) {
 
 define <2 x i1> @icmp_eq_sub_undef2(<2 x i32> %a) {
 ; CHECK-LABEL: @icmp_eq_sub_undef2(
-; CHECK-NEXT:    [[SUB:%.*]] = sub <2 x i32> <i32 15, i32 15>, [[A:%.*]]
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[SUB]], <i32 10, i32 undef>
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 undef>
 ; CHECK-NEXT:    ret <2 x i1> [[CMP]]
 ;
   %sub = sub <2 x i32> <i32 15, i32 15>, %a
@@ -175,8 +174,7 @@ define <2 x i1> @icmp_eq_sub_undef2(<2 x i32...
[truncated]

@goldsteinn goldsteinn changed the title goldsteinn/helper for icmp eq ne [InstCombine] Add helper simplifying Instruction w/ constants with eq/ne Constants Mar 22, 2024
@goldsteinn goldsteinn requested a review from dtcxzyw March 22, 2024 21:48
Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff bbcfe6f4311af8cf6095a5bc5937fa68a87b4289 f4a845a1a67851566a3e760bdbdaadf535875bd9 -- llvm/include/llvm/Transforms/InstCombine/InstCombiner.h llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index f4a9343476..40437f79e7 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -3644,13 +3644,13 @@ InstCombiner::simplifyOpWithConstantEqConsts(Value *Op, BuilderTy &Builder,
 
     // X nuw C0 == C1 -> X == C1 u/ C0 iff C1 u% C0 == 0
     if (all_of(ValsAsAPInt,
-               [&](const APInt * AC) { return AC->urem(*MC).isZero(); })) {
+               [&](const APInt *AC) { return AC->urem(*MC).isZero(); })) {
       ReverseAllAPInt([&](const APInt *Val) { return Val->udiv(*MC); });
       return I->getOperand(0);
     }
 
     // X nuw C0 == C1 -> X == C1 s/ C0 iff C1 s% C0 == 0
-    if (all_of(ValsAsAPInt, [&](const APInt * AC) {
+    if (all_of(ValsAsAPInt, [&](const APInt *AC) {
           return (!AC->isMinSignedValue() || !MC->isAllOnes()) &&
                  AC->srem(*MC).isZero();
         })) {

Copy link

✅ With the latest revision this PR passed the Python code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants