Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ValueTracking] Add more conditions in isTruePredicate #86083

Closed
wants to merge 2 commits into from

Conversation

goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Mar 21, 2024

NB: This comes after #86082, I was adding in the addlike variants and just saw
the function was lacking in general.

There is one notable "regression". This patch replaces the bespoke or disjoint logic with a direct match. This means we fail some
simplification during instsimplify.
All the cases we fail in instsimplify we do handle in instcombine
as we add disjoint flags.

Other than that, just some basic cases.

See proofs: https://alive2.llvm.org/ce/z/_-g7C8

@goldsteinn goldsteinn requested a review from nikic as a code owner March 21, 2024 05:06
@goldsteinn goldsteinn changed the title goldsteinn/vt istruepred [ValueTracking] [ValueTracking] Add more conditions in isTruePredicate Mar 21, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 21, 2024

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-analysis

Author: None (goldsteinn)

Changes
  • [IR] Add helpers for NUWAddLike and NSWAddLike to also match or disjoint; NFC
  • [InstCombine] Add tests for integrating N{U,S}WAddLike; NFC
  • [InstCombine] integrate N{U,S}WAddLike into existing folds
  • [ValueTracking] Add tests for deducing more conditions in isTruePredicate; NFC
  • [ValueTracking] Add more conditions in to isTruePredicate

Patch is 35.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86083.diff

13 Files Affected:

  • (modified) llvm/include/llvm/IR/PatternMatch.h (+22)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+53-35)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+5-3)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+3-2)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+4-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+1-1)
  • (modified) llvm/test/Transforms/InstCombine/add.ll (+76)
  • (modified) llvm/test/Transforms/InstCombine/div.ll (+22)
  • (added) llvm/test/Transforms/InstCombine/implies.ll (+424)
  • (modified) llvm/test/Transforms/InstCombine/sadd-with-overflow.ll (+32)
  • (modified) llvm/test/Transforms/InstCombine/shift-add.ll (+29)
  • (modified) llvm/test/Transforms/InstCombine/uadd-with-overflow.ll (+23)
  • (modified) llvm/test/Transforms/InstSimplify/implies.ll (+14-2)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 382009d9df785d..3e298eff9b56ea 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1202,6 +1202,7 @@ m_NSWAdd(const LHS &L, const RHS &R) {
                                    OverflowingBinaryOperator::NoSignedWrap>(L,
                                                                             R);
 }
+
 template <typename LHS, typename RHS>
 inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Sub,
                                  OverflowingBinaryOperator::NoSignedWrap>
@@ -1235,6 +1236,7 @@ m_NUWAdd(const LHS &L, const RHS &R) {
                                    OverflowingBinaryOperator::NoUnsignedWrap>(
       L, R);
 }
+
 template <typename LHS, typename RHS>
 inline OverflowingBinaryOp_match<LHS, RHS, Instruction::Sub,
                                  OverflowingBinaryOperator::NoUnsignedWrap>
@@ -1319,6 +1321,26 @@ m_AddLike(const LHS &L, const RHS &R) {
   return m_CombineOr(m_Add(L, R), m_DisjointOr(L, R));
 }
 
+/// Match either "add nsw" or "or disjoint"
+template <typename LHS, typename RHS>
+inline match_combine_or<
+    OverflowingBinaryOp_match<LHS, RHS, Instruction::Add,
+                              OverflowingBinaryOperator::NoSignedWrap>,
+    DisjointOr_match<LHS, RHS>>
+m_NSWAddLike(const LHS &L, const RHS &R) {
+  return m_CombineOr(m_NSWAdd(L, R), m_DisjointOr(L, R));
+}
+
+/// Match either "add nuw" or "or disjoint"
+template <typename LHS, typename RHS>
+inline match_combine_or<
+    OverflowingBinaryOp_match<LHS, RHS, Instruction::Add,
+                              OverflowingBinaryOperator::NoUnsignedWrap>,
+    DisjointOr_match<LHS, RHS>>
+m_NUWAddLike(const LHS &L, const RHS &R) {
+  return m_CombineOr(m_NUWAdd(L, R), m_DisjointOr(L, R));
+}
+
 //===----------------------------------------------------------------------===//
 // Class that matches a group of binary opcodes.
 //
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 797665cf06c875..c7c151a1e9cf25 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -8390,8 +8390,7 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
 
 /// Return true if "icmp Pred LHS RHS" is always true.
 static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
-                            const Value *RHS, const DataLayout &DL,
-                            unsigned Depth) {
+                            const Value *RHS, const DataLayout &DL) {
   if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
     return true;
 
@@ -8403,8 +8402,26 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
     const APInt *C;
 
     // LHS s<= LHS +_{nsw} C   if C >= 0
-    if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C))))
+    // LHS s<= LHS | C         if C >= 0
+    if (match(RHS, m_NSWAdd(m_Specific(LHS), m_APInt(C))) ||
+        match(RHS, m_Or(m_Specific(LHS), m_APInt(C))))
       return !C->isNegative();
+
+    // LHS s<= smax(LHS, V) for any V
+    if (match(RHS, m_c_SMax(m_Specific(LHS), m_Value())))
+      return true;
+
+    // smin(RHS, V) s<= RHS for any V
+    if (match(LHS, m_c_SMin(m_Specific(RHS), m_Value())))
+      return true;
+
+    // Match A to (X +_{nsw} CA) and B to (X +_{nsw} CB)
+    const Value *X;
+    const APInt *CLHS, *CRHS;
+    if (match(LHS, m_NSWAddLike(m_Value(X), m_APInt(CLHS))) &&
+        match(RHS, m_NSWAddLike(m_Specific(X), m_APInt(CRHS))))
+      return CLHS->sle(*CRHS);
+
     return false;
   }
 
@@ -8414,34 +8431,36 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
         cast<OverflowingBinaryOperator>(RHS)->hasNoUnsignedWrap())
       return true;
 
+    // LHS u<= LHS | V for any V
+    if (match(RHS, m_c_Or(m_Specific(LHS), m_Value())))
+      return true;
+
+    // LHS u<= umax(LHS, V) for any V
+    if (match(RHS, m_c_UMax(m_Specific(LHS), m_Value())))
+      return true;
+
     // RHS >> V u<= RHS for any V
     if (match(LHS, m_LShr(m_Specific(RHS), m_Value())))
       return true;
 
-    // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
-    auto MatchNUWAddsToSameValue = [&](const Value *A, const Value *B,
-                                       const Value *&X,
-                                       const APInt *&CA, const APInt *&CB) {
-      if (match(A, m_NUWAdd(m_Value(X), m_APInt(CA))) &&
-          match(B, m_NUWAdd(m_Specific(X), m_APInt(CB))))
-        return true;
+    // RHS u/ C_ugt_1 u<= RHS
+    const APInt *C;
+    if (match(LHS, m_UDiv(m_Specific(RHS), m_APInt(C))) && C->ugt(1))
+      return true;
 
-      // If X & C == 0 then (X | C) == X +_{nuw} C
-      if (match(A, m_Or(m_Value(X), m_APInt(CA))) &&
-          match(B, m_Or(m_Specific(X), m_APInt(CB)))) {
-        KnownBits Known(CA->getBitWidth());
-        computeKnownBits(X, Known, DL, Depth + 1, /*AC*/ nullptr,
-                         /*CxtI*/ nullptr, /*DT*/ nullptr);
-        if (CA->isSubsetOf(Known.Zero) && CB->isSubsetOf(Known.Zero))
-          return true;
-      }
+    // RHS & V u<= RHS for any V
+    if (match(LHS, m_c_And(m_Specific(RHS), m_Value())))
+      return true;
 
-      return false;
-    };
+    // umin(RHS, V) u<= RHS for any V
+    if (match(LHS, m_c_UMin(m_Specific(RHS), m_Value())))
+      return true;
 
+    // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
     const Value *X;
     const APInt *CLHS, *CRHS;
-    if (MatchNUWAddsToSameValue(LHS, RHS, X, CLHS, CRHS))
+    if (match(LHS, m_NUWAddLike(m_Value(X), m_APInt(CLHS))) &&
+        match(RHS, m_NUWAddLike(m_Specific(X), m_APInt(CRHS))))
       return CLHS->ule(*CRHS);
 
     return false;
@@ -8454,36 +8473,36 @@ static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
 static std::optional<bool>
 isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
                       const Value *ARHS, const Value *BLHS, const Value *BRHS,
-                      const DataLayout &DL, unsigned Depth) {
+                      const DataLayout &DL) {
   switch (Pred) {
   default:
     return std::nullopt;
 
   case CmpInst::ICMP_SLT:
   case CmpInst::ICMP_SLE:
-    if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth) &&
-        isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth))
+    if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL) &&
+        isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL))
       return true;
     return std::nullopt;
 
   case CmpInst::ICMP_SGT:
   case CmpInst::ICMP_SGE:
-    if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS, DL, Depth) &&
-        isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS, DL, Depth))
+    if (isTruePredicate(CmpInst::ICMP_SLE, ALHS, BLHS, DL) &&
+        isTruePredicate(CmpInst::ICMP_SLE, BRHS, ARHS, DL))
       return true;
     return std::nullopt;
 
   case CmpInst::ICMP_ULT:
   case CmpInst::ICMP_ULE:
-    if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth) &&
-        isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth))
+    if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL) &&
+        isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL))
       return true;
     return std::nullopt;
 
   case CmpInst::ICMP_UGT:
   case CmpInst::ICMP_UGE:
-    if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS, DL, Depth) &&
-        isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS, DL, Depth))
+    if (isTruePredicate(CmpInst::ICMP_ULE, ALHS, BLHS, DL) &&
+        isTruePredicate(CmpInst::ICMP_ULE, BRHS, ARHS, DL))
       return true;
     return std::nullopt;
   }
@@ -8527,7 +8546,7 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
                                               CmpInst::Predicate RPred,
                                               const Value *R0, const Value *R1,
                                               const DataLayout &DL,
-                                              bool LHSIsTrue, unsigned Depth) {
+                                              bool LHSIsTrue) {
   Value *L0 = LHS->getOperand(0);
   Value *L1 = LHS->getOperand(1);
 
@@ -8574,7 +8593,7 @@ static std::optional<bool> isImpliedCondICmps(const ICmpInst *LHS,
     return LPred == RPred;
 
   if (LPred == RPred)
-    return isImpliedCondOperands(LPred, L0, L1, R0, R1, DL, Depth);
+    return isImpliedCondOperands(LPred, L0, L1, R0, R1, DL);
 
   return std::nullopt;
 }
@@ -8636,8 +8655,7 @@ llvm::isImpliedCondition(const Value *LHS, CmpInst::Predicate RHSPred,
   // Both LHS and RHS are icmps.
   const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS);
   if (LHSCmp)
-    return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue,
-                              Depth);
+    return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue);
 
   /// The LHS should be an 'or', 'and', or a 'select' instruction.  We expect
   /// the RHS to be an icmp.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index aaf7184a5562cd..a978e9a643f5e9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -819,7 +819,7 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add,
   Value *X;
   const APInt *C1, *C2;
   if (match(Op1, m_APInt(C1)) &&
-      match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) &&
+      match(Op0, m_OneUse(m_ZExt(m_NUWAddLike(m_Value(X), m_APInt(C2))))) &&
       C1->isNegative() && C1->sge(-C2->sext(C1->getBitWidth()))) {
     Constant *NewC =
         ConstantInt::get(X->getType(), *C2 + C1->trunc(C2->getBitWidth()));
@@ -829,14 +829,16 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add,
   // More general combining of constants in the wide type.
   // (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C)
   Constant *NarrowC;
-  if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) {
+  if (match(Op0,
+            m_OneUse(m_SExt(m_NSWAddLike(m_Value(X), m_Constant(NarrowC)))))) {
     Value *WideC = Builder.CreateSExt(NarrowC, Ty);
     Value *NewC = Builder.CreateAdd(WideC, Op1C);
     Value *WideX = Builder.CreateSExt(X, Ty);
     return BinaryOperator::CreateAdd(WideX, NewC);
   }
   // (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C)
-  if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) {
+  if (match(Op0,
+            m_OneUse(m_ZExt(m_NUWAddLike(m_Value(X), m_Constant(NarrowC)))))) {
     Value *WideC = Builder.CreateZExt(NarrowC, Ty);
     Value *NewC = Builder.CreateAdd(WideC, Op1C);
     Value *WideX = Builder.CreateZExt(X, Ty);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 426b548c074adf..526fe5d080599f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2093,8 +2093,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     Value *Arg0 = II->getArgOperand(0);
     Value *Arg1 = II->getArgOperand(1);
     bool IsSigned = IID == Intrinsic::sadd_with_overflow;
-    bool HasNWAdd = IsSigned ? match(Arg0, m_NSWAdd(m_Value(X), m_APInt(C0)))
-                             : match(Arg0, m_NUWAdd(m_Value(X), m_APInt(C0)));
+    bool HasNWAdd = IsSigned
+                        ? match(Arg0, m_NSWAddLike(m_Value(X), m_APInt(C0)))
+                        : match(Arg0, m_NUWAddLike(m_Value(X), m_APInt(C0)));
     if (HasNWAdd && match(Arg1, m_APInt(C1))) {
       bool Overflow;
       APInt NewC =
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 9d4c271f990d19..345feeedc2707c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1160,14 +1160,14 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
     // We need a multiple of the divisor for a signed add constant, but
     // unsigned is fine with any constant pair.
     if (IsSigned &&
-        match(Op0, m_NSWAdd(m_NSWMul(m_Value(X), m_SpecificInt(*C2)),
-                            m_APInt(C1))) &&
+        match(Op0, m_NSWAddLike(m_NSWMul(m_Value(X), m_SpecificInt(*C2)),
+                                m_APInt(C1))) &&
         isMultiple(*C1, *C2, Quotient, IsSigned)) {
       return BinaryOperator::CreateNSWAdd(X, ConstantInt::get(Ty, Quotient));
     }
     if (!IsSigned &&
-        match(Op0, m_NUWAdd(m_NUWMul(m_Value(X), m_SpecificInt(*C2)),
-                            m_APInt(C1)))) {
+        match(Op0, m_NUWAddLike(m_NUWMul(m_Value(X), m_SpecificInt(*C2)),
+                                m_APInt(C1)))) {
       return BinaryOperator::CreateNUWAdd(X,
                                           ConstantInt::get(Ty, C1->udiv(*C2)));
     }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index eafd2889ec50bd..95aa2119e2d88b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -437,7 +437,7 @@ Instruction *InstCombinerImpl::commonShiftTransforms(BinaryOperator &I) {
   Value *A;
   Constant *C, *C1;
   if (match(Op0, m_Constant(C)) &&
-      match(Op1, m_NUWAdd(m_Value(A), m_Constant(C1)))) {
+      match(Op1, m_NUWAddLike(m_Value(A), m_Constant(C1)))) {
     Value *NewC = Builder.CreateBinOp(I.getOpcode(), C, C1);
     BinaryOperator *NewShiftOp = BinaryOperator::Create(I.getOpcode(), NewC, A);
     if (I.getOpcode() == Instruction::Shl) {
diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll
index 522dcf8db27f42..ec3aca26514caf 100644
--- a/llvm/test/Transforms/InstCombine/add.ll
+++ b/llvm/test/Transforms/InstCombine/add.ll
@@ -3986,5 +3986,81 @@ define i32 @add_reduce_sqr_sum_varC_invalid2(i32 %a, i32 %b) {
   ret i32 %ab2
 }
 
+define i32 @fold_sext_addition_or_disjoint(i8 %x) {
+; CHECK-LABEL: @fold_sext_addition_or_disjoint(
+; CHECK-NEXT:    [[SE:%.*]] = sext i8 [[XX:%.*]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nsw i32 [[SE]], 1246
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %xx = or disjoint i8 %x, 12
+  %se = sext i8 %xx to i32
+  %r = add i32 %se, 1234
+  ret i32 %r
+}
+
+define i32 @fold_sext_addition_fail(i8 %x) {
+; CHECK-LABEL: @fold_sext_addition_fail(
+; CHECK-NEXT:    [[XX:%.*]] = or i8 [[X:%.*]], 12
+; CHECK-NEXT:    [[SE:%.*]] = sext i8 [[XX]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nsw i32 [[SE]], 1234
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %xx = or i8 %x, 12
+  %se = sext i8 %xx to i32
+  %r = add i32 %se, 1234
+  ret i32 %r
+}
+
+define i32 @fold_zext_addition_or_disjoint(i8 %x) {
+; CHECK-LABEL: @fold_zext_addition_or_disjoint(
+; CHECK-NEXT:    [[SE:%.*]] = zext i8 [[XX:%.*]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nuw nsw i32 [[SE]], 1246
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %xx = or disjoint i8 %x, 12
+  %se = zext i8 %xx to i32
+  %r = add i32 %se, 1234
+  ret i32 %r
+}
+
+define i32 @fold_zext_addition_or_disjoint2(i8 %x) {
+; CHECK-LABEL: @fold_zext_addition_or_disjoint2(
+; CHECK-NEXT:    [[XX:%.*]] = add nuw i8 [[X:%.*]], 4
+; CHECK-NEXT:    [[SE:%.*]] = zext i8 [[XX]] to i32
+; CHECK-NEXT:    ret i32 [[SE]]
+;
+  %xx = or disjoint i8 %x, 18
+  %se = zext i8 %xx to i32
+  %r = add i32 %se, -14
+  ret i32 %r
+}
+
+define i32 @fold_zext_addition_fail(i8 %x) {
+; CHECK-LABEL: @fold_zext_addition_fail(
+; CHECK-NEXT:    [[XX:%.*]] = or i8 [[X:%.*]], 12
+; CHECK-NEXT:    [[SE:%.*]] = zext i8 [[XX]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nuw nsw i32 [[SE]], 1234
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %xx = or i8 %x, 12
+  %se = zext i8 %xx to i32
+  %r = add i32 %se, 1234
+  ret i32 %r
+}
+
+define i32 @fold_zext_addition_fail2(i8 %x) {
+; CHECK-LABEL: @fold_zext_addition_fail2(
+; CHECK-NEXT:    [[XX:%.*]] = or i8 [[X:%.*]], 18
+; CHECK-NEXT:    [[SE:%.*]] = zext i8 [[XX]] to i32
+; CHECK-NEXT:    [[R:%.*]] = add nsw i32 [[SE]], -14
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %xx = or i8 %x, 18
+  %se = zext i8 %xx to i32
+  %r = add i32 %se, -14
+  ret i32 %r
+}
+
+
 declare void @llvm.assume(i1)
 declare void @fake_func(i32)
diff --git a/llvm/test/Transforms/InstCombine/div.ll b/llvm/test/Transforms/InstCombine/div.ll
index 1309dee817cf65..e8a25ff44d0296 100644
--- a/llvm/test/Transforms/InstCombine/div.ll
+++ b/llvm/test/Transforms/InstCombine/div.ll
@@ -1810,3 +1810,25 @@ define i6 @udiv_distribute_mul_nsw_add_nuw(i6 %x) {
   %div = udiv i6 %add, 3
   ret i6 %div
 }
+
+define i32 @fold_disjoint_or_over_sdiv(i32 %x) {
+; CHECK-LABEL: @fold_disjoint_or_over_sdiv(
+; CHECK-NEXT:    [[R:%.*]] = add nsw i32 [[X:%.*]], 9
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %mul = mul nsw i32 %x, 9
+  %or = or disjoint i32 %mul, 81
+  %r = sdiv i32 %or, 9
+  ret i32 %r
+}
+
+define i32 @fold_disjoint_or_over_udiv(i32 %x) {
+; CHECK-LABEL: @fold_disjoint_or_over_udiv(
+; CHECK-NEXT:    [[R:%.*]] = add nuw i32 [[X:%.*]], 9
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %mul = mul nuw i32 %x, 9
+  %or = or disjoint i32 %mul, 81
+  %r = udiv i32 %or, 9
+  ret i32 %r
+}
diff --git a/llvm/test/Transforms/InstCombine/implies.ll b/llvm/test/Transforms/InstCombine/implies.ll
new file mode 100644
index 00000000000000..c02d84d3f83711
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/implies.ll
@@ -0,0 +1,424 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i1 @or_implies_sle(i8 %x, i8 %y, i1 %other) {
+; CHECK-LABEL: @or_implies_sle(
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[X:%.*]], 23
+; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp sgt i8 [[OR]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
+; CHECK:       T:
+; CHECK-NEXT:    ret i1 true
+; CHECK:       F:
+; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
+;
+  %or = or i8 %x, 23
+  %cond = icmp sle i8 %or, %y
+  br i1 %cond, label %T, label %F
+T:
+  %r = icmp sle i8 %x, %y
+  ret i1 %r
+F:
+  ret i1 %other
+}
+
+define i1 @or_implies_sle_fail(i8 %x, i8 %y, i1 %other) {
+; CHECK-LABEL: @or_implies_sle_fail(
+; CHECK-NEXT:    [[OR:%.*]] = or i8 [[X:%.*]], -34
+; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp sgt i8 [[OR]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
+; CHECK:       T:
+; CHECK-NEXT:    [[R:%.*]] = icmp sle i8 [[X]], [[Y]]
+; CHECK-NEXT:    ret i1 [[R]]
+; CHECK:       F:
+; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
+;
+  %or = or i8 %x, -34
+  %cond = icmp sle i8 %or, %y
+  br i1 %cond, label %T, label %F
+T:
+  %r = icmp sle i8 %x, %y
+  ret i1 %r
+F:
+  ret i1 %other
+}
+
+define i1 @or_distjoint_implies_ule(i8 %x, i8 %y, i1 %other) {
+; CHECK-LABEL: @or_distjoint_implies_ule(
+; CHECK-NEXT:    [[X2:%.*]] = or disjoint i8 [[X:%.*]], 24
+; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp ugt i8 [[X2]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
+; CHECK:       T:
+; CHECK-NEXT:    ret i1 true
+; CHECK:       F:
+; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
+;
+  %x1 = or disjoint i8 %x, 23
+  %x2 = or disjoint i8 %x, 24
+
+  %cond = icmp ule i8 %x2, %y
+  br i1 %cond, label %T, label %F
+T:
+  %r = icmp ule i8 %x1, %y
+  ret i1 %r
+F:
+  ret i1 %other
+}
+
+define i1 @or_distjoint_implies_ule_fail(i8 %x, i8 %y, i1 %other) {
+; CHECK-LABEL: @or_distjoint_implies_ule_fail(
+; CHECK-NEXT:    [[X2:%.*]] = or disjoint i8 [[X:%.*]], 24
+; CHECK-NEXT:    [[COND_NOT:%.*]] = icmp ugt i8 [[X2]], [[Y:%.*]]
+; CHECK-NEXT:    br i1 [[COND_NOT]], label [[F:%.*]], label [[T:%.*]]
+; CHECK:       T:
+; CHECK-NEXT:    [[X1:%.*]] = or disjoint i8 [[X]], 28
+; CHECK-NEXT:    [[R:%.*]] = icmp ule i8 [[X1]], [[Y]]
+; CHECK-NEXT:    ret i1 [[R]]
+; CHECK:       F:
+; CHECK-NEXT:    ret i1 [[OTHER:%.*]]
+;
+  %x1 = or disjoint i8 %x, 28
+  %x2 = or disjoint i8 %x, 24
+
+  %cond = icmp ule i8 %x2, %y
+  br i1 %cond, label %T, label %F
+T:
+  %r = icmp ule i8 %x1, %y
+  ret i1 %r
+F:
+  ret i1 %other
+}
+
+define i1 @or_prove_distjoin_implies_ule(i8 %xx, i8 %y, i1 %other) {
+; CHECK-LABEL: @or_prove_distjoin_implies_ule(
+; CHECK-NEXT:    [[X:%.*]] = and i8 [[XX:%.*]], -16
+; CHECK-NEXT:    [[X2:%.*]] = or disjoint i8 [[X]], 10
+; CHECK-NEXT:    [[COND_NOT:%.*]] = icm...
[truncated]

@goldsteinn
Copy link
Contributor Author

@nikic, IIRC you had started a patch on a helper function for value relationships and would be happy to drop this and pick that up. I couldn't find your patch though. Maybe I misremember?

@nikic
Copy link
Contributor

nikic commented Mar 21, 2024

@goldsteinn I think the patch you have in mind is #69471. Totally forgot about that one...

My general feedback on this patch is that we're re-implementing simplifyICmpInst() here. I think it would be good to at least check what the compile-time impact of using simplifyICmpInst() instead would be -- I suspect "bad", but best to confirm that...

Otherwise, I am fine with having some pragmatic duplication here, if having it is useful. I just want to make sure we don't end up with a full reimplementation down the line because a certain someone keeps spotting missing cases :)

@goldsteinn
Copy link
Contributor Author

@goldsteinn I think the patch you have in mind is #69471. Totally forgot about that one...

My general feedback on this patch is that we're re-implementing simplifyICmpInst() here. I think it would be good to at least check what the compile-time impact of using simplifyICmpInst() instead would be -- I suspect "bad", but best to confirm that...

Otherwise, I am fine with having some pragmatic duplication here, if having it is useful. I just want to make sure we don't end up with a full reimplementation down the line because a certain someone keeps spotting missing cases :)

Let me look into creating a common helper for simlifyICmpInst and this.

@goldsteinn
Copy link
Contributor Author

@goldsteinn I think the patch you have in mind is #69471. Totally forgot about that one...
My general feedback on this patch is that we're re-implementing simplifyICmpInst() here. I think it would be good to at least check what the compile-time impact of using simplifyICmpInst() instead would be -- I suspect "bad", but best to confirm that...
Otherwise, I am fine with having some pragmatic duplication here, if having it is useful. I just want to make sure we don't end up with a full reimplementation down the line because a certain someone keeps spotting missing cases :)

Let me look into creating a common helper for simlifyICmpInst and this.

I think impl is relatively novel, we have a scattering if the cases supported in
instsimplify like simplifyICmpWithMinMax, but the and/or icmp instructions
generally defer to is implied.

There is one notable "regression". This patch replaces the bespoke `or
disjoint` logic we a direct match. This means we fail some
simplification during `instsimplify`.
All the cases we fail in `instsimplify` we do handle in `instcombine`
as we add `disjoint` flags.

Other than that, just some basic cases.

See proofs: https://alive2.llvm.org/ce/z/_-g7C8
@dtcxzyw dtcxzyw changed the title [ValueTracking] [ValueTracking] Add more conditions in isTruePredicate [ValueTracking] Add more conditions in isTruePredicate Mar 21, 2024
dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Mar 24, 2024
@goldsteinn
Copy link
Contributor Author

ping.

@goldsteinn
Copy link
Contributor Author

ping

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

llvm/lib/Analysis/ValueTracking.cpp Show resolved Hide resolved
@goldsteinn goldsteinn closed this in 678f32a Apr 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants