-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[CodeGenPrepare] Create USubWithOverflow_match #160327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-backend-x86 @llvm/pr-subscribers-llvm-ir Author: AZero13 (AZero13) ChangesTo make it consistent with m_UAddWithOverflow_match. Full diff: https://github.com/llvm/llvm-project/pull/160327.diff 3 Files Affected:
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 6168e24569f99..c894e2b1aba36 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -2685,6 +2685,89 @@ m_UAddWithOverflow(const LHS_t &L, const RHS_t &R, const Sum_t &S) {
return UAddWithOverflow_match<LHS_t, RHS_t, Sum_t>(L, R, S);
}
+template <typename LHS_t, typename RHS_t, typename Diff_t>
+struct USubWithOverflow_match {
+ LHS_t L;
+ RHS_t R;
+ Diff_t S;
+
+ USubWithOverflow_match(const LHS_t &L, const RHS_t &R, const Diff_t &S)
+ : L(L), R(R), S(S) {}
+
+ template <typename OpTy> bool match(OpTy *V) const {
+ Value *ICmpLHS = nullptr, *ICmpRHS = nullptr;
+ CmpPredicate Pred;
+ if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V))
+ return false;
+
+ // Sub temporaries
+ Value *SubLHS = nullptr, *SubRHS = nullptr;
+ auto SubExpr = m_Sub(m_Value(SubLHS), m_Value(SubRHS));
+
+ // Add temporaries (we will only accept add if RHS is an APInt negative const)
+ Value *AddLHS = nullptr, *AddRHS = nullptr;
+ auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS));
+
+ // (a - b) >u a OR (a + (-c)) >u a (allow add-canonicalized forms
+ // but only where the RHS is a constant APInt that is negative)
+ if (Pred == ICmpInst::ICMP_UGT) {
+ if (SubExpr.match(ICmpLHS) && ICmpRHS == SubLHS)
+ return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);
+
+ if (AddExpr.match(ICmpLHS)) {
+ const APInt *AddC = nullptr;
+ if (m_APInt(AddC).match(AddRHS) && AddC->isNegative() &&
+ ICmpRHS == AddLHS)
+ return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS);
+ }
+ }
+
+ // a <u (a - b) OR a <u (a + (-c))
+ if (Pred == ICmpInst::ICMP_ULT) {
+ if (SubExpr.match(ICmpRHS) && ICmpLHS == SubLHS)
+ return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS);
+
+ if (AddExpr.match(ICmpRHS)) {
+ const APInt *AddC = nullptr;
+ if (m_APInt(AddC).match(AddRHS) && AddC->isNegative() &&
+ ICmpLHS == AddLHS)
+ return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS);
+ }
+ }
+
+ // Simple forms: "op <u other" or "other >u op"
+ Value *Op1 = nullptr;
+ if (Pred == ICmpInst::ICMP_ULT) {
+ if (m_Value(Op1).match(ICmpLHS))
+ return L.match(Op1) && R.match(ICmpRHS) && S.match(ICmpLHS);
+ } else if (Pred == ICmpInst::ICMP_UGT) {
+ if (m_Value(Op1).match(ICmpRHS))
+ return L.match(Op1) && R.match(ICmpLHS) && S.match(ICmpRHS);
+ }
+
+ // Special-case for 0 - a != 0 (common canonicalization)
+ if (Pred == ICmpInst::ICMP_NE) {
+ // (0 - a) != 0
+ if (SubExpr.match(ICmpLHS) && m_Zero().match(ICmpRHS) &&
+ m_Zero().match(SubLHS))
+ return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS);
+
+ // 0 != (0 - a)
+ if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) &&
+ m_Zero().match(SubLHS))
+ return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpRHS);
+ }
+
+ return false;
+ }
+};
+
+template <typename LHS_t, typename RHS_t, typename Diff_t>
+USubWithOverflow_match<LHS_t, RHS_t, Diff_t>
+m_USubWithOverflow(const LHS_t &L, const RHS_t &R, const Diff_t &S) {
+ return USubWithOverflow_match<LHS_t, RHS_t, Diff_t>(L, R, S);
+}
+
template <typename Opnd_t> struct Argument_match {
unsigned OpI;
Opnd_t Val;
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index d290f202f3cca..246636a4c97a3 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -1695,42 +1695,51 @@ bool CodeGenPrepare::combineToUAddWithOverflow(CmpInst *Cmp,
return true;
}
-bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
- ModifyDT &ModifiedDT) {
- // We are not expecting non-canonical/degenerate code. Just bail out.
+static bool matchUSubWithOverflowConstantEdgeCases(CmpInst *Cmp,
+ BinaryOperator *&Sub) {
+ // A - B, A u> B --> usubo(A, B)
Value *A = Cmp->getOperand(0), *B = Cmp->getOperand(1);
+
+ // We are not expecting non-canonical/degenerate code. Just bail out.
if (isa<Constant>(A) && isa<Constant>(B))
return false;
- // Convert (A u> B) to (A u< B) to simplify pattern matching.
ICmpInst::Predicate Pred = Cmp->getPredicate();
+
+ // Normalize: convert (A u> B) -> (B u< A)
if (Pred == ICmpInst::ICMP_UGT) {
std::swap(A, B);
Pred = ICmpInst::ICMP_ULT;
}
+
// Convert special-case: (A == 0) is the same as (A u< 1).
if (Pred == ICmpInst::ICMP_EQ && match(B, m_ZeroInt())) {
B = ConstantInt::get(B->getType(), 1);
Pred = ICmpInst::ICMP_ULT;
}
// Convert special-case: (A != 0) is the same as (0 u< A).
- if (Pred == ICmpInst::ICMP_NE && match(B, m_ZeroInt())) {
+ else if (Pred == ICmpInst::ICMP_NE && match(B, m_ZeroInt())) {
std::swap(A, B);
Pred = ICmpInst::ICMP_ULT;
+ } else {
+ return false;
}
+
if (Pred != ICmpInst::ICMP_ULT)
return false;
- // Walk the users of a variable operand of a compare looking for a subtract or
- // add with that same operand. Also match the 2nd operand of the compare to
- // the add/sub, but that may be a negated constant operand of an add.
+ // Walk the users of the variable operand of the compare looking for a
+ // subtract or add with that same operand. Also match the 2nd operand of the
+ // compare to the add/sub, but that may be a negated constant operand of an
+ // add.
Value *CmpVariableOperand = isa<Constant>(A) ? B : A;
- BinaryOperator *Sub = nullptr;
+ Sub = nullptr;
+
for (User *U : CmpVariableOperand->users()) {
// A - B, A u< B --> usubo(A, B)
if (match(U, m_Sub(m_Specific(A), m_Specific(B)))) {
Sub = cast<BinaryOperator>(U);
- break;
+ return true;
}
// A + (-C), A u< C (canonicalized form of (sub A, C))
@@ -1738,19 +1747,49 @@ bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
if (match(U, m_Add(m_Specific(A), m_APInt(AddC))) &&
match(B, m_APInt(CmpC)) && *AddC == -(*CmpC)) {
Sub = cast<BinaryOperator>(U);
- break;
+ return true;
}
}
- if (!Sub)
- return false;
+ return false;
+}
+
+bool CodeGenPrepare::combineToUSubWithOverflow(CmpInst *Cmp,
+ ModifyDT &ModifiedDT) {
+ bool EdgeCase = false;
+ Value *A = nullptr, *B = nullptr;
+ BinaryOperator *Sub = nullptr;
+
+ // If the compare already matches the (sub, icmp) pattern use it directly.
+ if (!match(Cmp, m_USubWithOverflow(m_Value(A), m_Value(B), m_BinOp(Sub)))) {
+ // Otherwise try to recognize constant-edge-case forms like
+ // icmp ne (sub 0, B), 0 or
+ // icmp eq (sub A, 1), 0
+ if (!matchUSubWithOverflowConstantEdgeCases(Cmp, Sub))
+ return false;
+ // Set A/B from the discovered Sub and record that this was an edge-case
+ // match.
+ A = Sub->getOperand(0);
+ B = Sub->getOperand(1);
+ EdgeCase = true;
+ }
+
+ // Check target wants the overflow intrinsic formed. When matching an
+ // edge-case we allow forming the intrinsic with fewer uses (mirror
+ // combineToUAddWithOverflow).
if (!TLI->shouldFormOverflowOp(ISD::USUBO,
TLI->getValueType(*DL, Sub->getType()),
- Sub->hasNUsesOrMore(1)))
+ Sub->hasNUsesOrMore(EdgeCase ? 1 : 2)))
+ return false;
+
+ // Ensure it's safe to create the call in the icmp's basic block (same rule as
+ // UAdd).
+ if (Sub->getParent() != Cmp->getParent() && !Sub->hasOneUse())
return false;
- if (!replaceMathCmpWithIntrinsic(Sub, Sub->getOperand(0), Sub->getOperand(1),
- Cmp, Intrinsic::usub_with_overflow))
+ // Replace math+cmp pair with the llvm.usub.with.overflow intrinsic.
+ if (!replaceMathCmpWithIntrinsic(Sub, A, B, Cmp,
+ Intrinsic::usub_with_overflow))
return false;
// Reset callers - do not crash by iterating over a dead instruction.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index e4cb457499ef5..5c7aae5f91fab 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7829,6 +7829,23 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
}
}
+ Instruction *SubI = nullptr;
+ if (match(&I, m_USubWithOverflow(m_Value(X), m_Value(Y),
+ m_Instruction(SubI))) &&
+ isa<IntegerType>(X->getType())) {
+ Value *Result;
+ Constant *Overflow;
+ // m_UAddWithOverflow can match patterns that do not include an explicit
+ // "add" instruction, so check the opcode of the matched op.
+ if (SubI->getOpcode() == Instruction::Sub &&
+ OptimizeOverflowCheck(Instruction::Sub, /*Signed*/ false, X, Y, *SubI,
+ Result, Overflow)) {
+ replaceInstUsesWith(*SubI, Result);
+ eraseInstFromFunction(*SubI);
+ return replaceInstUsesWith(I, Overflow);
+ }
+ }
+
// (zext X) * (zext Y) --> llvm.umul.with.overflow.
if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) &&
match(Op1, m_APInt(C))) {
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
2d604b0
to
e7be838
Compare
a9f9a99
to
d154edd
Compare
To make it consistent with m_UAddWithOverflow_match.
d154edd
to
22352cc
Compare
: L(L), R(R), S(S) {} | ||
|
||
template <typename OpTy> bool match(OpTy *V) const { | ||
Value *ICmpLHS = nullptr, *ICmpRHS = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do these need to be initialized? They aren't initialized in UAddWithOverflow_match
if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) | ||
return false; | ||
|
||
Value *SubLHS = nullptr, *SubRHS = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do these need to be initialized?
Value *SubLHS = nullptr, *SubRHS = nullptr; | ||
auto SubExpr = m_Sub(m_Value(SubLHS), m_Value(SubRHS)); | ||
|
||
Value *AddLHS = nullptr, *AddRHS = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do these need to be initialized?
// Special-case for 0 - a != 0 (common canonicalization) | ||
if (Pred == ICmpInst::ICMP_NE) { | ||
// (0 - a) != 0 | ||
if (SubExpr.match(ICmpLHS) && m_Zero().match(ICmpRHS) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_Zero -> m_ZeroInt
return L.match(SubLHS) && R.match(SubRHS) && S.match(ICmpLHS); | ||
|
||
// 0 != (0 - a) | ||
if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_ZeroInt
|
||
// 0 != (0 - a) | ||
if (m_Zero().match(ICmpLHS) && SubExpr.match(ICmpRHS) && | ||
m_Zero().match(SubLHS)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_ZeroInt
isa<IntegerType>(X->getType())) { | ||
Value *Result; | ||
Constant *Overflow; | ||
// m_UAddWithOverflow can match patterns that do not include an explicit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m_USubWithOverflow?
isa<IntegerType>(X->getType())) { | ||
Value *Result; | ||
Constant *Overflow; | ||
// m_UAddWithOverflow can match patterns that do not include an explicit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// m_UAddWithOverflow can match patterns that do not include an explicit | |
// m_UAddWithOverflow can match patterns that do not include an explicit |
Value *Result; | ||
Constant *Overflow; | ||
// m_UAddWithOverflow can match patterns that do not include an explicit | ||
// "add" instruction, so check the opcode of the matched op. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"sub"?
To make it consistent with m_UAddWithOverflow_match.