Skip to content

Conversation

AZero13
Copy link
Contributor

@AZero13 AZero13 commented Sep 23, 2025

To make it consistent with m_UAddWithOverflow_match.

@AZero13 AZero13 requested a review from nikic as a code owner September 23, 2025 15:39
@llvmbot llvmbot added llvm:codegen llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:transforms labels Sep 23, 2025
@AZero13 AZero13 marked this pull request as draft September 23, 2025 15:39
@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2025

@llvm/pr-subscribers-backend-x86
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: AZero13 (AZero13)

Changes

To make it consistent with m_UAddWithOverflow_match.


Full diff: https://github.com/llvm/llvm-project/pull/160327.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/PatternMatch.h (+83)
  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+55-16)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+17)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 6168e24569f99..c894e2b1aba36 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -2685,6 +2685,89 @@ m_UAddWithOverflow(const LHS_t &L, const RHS_t &R, const Sum_t &S) {
   return UAddWithOverflow_match<LHS_t, RHS_t, Sum_t>(L, R, S);
 }
 
+template <typename LHS_t, typename RHS_t, typename Diff_t>
+struct USubWithOverflow_match {
+  LHS_t L;
+  RHS_t R;
+  Diff_t S;
+
+  USubWithOverflow_match(const LHS_t &L, const RHS_t &R, const Diff_t &S)
+      : L(L), R(R), S(S) {}
+
+  template <typename OpTy> bool match(OpTy *V) const {
+    Value *ICmpLHS = nullptr, *ICmpRHS = nullptr;
+    CmpPredicate Pred;
+    if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V))
+      return false;
+
+    // Sub temporaries
+    Value *SubLHS = nullptr, *SubRHS = nullptr;
+    auto SubExpr = m_Sub(m_Value(SubLHS), m_Value(SubRHS));
+
+    // Add temporaries (we will only accept add if RHS is an APInt negative const)
+    Value *AddLHS = nullptr, *AddRHS = nullptr;
+    auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS));
+
+    // (a - b) >u a   OR   (a + (-c)) >u a  (allow add-canonicalized forms
+    // but only where the RHS is a constant APInt that is negative)
+    if (Pred == ICmpInst::ICMP_UGT) {
+      if (SubExpr.match(ICmpLHS) && ICmpRHS == SubLHS)
+        return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);
+
+      if (AddExpr.match(ICmpLHS)) {
+        const APInt *AddC = nullptr;
+        if (m_APInt(AddC).match(AddRHS) && AddC->isNegative() &&
+            ICmpRHS == AddLHS)
+          return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS);
+      }
+    }
+
+    // a <u (a - b)   OR   a <u (a + (-c))
+    if (Pred == ICmpInst::ICMP_ULT) {
+      if (SubExpr.match(ICmpRHS) && ICmpLHS == SubLHS)
+        return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS);
+
+      if (AddExpr.match(ICmpRHS)) {
+        const APInt *AddC = nullptr;
+        if (m_APInt(AddC).match(AddRHS) && AddC->isNegative() &&
+            ICmpLHS == AddLHS)
+          return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS);
+      }
+    }
+
+    // Simple forms: "op <u other"  or  "other >u op"
+    Value *Op1 = nullptr;
+    if (Pred == ICmpInst::ICMP_ULT) {
+      if (m_Value(Op1).match(ICmpLHS))
+        return L.match(Op1) && R.match(ICmpRHS) && S.match(ICmpLHS);
+    } else if (Pred == ICmpInst::ICMP_UGT) {
+      if (m_Value(Op1).match(ICmpRHS))
+        return L.match(Op1) && R.match(ICmpLHS) && S.match(ICmpRHS);
+    }
+
+    // Special-case for 0 - a != 0 (common canonicalization)
+    if (Pred == ICmpInst::ICMP_NE) {
+      // (0 - a) != 0
+      if (SubExpr.match(ICmpLHS) && m_Zero().match(ICmpRHS) &&
+          m_Zero().match(SubLHS))
+        return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);
+
+      // 0 != (0 - a)
+      if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) &&
+          m_Zero().match(SubLHS))
+        return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS);
+    }
+
+    return false;
+  }
+};
+
+template <typename LHS_t, typename RHS_t, typename Diff_t>
+USubWithOverflow_match<LHS_t, RHS_t, Diff_t>
+m_USubWithOverflow(const LHS_t &L, const RHS_t &R, const Diff_t &S) {
+  return USubWithOverflow_match<LHS_t, RHS_t, Diff_t>(L, R, S);
+}
+
 template <typename Opnd_t> struct Argument_match {
   unsigned OpI;
   Opnd_t Val;
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index d290f202f3cca..246636a4c97a3 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -1695,42 +1695,51 @@ bool CodeGenPrepare::combineToUAddWithOverflow(CmpInst *Cmp,
   return true;
 }
 
-bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
-                                               ModifyDT &ModifiedDT) {
-  // We are not expecting non-canonical/degenerate code. Just bail out.
+static bool matchUSubWithOverflowConstantEdgeCases(CmpInst *Cmp,
+                                                   BinaryOperator *&Sub) {
+  // A - B, A u> B --> usubo(A, B)
   Value *A = Cmp->getOperand(0), *B = Cmp->getOperand(1);
+
+  // We are not expecting non-canonical/degenerate code. Just bail out.
   if (isa<Constant>(A) && isa<Constant>(B))
     return false;
 
-  // Convert (A u> B) to (A u< B) to simplify pattern matching.
   ICmpInst::Predicate Pred = Cmp->getPredicate();
+
+  // Normalize: convert (A u> B) -> (B u< A)
   if (Pred == ICmpInst::ICMP_UGT) {
     std::swap(A, B);
     Pred = ICmpInst::ICMP_ULT;
   }
+
   // Convert special-case: (A == 0) is the same as (A u< 1).
   if (Pred == ICmpInst::ICMP_EQ && match(B, m_ZeroInt())) {
     B = ConstantInt::get(B->getType(), 1);
     Pred = ICmpInst::ICMP_ULT;
   }
   // Convert special-case: (A != 0) is the same as (0 u< A).
-  if (Pred == ICmpInst::ICMP_NE && match(B, m_ZeroInt())) {
+  else if (Pred == ICmpInst::ICMP_NE && match(B, m_ZeroInt())) {
     std::swap(A, B);
     Pred = ICmpInst::ICMP_ULT;
+  } else {
+    return false;
   }
+
   if (Pred != ICmpInst::ICMP_ULT)
     return false;
 
-  // Walk the users of a variable operand of a compare looking for a subtract or
-  // add with that same operand. Also match the 2nd operand of the compare to
-  // the add/sub, but that may be a negated constant operand of an add.
+  // Walk the users of the variable operand of the compare looking for a
+  // subtract or add with that same operand. Also match the 2nd operand of the
+  // compare to the add/sub, but that may be a negated constant operand of an
+  // add.
   Value *CmpVariableOperand = isa<Constant>(A) ? B : A;
-  BinaryOperator *Sub = nullptr;
+  Sub = nullptr;
+
   for (User *U : CmpVariableOperand->users()) {
     // A - B, A u< B --> usubo(A, B)
     if (match(U, m_Sub(m_Specific(A), m_Specific(B)))) {
       Sub = cast<BinaryOperator>(U);
-      break;
+      return true;
     }
 
     // A + (-C), A u< C (canonicalized form of (sub A, C))
@@ -1738,19 +1747,49 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
     if (match(U, m_Add(m_Specific(A), m_APInt(AddC))) &&
         match(B, m_APInt(CmpC)) && *AddC == -(*CmpC)) {
       Sub = cast<BinaryOperator>(U);
-      break;
+      return true;
     }
   }
-  if (!Sub)
-    return false;
 
+  return false;
+}
+
+bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
+                                               ModifyDT &ModifiedDT) {
+  bool EdgeCase = false;
+  Value *A = nullptr, *B = nullptr;
+  BinaryOperator *Sub = nullptr;
+
+  // If the compare already matches the (sub, icmp) pattern use it directly.
+  if (!match(Cmp, m_USubWithOverflow(m_Value(A), m_Value(B), m_BinOp(Sub)))) {
+    // Otherwise try to recognize constant-edge-case forms like
+    //   icmp ne (sub 0, B), 0      or
+    //   icmp eq (sub A, 1), 0
+    if (!matchUSubWithOverflowConstantEdgeCases(Cmp, Sub))
+      return false;
+    // Set A/B from the discovered Sub and record that this was an edge-case
+    // match.
+    A = Sub->getOperand(0);
+    B = Sub->getOperand(1);
+    EdgeCase = true;
+  }
+
+  // Check target wants the overflow intrinsic formed. When matching an
+  // edge-case we allow forming the intrinsic with fewer uses (mirror
+  // combineToUAddWithOverflow).
   if (!TLI->shouldFormOverflowOp(ISD::USUBO,
                                  TLI->getValueType(*DL, Sub->getType()),
-                                 Sub->hasNUsesOrMore(1)))
+                                 Sub->hasNUsesOrMore(EdgeCase ? 1 : 2)))
+    return false;
+
+  // Ensure it's safe to create the call in the icmp's basic block (same rule as
+  // UAdd).
+  if (Sub->getParent() != Cmp->getParent() && !Sub->hasOneUse())
     return false;
 
-  if (!replaceMathCmpWithIntrinsic(Sub, Sub->getOperand(0), Sub->getOperand(1),
-                                   Cmp, Intrinsic::usub_with_overflow))
+  // Replace math+cmp pair with the llvm.usub.with.overflow intrinsic.
+  if (!replaceMathCmpWithIntrinsic(Sub, A, B, Cmp,
+                                   Intrinsic::usub_with_overflow))
     return false;
 
   // Reset callers - do not crash by iterating over a dead instruction.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e4cb457499ef5..5c7aae5f91fab 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7829,6 +7829,23 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
       }
     }
 
+    Instruction *SubI = nullptr;
+    if (match(&I, m_USubWithOverflow(m_Value(X), m_Value(Y),
+                                     m_Instruction(SubI))) &&
+        isa<IntegerType>(X->getType())) {
+      Value *Result;
+      Constant *Overflow;
+      // m_UAddWithOverflow can match patterns that do not include  an explicit
+      // "add" instruction, so check the opcode of the matched op.
+      if (SubI->getOpcode() == Instruction::Sub &&
+          OptimizeOverflowCheck(Instruction::Sub, /*Signed*/ false, X, Y, *SubI,
+                                Result, Overflow)) {
+        replaceInstUsesWith(*SubI, Result);
+        eraseInstFromFunction(*SubI);
+        return replaceInstUsesWith(I, Overflow);
+      }
+    }
+
     // (zext X) * (zext Y)  --> llvm.umul.with.overflow.
     if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
         match(Op1, m_APInt(C))) {

Copy link

github-actions bot commented Sep 23, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@AZero13 AZero13 force-pushed the arm-commute branch 4 times, most recently from 2d604b0 to e7be838 Compare September 23, 2025 17:19
@AZero13 AZero13 marked this pull request as ready for review September 23, 2025 18:06
To make it consistent with m_UAddWithOverflow_match.
: L(L), R(R), S(S) {}

template <typename OpTy> bool match(OpTy *V) const {
Value *ICmpLHS = nullptr, *ICmpRHS = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these need to be initialized? They aren't initialized in UAddWithOverflow_match

if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V))
return false;

Value *SubLHS = nullptr, *SubRHS = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these need to be initialized?

Value *SubLHS = nullptr, *SubRHS = nullptr;
auto SubExpr = m_Sub(m_Value(SubLHS), m_Value(SubRHS));

Value *AddLHS = nullptr, *AddRHS = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do these need to be initialized?

// Special-case for 0 - a != 0 (common canonicalization)
if (Pred == ICmpInst::ICMP_NE) {
// (0 - a) != 0
if (SubExpr.match(ICmpLHS) && m_Zero().match(ICmpRHS) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_Zero -> m_ZeroInt

return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);

// 0 != (0 - a)
if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_ZeroInt


// 0 != (0 - a)
if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) &&
m_Zero().match(SubLHS))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_ZeroInt

isa<IntegerType>(X->getType())) {
Value *Result;
Constant *Overflow;
// m_UAddWithOverflow can match patterns that do not include an explicit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m_USubWithOverflow?

isa<IntegerType>(X->getType())) {
Value *Result;
Constant *Overflow;
// m_UAddWithOverflow can match patterns that do not include an explicit
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// m_UAddWithOverflow can match patterns that do not include an explicit
// m_UAddWithOverflow can match patterns that do not include an explicit

Value *Result;
Constant *Overflow;
// m_UAddWithOverflow can match patterns that do not include an explicit
// "add" instruction, so check the opcode of the matched op.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"sub"?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:codegen llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants