Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Remove shl if we only demand known signbits of shift source #79013

Closed
wants to merge 2 commits into from

Conversation

ParkHanbum
Copy link
Contributor

this patch resolve TODO written in commit
5909c67

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 22, 2024

@llvm/pr-subscribers-llvm-transforms

Author: hanbeom (ParkHanbum)

Changes

this patch resolve TODO written in commit
5909c67


Full diff: https://github.com/llvm/llvm-project/pull/79013.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp (+23-17)
  • (modified) llvm/test/Transforms/InstCombine/shl-demand.ll (+118)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index a8a5f9831e15e3..6ef00bab5307ec 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -640,25 +640,31 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
                                                     DemandedMask, Known))
             return R;
 
-      // TODO: If we only want bits that already match the signbit then we don't
+      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1);
+      // If we only want bits that already match the signbit then we don't
       // need to shift.
+      if (DemandedMask.countr_zero() >= ShiftAmt) {
+        unsigned NumLowDemandedBits = BitWidth - DemandedMask.countr_zero();
+        unsigned SignBits =
+            ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+        if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumLowDemandedBits)
+          return I->getOperand(0);
 
-      // If we can pre-shift a right-shifted constant to the left without
-      // losing any high bits amd we don't demand the low bits, then eliminate
-      // the left-shift:
-      // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
-      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
-      Value *X;
-      Constant *C;
-      if (DemandedMask.countr_zero() >= ShiftAmt &&
-          match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
-        Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
-        Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
-                                                      LeftShiftAmtC, DL);
-        if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, LeftShiftAmtC,
-                                         DL) == C) {
-          Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
-          return InsertNewInstWith(Lshr, I->getIterator());
+        // If we can pre-shift a right-shifted constant to the left without
+        // losing any high bits amd we don't demand the low bits, then eliminate
+        // the left-shift:
+        // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
+        Value *X;
+        Constant *C;
+        if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
+          Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
+          Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
+                                                        LeftShiftAmtC, DL);
+          if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC,
+                                           LeftShiftAmtC, DL) == C) {
+            Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
+            return InsertNewInstWith(Lshr, I->getIterator());
+          }
         }
       }
 
diff --git a/llvm/test/Transforms/InstCombine/shl-demand.ll b/llvm/test/Transforms/InstCombine/shl-demand.ll
index 85752890b4b80d..26175ebbe15358 100644
--- a/llvm/test/Transforms/InstCombine/shl-demand.ll
+++ b/llvm/test/Transforms/InstCombine/shl-demand.ll
@@ -1,6 +1,124 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
 ; RUN: opt -passes=instcombine -S < %s | FileCheck %s
 
+
+; If we only want bits that already match the signbit then we don't need to shift.
+; https://alive2.llvm.org/ce/z/WJBPVt
+define i32 @src_srem_shl_demand_max_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -2147483648
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 2           ; srem  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
+  %shl = shl i32 %srem, 30          ; shl   = SD000000000000000000000000000000
+  %mask = and i32 %shl, -2147483648 ; mask  = 10000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_min_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_min_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741823
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -2147483648
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 1073741823  ; srem  = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = SDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, -2147483648 ; mask  = 10000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_max_mask(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_mask(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SREM]], -4
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 2           ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
+  %shl = shl i32 %srem, 1           ; shl  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD0
+  %mask = and i32 %shl, -4          ; mask = 11111111111111111111111111111100
+  ret i32 %mask
+}
+
+; Negative test - mask demands non-signbit from shift source
+define i32 @src_srem_shl_demand_max_signbit_mask_hit_first_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_signbit_mask_hit_first_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 29
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -1073741824
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 4           ; srem  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD
+  %shl = shl i32 %srem, 29          ; shl   = SDD00000000000000000000000000000
+  %mask = and i32 %shl, -1073741824 ; mask  = 11000000000000000000000000000000
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_min_signbit_mask_hit_last_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_min_signbit_mask_hit_last_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 536870912
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -1073741822
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 536870912   ; srem  = SSSDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, -1073741822 ; mask  = 11000000000000000000000000000010
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_eliminate_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_eliminate_signbit(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741824
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], 2
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 1073741824  ; srem  = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+  %shl = shl i32 %srem, 1           ; shl   = DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+  %mask = and i32 %shl, 2           ; mask  = 00000000000000000000000000000010
+  ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_max_mask_hit_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_mask_hit_demand(
+; CHECK-NEXT:    [[SREM:%.*]] = srem i32 [[A0:%.*]], 4
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT:    [[MASK:%.*]] = and i32 [[SHL]], -4
+; CHECK-NEXT:    ret i32 [[MASK]]
+;
+  %srem = srem i32 %a0, 4           ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD
+  %shl= shl i32 %srem, 1            ; shl  = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD0
+  %mask = and i32 %shl, -4          ; mask = 11111111111111111111111111111100
+  ret i32 %mask
+}
+
+define <2 x i32> @src_srem_shl_mask_vector(<2 x i32> %a0) {
+; CHECK-LABEL: @src_srem_shl_mask_vector(
+; CHECK-NEXT:    [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], <i32 4, i32 4>
+; CHECK-NEXT:    [[SHL:%.*]] = shl nsw <2 x i32> [[SREM]], <i32 29, i32 29>
+; CHECK-NEXT:    [[MASK:%.*]] = and <2 x i32> [[SHL]], <i32 -1073741824, i32 -1073741824>
+; CHECK-NEXT:    ret <2 x i32> [[MASK]]
+;
+  %srem = srem <2 x i32> %a0, <i32 4, i32 4>
+  %shl = shl <2 x i32> %srem, <i32 29, i32 29>
+  %mask = and <2 x i32> %shl, <i32 -1073741824, i32 -1073741824>
+  ret <2 x i32> %mask
+}
+
+define <2 x i32> @src_srem_shl_mask_vector_nonconstant(<2 x i32> %a0, <2 x i32> %a1) {
+; CHECK-LABEL: @src_srem_shl_mask_vector_nonconstant(
+; CHECK-NEXT:    [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], <i32 4, i32 4>
+; CHECK-NEXT:    [[SHL:%.*]] = shl <2 x i32> [[SREM]], [[A1:%.*]]
+; CHECK-NEXT:    [[MASK:%.*]] = and <2 x i32> [[SHL]], <i32 -1073741824, i32 -1073741824>
+; CHECK-NEXT:    ret <2 x i32> [[MASK]]
+;
+  %srem = srem <2 x i32> %a0, <i32 4, i32 4>
+  %shl = shl <2 x i32> %srem, %a1
+  %mask = and <2 x i32> %shl, <i32 -1073741824, i32 -1073741824>
+  ret <2 x i32> %mask
+}
+
 define i16 @sext_shl_trunc_same_size(i16 %x, i32 %y) {
 ; CHECK-LABEL: @sext_shl_trunc_same_size(
 ; CHECK-NEXT:    [[CONV1:%.*]] = zext i16 [[X:%.*]] to i32

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants