-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Improve eq/ne by parts to handle ult/ugt equality pattern #69884
Conversation
@llvm/pr-subscribers-llvm-transforms Author: None (goldsteinn) Changes
Full diff: https://github.com/llvm/llvm-project/pull/69884.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 3e0218d9b76d1f7..3bd698f1427a669 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1166,7 +1166,7 @@ static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) {
V = Builder.CreateLShr(V, P.StartBit);
Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits);
if (TruncTy != V->getType())
- V = Builder.CreateTrunc(V, TruncTy);
+ V = Builder.CreateZExtOrTrunc(V, TruncTy);
return V;
}
@@ -1179,13 +1179,57 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
return nullptr;
CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
- if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
+ auto MatchPred = [&](ICmpInst *Cmp) -> std::pair<bool, const APInt *> {
+ if (Pred == Cmp->getPredicate())
+ return {true, nullptr};
+
+ const APInt *C;
+ // (icmp eq (lshr x, C), (lshr y, C)) gets optimized to:
+ // (icmp ult (xor x, y), 1 << C) so also look for that.
+ if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT)
+ return {match(Cmp->getOperand(1), m_APInt(C)) && C->isPowerOf2() &&
+ match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())),
+ C};
+
+ // (icmp ne (lshr x, C), (lshr y, C)) gets optimized to:
+ // (icmp ugt (xor x, y), (1 << C) - 1) so also look for that.
+ if (Pred == CmpInst::ICMP_NE && Cmp->getPredicate() == CmpInst::ICMP_UGT)
+ return {match(Cmp->getOperand(1), m_APInt(C)) && C->isMask() &&
+ !C->isAllOnes() &&
+ match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value())),
+ C};
+
+ return {false, nullptr};
+ };
+
+ auto GetMatchPart = [&](std::pair<bool, const APInt *> MatchResult,
+ ICmpInst *Cmp,
+ unsigned OpNo) -> std::optional<IntPart> {
+ // Normal IntPart
+ if (MatchResult.second == nullptr)
+ return matchIntPart(Cmp->getOperand(OpNo));
+
+ // We have one of the ult/ugt patterns.
+ unsigned From;
+ const APInt *C = MatchResult.second;
+ if (Pred == CmpInst::ICMP_NE)
+ From = C->popcount();
+ else
+ From = (*C - 1).popcount();
+ Instruction *I = cast<Instruction>(Cmp->getOperand(0));
+ return {{I->getOperand(OpNo), From,
+ Cmp->getOperand(0)->getType()->getScalarSizeInBits()}};
+ };
+
+ auto Cmp0Match = MatchPred(Cmp0);
+ auto Cmp1Match = MatchPred(Cmp1);
+ if (!Cmp0Match.first || !Cmp1Match.first)
return nullptr;
- std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
- std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
- std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
- std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
+ std::optional<IntPart> L0 = GetMatchPart(Cmp0Match, Cmp0, 0);
+ std::optional<IntPart> R0 = GetMatchPart(Cmp0Match, Cmp0, 1);
+ std::optional<IntPart> L1 = GetMatchPart(Cmp1Match, Cmp1, 0);
+ std::optional<IntPart> R1 = GetMatchPart(Cmp1Match, Cmp1, 1);
if (!L0 || !R0 || !L1 || !R1)
return nullptr;
diff --git a/llvm/test/Transforms/InstCombine/eq-of-parts.ll b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
index dbf671aaaa86b40..7c5ec19903a4bc3 100644
--- a/llvm/test/Transforms/InstCombine/eq-of-parts.ll
+++ b/llvm/test/Transforms/InstCombine/eq-of-parts.ll
@@ -1333,3 +1333,107 @@ define i1 @ne_21_wrong_pred2(i32 %x, i32 %y) {
%c.210 = or i1 %c.2, %c.1
ret i1 %c.210
}
+
+define i1 @eq_optimized_highbits_cmp(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp(
+; CHECK-NEXT: [[R:%.*]] = icmp eq i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: ret i1 [[R]]
+;
+ %xor = xor i32 %y, %x
+ %cmp_hi = icmp ult i32 %xor, 33554432
+ %tx = trunc i32 %x to i25
+ %ty = trunc i32 %y to i25
+ %cmp_lo = icmp eq i25 %tx, %ty
+ %r = and i1 %cmp_hi, %cmp_lo
+ ret i1 %r
+}
+
+define i1 @eq_optimized_highbits_cmp_todo_overlapping(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp_todo_overlapping(
+; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 16777216
+; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i25
+; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i25
+; CHECK-NEXT: [[CMP_LO:%.*]] = icmp eq i25 [[TX]], [[TY]]
+; CHECK-NEXT: [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT: ret i1 [[R]]
+;
+ %xor = xor i32 %y, %x
+ %cmp_hi = icmp ult i32 %xor, 16777216
+ %tx = trunc i32 %x to i25
+ %ty = trunc i32 %y to i25
+ %cmp_lo = icmp eq i25 %tx, %ty
+ %r = and i1 %cmp_hi, %cmp_lo
+ ret i1 %r
+}
+
+define i1 @eq_optimized_highbits_cmp_fail_not_pow2(i32 %x, i32 %y) {
+; CHECK-LABEL: @eq_optimized_highbits_cmp_fail_not_pow2(
+; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ult i32 [[XOR]], 16777215
+; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i24
+; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i24
+; CHECK-NEXT: [[CMP_LO:%.*]] = icmp eq i24 [[TX]], [[TY]]
+; CHECK-NEXT: [[R:%.*]] = and i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT: ret i1 [[R]]
+;
+ %xor = xor i32 %y, %x
+ %cmp_hi = icmp ult i32 %xor, 16777215
+ %tx = trunc i32 %x to i24
+ %ty = trunc i32 %y to i24
+ %cmp_lo = icmp eq i24 %tx, %ty
+ %r = and i1 %cmp_hi, %cmp_lo
+ ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp(
+; CHECK-NEXT: [[R:%.*]] = icmp ne i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: ret i1 [[R]]
+;
+ %xor = xor i32 %y, %x
+ %cmp_hi = icmp ugt i32 %xor, 16777215
+ %tx = trunc i32 %x to i24
+ %ty = trunc i32 %y to i24
+ %cmp_lo = icmp ne i24 %tx, %ty
+ %r = or i1 %cmp_hi, %cmp_lo
+ ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp_fail_not_mask(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp_fail_not_mask(
+; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777216
+; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i24
+; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i24
+; CHECK-NEXT: [[CMP_LO:%.*]] = icmp ne i24 [[TX]], [[TY]]
+; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT: ret i1 [[R]]
+;
+ %xor = xor i32 %y, %x
+ %cmp_hi = icmp ugt i32 %xor, 16777216
+ %tx = trunc i32 %x to i24
+ %ty = trunc i32 %y to i24
+ %cmp_lo = icmp ne i24 %tx, %ty
+ %r = or i1 %cmp_hi, %cmp_lo
+ ret i1 %r
+}
+
+define i1 @ne_optimized_highbits_cmp_fail_no_combined_int(i32 %x, i32 %y) {
+; CHECK-LABEL: @ne_optimized_highbits_cmp_fail_no_combined_int(
+; CHECK-NEXT: [[XOR:%.*]] = xor i32 [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT: [[CMP_HI:%.*]] = icmp ugt i32 [[XOR]], 16777215
+; CHECK-NEXT: [[TX:%.*]] = trunc i32 [[X]] to i23
+; CHECK-NEXT: [[TY:%.*]] = trunc i32 [[Y]] to i23
+; CHECK-NEXT: [[CMP_LO:%.*]] = icmp ne i23 [[TX]], [[TY]]
+; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP_HI]], [[CMP_LO]]
+; CHECK-NEXT: ret i1 [[R]]
+;
+ %xor = xor i32 %y, %x
+ %cmp_hi = icmp ugt i32 %xor, 16777215
+ %tx = trunc i32 %x to i23
+ %ty = trunc i32 %y to i23
+ %cmp_lo = icmp ne i23 %tx, %ty
+ %r = or i1 %cmp_hi, %cmp_lo
+ ret i1 %r
+}
|
534c79f
to
e31b1f6
Compare
e31b1f6
to
6a25403
Compare
ping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
Type *TruncTy = V->getType()->getWithNewBitWidth(P.NumBits); | ||
if (TruncTy != V->getType()) | ||
V = Builder.CreateTrunc(V, TruncTy); | ||
Type *OutTy = V->getType()->getWithNewBitWidth(P.NumBits); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you rename TruncTy
to OutTy
?
unsigned From = Pred == CmpInst::ICMP_NE ? C->popcount() : C->countr_zero(); | ||
Instruction *I = cast<Instruction>(Cmp->getOperand(0)); | ||
return {{I->getOperand(OpNo), From, | ||
Cmp->getOperand(0)->getType()->getScalarSizeInBits() - From}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cmp->getOperand(0)->getType()->getScalarSizeInBits() - From}}; | |
C->getBitWidth() - From}}; |
It would be simpler :)
…tern. (icmp eq/ne (lshr x, C), (lshr y, C) gets optimized to `(icmp ult/uge (xor x, y), (1 << C)`. This can cause the current equal by parts detection to miss the high-bits as it may get optimized to the new pattern. This commit adds support for detecting / combining the ult/ugt pattern.
6a25403
to
1834ec7
Compare
ult/ugt
equality pattern.