Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Improve eq/ne by parts to handle ult/ugt equality pattern #69884

Closed
wants to merge 2 commits into from

Conversation

goldsteinn
Copy link
Contributor

  • [InstCombine] Add tests for new eq/ne patterns combining eq/ne by parts; NFC
  • [InstCombine] Improve eq/ne by parts to handle ult/ugt equality pattern.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 22, 2023

@llvm/pr-subscribers-llvm-transforms

Author: None (goldsteinn)

Changes
  • [InstCombine] Add tests for new eq/ne patterns combining eq/ne by parts; NFC
  • [InstCombine] Improve eq/ne by parts to handle ult/ugt equality pattern.

Full diff: https://github.com/llvm/llvm-project/pull/69884.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+50-6)
  • (modified) llvm/test/Transforms/InstCombine/eq-of-parts.ll (+104)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 3e0218d9b76d1f7..3bd698f1427a669 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1166,7 +1166,7 @@ static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) {
     V = Builder.CreateLShr(V, P.StartBit);
   Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits);
   if (TruncTy != V->getType())
-    V = Builder.CreateTrunc(V, TruncTy);
+    V = Builder.CreateZExtOrTrunc(V, TruncTy);
   return V;
 }
 
@@ -1179,13 +1179,57 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
     return nullptr;
 
   CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
-  if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
+  auto MatchPred = [&](ICmpInst *Cmp) -> std::pair<bool, const APInt *> {
+    if (Pred == Cmp->getPredicate())
+      return {true, nullptr};
+
+    const APInt *C;
+    // (icmp eq (lshr x, C), (lshr y, C)) gets optimized to:
+    // (icmp ult (xor x, y), 1 << C) so also look for that.
+    if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT)
+      return {match(Cmp->getOperand(1), m_APInt(C)) && C->isPowerOf2() &&
+                  match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())),
+              C};
+
+    // (icmp ne (lshr x, C), (lshr y, C)) gets optimized to:
+    // (icmp ugt (xor x, y), (1 << C) - 1) so also look for that.
+    if (Pred == CmpInst::ICMP_NE && Cmp->getPredicate() == CmpInst::ICMP_UGT)
+      return {match(Cmp->getOperand(1), m_APInt(C)) && C->isMask() &&
+                  !C->isAllOnes() &&
+                  match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())),
+              C};
+
+    return {false, nullptr};
+  };
+
+  auto GetMatchPart = [&](std::pair<bool, const APInt *> MatchResult,
+                          ICmpInst *Cmp,
+                          unsigned OpNo) -> std::optional<IntPart> {
+    // Normal IntPart
+    if (MatchResult.second == nullptr)
+      return matchIntPart(Cmp->getOperand(OpNo));
+
+    // We have one of the ult/ugt patterns.
+    unsigned From;
+    const APInt *C = MatchResult.second;
+    if (Pred == CmpInst::ICMP_NE)
+      From = C->popcount();
+    else
+      From = (*C - 1).popcount();
+    Instruction *I = cast<Instruction>(Cmp->getOperand(0));
+    return {{I->getOperand(OpNo), From,
+             Cmp->getOperand(0)->getType()->getScalarSizeInBits()}};
+  };
+
+  auto Cmp0Match = MatchPred(Cmp0);
+  auto Cmp1Match = MatchPred(Cmp1);
+  if (!Cmp0Match.first || !Cmp1Match.first)
     return nullptr;
 
-  std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
-  std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
-  std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
-  std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
+  std::optional<IntPart> L0 = GetMatchPart(Cmp0Match, Cmp0, 0);
+  std::optional<IntPart> R0 = GetMatchPart(Cmp0Match, Cmp0, 1);
+  std::optional<IntPart> L1 = GetMatchPart(Cmp1Match, Cmp1, 0);
+  std::optional<IntPart> R1 = GetMatchPart(Cmp1Match, Cmp1, 1);
   if (!L0 || !R0 || !L1 || !R1)
     return nullptr;
 
diff --git a/llvm/test/Transforms/InstCombine/eq-of-parts.ll b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
index dbf671aaaa86b40..7c5ec19903a4bc3 100644
--- a/llvm/test/Transforms/InstCombine/eq-of-parts.ll
+++ b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
@@ -1333,3 +1333,107 @@ define i1 @ne_21_wrong_pred2(i32 %x, i32 %y) {
   %c.210 = or i1 %c.2, %c.1
   ret i1 %c.210
 }
+
+define i1 @eq_optimized_highbits_cmp(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp(
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ult i32 %xor, 33554432
+  %tx = trunc i32 %x to i25
+  %ty = trunc i32 %y to i25
+  %cmp_lo = icmp eq i25 %tx, %ty
+  %r = and i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @eq_optimized_highbits_cmp_todo_overlapping(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp_todo_overlapping(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 16777216
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i25
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i25
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp eq i25 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ult i32 %xor, 16777216
+  %tx = trunc i32 %x to i25
+  %ty = trunc i32 %y to i25
+  %cmp_lo = icmp eq i25 %tx, %ty
+  %r = and i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @eq_optimized_highbits_cmp_fail_not_pow2(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp_fail_not_pow2(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 16777215
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i24
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i24
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp eq i24 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ult i32 %xor, 16777215
+  %tx = trunc i32 %x to i24
+  %ty = trunc i32 %y to i24
+  %cmp_lo = icmp eq i24 %tx, %ty
+  %r = and i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp(
+; CHECK-NEXT:    [[R:%.*]] = icmp ne i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ugt i32 %xor, 16777215
+  %tx = trunc i32 %x to i24
+  %ty = trunc i32 %y to i24
+  %cmp_lo = icmp ne i24 %tx, %ty
+  %r = or i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp_fail_not_mask(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp_fail_not_mask(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777216
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i24
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i24
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ugt i32 %xor, 16777216
+  %tx = trunc i32 %x to i24
+  %ty = trunc i32 %y to i24
+  %cmp_lo = icmp ne i24 %tx, %ty
+  %r = or i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp_fail_no_combined_int(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp_fail_no_combined_int(
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777215
+; CHECK-NEXT:    [[TX:%.*]] = trunc i32 [[X]] to i23
+; CHECK-NEXT:    [[TY:%.*]] = trunc i32 [[Y]] to i23
+; CHECK-NEXT:    [[CMP_LO:%.*]] = icmp ne i23 [[TX]], [[TY]]
+; CHECK-NEXT:    [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT:    ret i1 [[R]]
+;
+  %xor = xor i32 %y, %x
+  %cmp_hi = icmp ugt i32 %xor, 16777215
+  %tx = trunc i32 %x to i23
+  %ty = trunc i32 %y to i23
+  %cmp_lo = icmp ne i23 %tx, %ty
+  %r = or i1 %cmp_hi, %cmp_lo
+  ret i1 %r
+}

@dtcxzyw dtcxzyw changed the title goldsteinn/int parts [InstCombine] Improve eq/ne by parts to handle ult/ugt equality pattern Oct 22, 2023
@goldsteinn
Copy link
Contributor Author

ping.

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits);
if (TruncTy != V->getType())
V = Builder.CreateTrunc(V, TruncTy);
Type *OutTy = V->getType()->getWithNewBitWidth(P.NumBits);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you rename TruncTy to OutTy?

unsigned From = Pred == CmpInst::ICMP_NE ? C->popcount() : C->countr_zero();
Instruction *I = cast<Instruction>(Cmp->getOperand(0));
return {{I->getOperand(OpNo), From,
Cmp->getOperand(0)->getType()->getScalarSizeInBits() - From}};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Cmp->getOperand(0)->getType()->getScalarSizeInBits() - From}};
C->getBitWidth() - From}};

It would be simpler :)

…tern.

(icmp eq/ne (lshr x, C), (lshr y, C) gets optimized to `(icmp
ult/uge (xor x, y), (1 << C)`. This can cause the current equal by
parts detection to miss the high-bits as it may get optimized to the
new pattern.

This commit adds support for detecting / combining the ult/ugt
pattern.
@goldsteinn goldsteinn closed this in ad91473 Nov 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants