Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Fold ((X << nuw Z) binop nuw Y) >>u Z --> X binop nuw (Y >>u Z) #88193

Merged
merged 2 commits into from
May 7, 2024

Conversation

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 9, 2024

@llvm/pr-subscribers-llvm-transforms

Author: AtariDreams (AtariDreams)

Changes

I also added mul nsw/nuw 3, div 2 since this was the canonical version of ((x << 1) + x) / 2, which is a specific expression which canonicalization causes the InstCombine to miss it.

Proof: https://alive2.llvm.org/ce/z/kDVTiL


Full diff: https://github.com/llvm/llvm-project/pull/88193.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+50-2)
  • (modified) llvm/test/Transforms/InstCombine/lshr.ll (+56)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 95aa2119e2d88b..1ae98be7f46d21 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1267,6 +1267,19 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)))
     return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
 
+  // Special Case:
+  // if both the add and the shift are nuw, we can omit the AND entirely
+  // ((X << Y) nuw + Z nuw) >>u Z --> (X + (Y >>u Z))
+  Value *Y;
+  if (match(Op0, m_OneUse(m_c_NUWAdd((m_NUWShl(m_Value(X), m_Specific(Op1))),
+                                     m_Value(Y))))) {
+    Value *NewLshr = Builder.CreateLShr(Y, Op1, "", I.isExact());
+    auto *newAdd = BinaryOperator::CreateNUWAdd(NewLshr, X);
+    if (auto *Op0Bin = cast<BinaryOperator>(Op0))
+      newAdd->setHasNoSignedWrap(Op0Bin->hasNoSignedWrap());
+    return newAdd;
+  }
+
   if (match(Op1, m_APInt(C))) {
     unsigned ShAmtC = C->getZExtValue();
     auto *II = dyn_cast<IntrinsicInst>(Op0);
@@ -1283,7 +1296,6 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
       return new ZExtInst(Cmp, Ty);
     }
 
-    Value *X;
     const APInt *C1;
     if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) {
       if (C1->ult(ShAmtC)) {
@@ -1328,7 +1340,7 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
     // ((X << C) + Y) >>u C --> (X + (Y >>u C)) & (-1 >>u C)
     // TODO: Consolidate with the more general transform that starts from shl
     //       (the shifts are in the opposite order).
-    Value *Y;
+
     if (match(Op0,
               m_OneUse(m_c_Add(m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))),
                                m_Value(Y))))) {
@@ -1450,9 +1462,25 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
           NewMul->setHasNoSignedWrap(true);
           return NewMul;
         }
+
+        // Special case:
+        // lshr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1)
+        if (ShAmtC == 1 && MulC->getZExtValue() == 3) {
+          auto *NewAdd = BinaryOperator::CreateNUWAdd(
+              X,
+              Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()));
+          NewAdd->setHasNoSignedWrap(true);
+          return NewAdd;
+        }
       }
     }
 
+    // // lshr nsw (mul (X, 3), 1) -> add nsw (X, lshr(X, 1)
+    if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) &&
+        ShAmtC == 1)
+      return BinaryOperator::CreateNSWAdd(
+          X, Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact()));
+
     // Try to narrow bswap.
     // In the case where the shift amount equals the bitwidth difference, the
     // shift is eliminated.
@@ -1656,6 +1684,26 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
       if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y)))))
         return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty);
     }
+
+    // Special case: ashr nuw (mul (X, 3), 1) -> add nuw nsw (X, lshr(X, 1)
+    if (match(Op0, m_OneUse(m_NSWMul(m_Value(X), m_SpecificInt(3)))) &&
+        ShAmt == 1) {
+      Value *Shift;
+      if (auto *Op0Bin = cast<BinaryOperator>(Op0)) {
+        if (Op0Bin->hasNoUnsignedWrap())
+          // We can use lshr if the mul is nuw and nsw
+          Shift =
+              Builder.CreateLShr(X, ConstantInt::get(Ty, 1), "", I.isExact());
+        else
+          Shift =
+              Builder.CreateAShr(X, ConstantInt::get(Ty, 1), "", I.isExact());
+
+        auto *NewAdd = BinaryOperator::CreateNSWAdd(X, Shift);
+        NewAdd->setHasNoUnsignedWrap(Op0Bin->hasNoUnsignedWrap());
+
+        return NewAdd;
+      }
+    }
   }
 
   const SimplifyQuery Q = SQ.getWithInstruction(&I);
diff --git a/llvm/test/Transforms/InstCombine/lshr.ll b/llvm/test/Transforms/InstCombine/lshr.ll
index 02c2bbc2819b8c..3124d05248c2eb 100644
--- a/llvm/test/Transforms/InstCombine/lshr.ll
+++ b/llvm/test/Transforms/InstCombine/lshr.ll
@@ -360,8 +360,64 @@ define <3 x i14> @mul_splat_fold_vec(<3 x i14> %x) {
   ret <3 x i14> %t
 }
 
+define i32 @mul_times_3_div_2 (i32 %x) {
+; CHECK-LABEL: @mul_times_3_div_2(
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1:%.*]], 1
+; CHECK-NEXT:    [[TMP3:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+  %2 = mul nsw nuw i32 %x, 3
+  %3 = lshr i32 %2, 1
+  ret i32 %3
+}
+
+  define i32 @shl_add_lshr (i32 %x, i32 %c, i32 %y) {
+; CHECK-LABEL: @shl_add_lshr(
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr exact i32 [[TMP2:%.*]], [[C:%.*]]
+; CHECK-NEXT:    [[TMP4:%.*]] = add nuw nsw i32 [[TMP3]], [[X:%.*]]
+; CHECK-NEXT:    ret i32 [[TMP4]]
+;
+  %2 = shl nuw i32 %x, %c
+  %3 = add nsw nuw i32 %2, %y
+  %4 = lshr exact i32 %3, %c
+  ret i32 %4
+}
+
+  define i32 @ashr_mul_times_3_div_2 (i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2(
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP0:%.*]], 1
+; CHECK-NEXT:    [[TMP3:%.*]] = add nuw nsw i32 [[TMP2]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[TMP3]]
+;
+  %2 = mul nsw nuw i32 %0, 3
+  %3 = ashr i32 %2, 1
+  ret i32 %3
+}
+
+define i32 @ashr_mul_times_3_div_2_exact (i32 %0) {
+; CHECK-LABEL: @ashr_mul_times_3_div_2_exact(
+; CHECK-NEXT:    [[TMP3:%.*]] = ashr exact i32 [[TMP2:%.*]], 1
+; CHECK-NEXT:    [[TMP4:%.*]] = add nsw i32 [[TMP3]], [[TMP2]]
+; CHECK-NEXT:    ret i32 [[TMP4]]
+;
+  %2 = mul nsw i32 %0, 3
+  %3 = ashr exact i32 %2, 1
+  ret i32 %3
+}
+
 ; Negative test
 
+define i32 @mul_times_3_div_2_no_nsw (i32 %x) {
+; CHECK-LABEL: @mul_times_3_div_2_no_nsw(
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[X:%.*]], 3
+; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[TMP1]], 1
+; CHECK-NEXT:    ret i32 [[TMP2]]
+;
+  %2 = mul i32 %x, 3
+  %3 = lshr i32 %2, 1
+  ret i32 %3
+}
+
 define i32 @mul_splat_fold_wrong_mul_const(i32 %x) {
 ; CHECK-LABEL: @mul_splat_fold_wrong_mul_const(
 ; CHECK-NEXT:    [[M:%.*]] = mul nuw i32 [[X:%.*]], 65538

@AtariDreams AtariDreams force-pushed the times3 branch 5 times, most recently from 0f47cfd to 906aa22 Compare April 11, 2024 01:37
@AtariDreams
Copy link
Contributor Author

@nikic What do you think about this?

@AtariDreams
Copy link
Contributor Author

@arsenm What do you think about this PR?

llvm/test/Transforms/InstCombine/lshr.ll Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/lshr.ll Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/lshr.ll Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/lshr.ll Show resolved Hide resolved
@AtariDreams AtariDreams force-pushed the times3 branch 2 times, most recently from f15664c to 860242d Compare April 21, 2024 21:57
@AtariDreams
Copy link
Contributor Author

@arsenm This is ready for review and merge! (I don't have commit perms)

@AtariDreams
Copy link
Contributor Author

@arsenm I removed the constant specific stuff and moved a generalization of that to another PR.

@AtariDreams
Copy link
Contributor Author

Okay it is resolved now, and I added new tests @dtcxzyw, I apologize for the inconvenience.

@dtcxzyw
Copy link
Member

dtcxzyw commented May 7, 2024

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Please don't force-push in this PR. The notifications flood my mailbox :(

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp Outdated Show resolved Hide resolved
llvm/test/Transforms/InstCombine/lshr.ll Show resolved Hide resolved
@AtariDreams
Copy link
Contributor Author

@arsenm I do not have commit perms by the way.

@arsenm arsenm merged commit 1e36c96 into llvm:main May 7, 2024
3 of 4 checks passed
@AtariDreams AtariDreams deleted the times3 branch May 7, 2024 19:24
AtariDreams added a commit to AtariDreams/llvm-project that referenced this pull request May 8, 2024
AtariDreams added a commit to AtariDreams/llvm-project that referenced this pull request May 8, 2024
It is inaccurate and needs to be corrected.
AtariDreams added a commit to AtariDreams/llvm-project that referenced this pull request May 8, 2024
It is inaccurate and needs to be corrected.
nikic pushed a commit that referenced this pull request May 9, 2024
It is inaccurate and needs to be corrected.
nikic pushed a commit that referenced this pull request May 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants