Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Remove shl if we only demand known signbits of shift source #79014

Merged
merged 2 commits into from
Apr 10, 2024

Conversation

ParkHanbum
Copy link
Contributor

this patch resolve TODO written in commit
5909c67

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 22, 2024

@llvm/pr-subscribers-llvm-transforms

Author: hanbeom (ParkHanbum)

Changes

this patch resolve TODO written in commit
5909c67


Full diff: https://github.com/llvm/llvm-project/pull/79014.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp (+23-17)
  • (modified) llvm/test/Transforms/InstCombine/shl-demand.ll (+118)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index a8a5f9831e15e3..6ef00bab5307ec 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -640,25 +640,31 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
                                                     DemandedMask, Known))
             return R;
 
-      // TODO: If we only want bits that already match the signbit then we don't
+      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1);
+      // If we only want bits that already match the signbit then we don't
       // need to shift.
+      if (DemandedMask.countr_zero() >= ShiftAmt) {
+        unsigned NumLowDemandedBits = BitWidth - DemandedMask.countr_zero();
+        unsigned SignBits =
+            ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+        if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumLowDemandedBits)
+          return I->getOperand(0);
 
-      // If we can pre-shift a right-shifted constant to the left without
-      // losing any high bits amd we don't demand the low bits, then eliminate
-      // the left-shift:
-      // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
-      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
-      Value *X;
-      Constant *C;
-      if (DemandedMask.countr_zero() >= ShiftAmt &&
-          match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
-        Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
-        Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
-                                                      LeftShiftAmtC, DL);
-        if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, LeftShiftAmtC,
-                                         DL) == C) {
-          Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
-          return InsertNewInstWith(Lshr, I->getIterator());
+        // If we can pre-shift a right-shifted constant to the left without
+        // losing any high bits amd we don't demand the low bits, then eliminate
+        // the left-shift:
+        // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
+        Value *X;
+        Constant *C;
+        if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
+          Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
+          Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
+                                                        LeftShiftAmtC, DL);
+          if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC,
+                                           LeftShiftAmtC, DL) == C) {
+            Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
+            return InsertNewInstWith(Lshr, I->getIterator());
+          }
         }
       }
 
diff --git a/llvm/test/Transforms/InstCombine/shl-demand.ll b/llvm/test/Transforms/InstCombine/shl-demand.ll
index 85752890b4b80d..26175ebbe15358 100644
--- a/llvm/test/Transforms/InstCombine/shl-demand.ll
+++ b/llvm/test/Transforms/InstCombine/shl-demand.ll
@@ -1,6 +1,124 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt -passes=instcombine -S < %s | FileCheck %s
 
+
+; If we only want bits that already match the signbit then we don't need to shift.
+; https://alive2.llvm.org/ce/z/WJBPVt
+define i32 @src_srem_shl_demand_max_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -2147483648
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 2           ; srem  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
+  %shl = shl i32 %srem, 30          ; shl   = SD000000000000000000000000000000
+  %mask = and i32 %shl, -2147483648 ; mask  = 10000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_min_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_min_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741823
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -2147483648
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 1073741823  ; srem  = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = SDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, -2147483648 ; mask  = 10000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_max_mask(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_mask(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -4
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 2           ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
+  %shl = shl i32 %srem, 1           ; shl  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD0
+  %mask = and i32 %shl, -4          ; mask = 11111111111111111111111111111100
+  ret i32 %mask
+}
+
+; Negative test - mask demands non-signbit from shift source
+define i32 @src_srem_shl_demand_max_signbit_mask_hit_first_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_signbit_mask_hit_first_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 29
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -1073741824
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 4           ; srem  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD
+  %shl = shl i32 %srem, 29          ; shl   = SDD00000000000000000000000000000
+  %mask = and i32 %shl, -1073741824 ; mask  = 11000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_min_signbit_mask_hit_last_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_min_signbit_mask_hit_last_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 536870912
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -1073741822
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 536870912   ; srem  = SSSDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, -1073741822 ; mask  = 11000000000000000000000000000010
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_eliminate_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_eliminate_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741824
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], 2
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 1073741824  ; srem  = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, 2           ; mask  = 00000000000000000000000000000010
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_max_mask_hit_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_mask_hit_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -4
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 4           ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD
+  %shl= shl i32 %srem, 1            ; shl  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD0
+  %mask = and i32 %shl, -4          ; mask = 11111111111111111111111111111100
+  ret i32 %mask
+}
+
+define <2 x i32> @src_srem_shl_mask_vector(<2 x i32> %a0) {
+; CHECK-LABEL: @src_srem_shl_mask_vector(
+; CHECK-NEXT:    [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], <i32 4, i32 4>
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw <2 x i32> [[SREM]], <i32 29, i32 29>
+; CHECK-NEXT:    [[MASK:%.*]] = and <2 x i32> [[SHL]], <i32 -1073741824, i32 -1073741824>
+; CHECK-NEXT:    ret <2 x i32> [[MASK]]
+;
+  %srem = srem <2 x i32> %a0, <i32 4, i32 4>
+  %shl = shl <2 x i32> %srem, <i32 29, i32 29>
+  %mask = and <2 x i32> %shl, <i32 -1073741824, i32 -1073741824>
+  ret <2 x i32> %mask
+}
+
+define <2 x i32> @src_srem_shl_mask_vector_nonconstant(<2 x i32> %a0, <2 x i32> %a1) {
+; CHECK-LABEL: @src_srem_shl_mask_vector_nonconstant(
+; CHECK-NEXT:    [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], <i32 4, i32 4>
+; CHECK-NEXT:    [[SHL:%.*]] = shl <2 x i32> [[SREM]], [[A1:%.*]]
+; CHECK-NEXT:    [[MASK:%.*]] = and <2 x i32> [[SHL]], <i32 -1073741824, i32 -1073741824>
+; CHECK-NEXT:    ret <2 x i32> [[MASK]]
+;
+  %srem = srem <2 x i32> %a0, <i32 4, i32 4>
+  %shl = shl <2 x i32> %srem, %a1
+  %mask = and <2 x i32> %shl, <i32 -1073741824, i32 -1073741824>
+  ret <2 x i32> %mask
+}
+
 define i16 @sext_shl_trunc_same_size(i16 %x, i32 %y) {
 ; CHECK-LABEL: @sext_shl_trunc_same_size(
 ; CHECK-NEXT:    [[CONV1:%.*]] = zext i16 [[X:%.*]] to i32

// need to shift.
if (DemandedMask.countr_zero() >= ShiftAmt) {
unsigned NumLowDemandedBits = BitWidth - DemandedMask.countr_zero();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this highbits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my bad.. I'll fix it!

unsigned NumLowDemandedBits = BitWidth - DemandedMask.countr_zero();
unsigned SignBits =
ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumLowDemandedBits)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsteinn maybe your link is wrong. (<< not now. I can open link. maybe it was my network problem)
anyway. considering that if signbits equal shiftamt,
https://alive2.llvm.org/ce/z/82A769
It doesn't seem to work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that @goldsteinn is in principle correct, but if SignBits == ShiftAmt then SignBits - ShiftAmt is 0, so the condition would become 0 >= NumHighDemandedBits, which is the same as NumHighDemandedBits == 0, which is a degenerate case that gets handled earlier anyway. So this just doesn't matter either way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://alive2.llvm.org/ce/z/gaBQk7
I updated alive link for showing verification of case SignBits == ShiftAmt and SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHighDemandedBits. I can see that SignBits == ShiftAmt does not pass verification when the signbit is flipped.

if I misunderstand of it, please let me know.

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Jan 22, 2024
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ParkHanbum ParkHanbum force-pushed the do_5e90224 branch 2 times, most recently from 4132690 to 0a2c8dd Compare January 29, 2024 16:41
@dtcxzyw
Copy link
Member

dtcxzyw commented Apr 7, 2024

@ParkHanbum Can you rebase this patch on main?

@ParkHanbum
Copy link
Contributor Author

@dtcxzyw it is done!

if (DemandedMask.countr_zero() >= ShiftAmt) {
unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
unsigned SignBits =
ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just check nsw here? We will set nsw if computeNumSignBits >= ShCnt + other conditions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that knowing that computeNumSignBits is greater than ShCnt is not enough. We need to shift computeNumSignBits by ShCnt and then check whether DemandedBits is still within the range of SignBits.
If there is a way to check the range of SignBits and DemandedBits by checking nsw, please let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you are right, although you should be able to check nsw first before doing the computeNumSignBits as w.o nsw think we always know ShCnt > computeNumSignBits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsteinn I understand!

@goldsteinn
Copy link
Contributor

LGTM

@nikic nikic merged commit 44c79da into llvm:main Apr 10, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants