-
Notifications
You must be signed in to change notification settings - Fork 10.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Remove shl if we only demand known signbits of shift source #79014
Conversation
@llvm/pr-subscribers-llvm-transforms Author: hanbeom (ParkHanbum) Changesthis patch resolve TODO written in commit Full diff: https://github.com/llvm/llvm-project/pull/79014.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index a8a5f9831e15e3..6ef00bab5307ec 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -640,25 +640,31 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
DemandedMask, Known))
return R;
- // TODO: If we only want bits that already match the signbit then we don't
+ uint64_t ShiftAmt = SA->getLimitedValue(BitWidth - 1);
+ // If we only want bits that already match the signbit then we don't
// need to shift.
+ if (DemandedMask.countr_zero() >= ShiftAmt) {
+ unsigned NumLowDemandedBits = BitWidth - DemandedMask.countr_zero();
+ unsigned SignBits =
+ ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI);
+ if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumLowDemandedBits)
+ return I->getOperand(0);
- // If we can pre-shift a right-shifted constant to the left without
- // losing any high bits amd we don't demand the low bits, then eliminate
- // the left-shift:
- // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
- uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
- Value *X;
- Constant *C;
- if (DemandedMask.countr_zero() >= ShiftAmt &&
- match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
- Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
- Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
- LeftShiftAmtC, DL);
- if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, LeftShiftAmtC,
- DL) == C) {
- Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
- return InsertNewInstWith(Lshr, I->getIterator());
+ // If we can pre-shift a right-shifted constant to the left without
+ // losing any high bits amd we don't demand the low bits, then eliminate
+ // the left-shift:
+ // (C >> X) << LeftShiftAmtC --> (C << RightShiftAmtC) >> X
+ Value *X;
+ Constant *C;
+ if (match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) {
+ Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt);
+ Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C,
+ LeftShiftAmtC, DL);
+ if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC,
+ LeftShiftAmtC, DL) == C) {
+ Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X);
+ return InsertNewInstWith(Lshr, I->getIterator());
+ }
}
}
diff --git a/llvm/test/Transforms/InstCombine/shl-demand.ll b/llvm/test/Transforms/InstCombine/shl-demand.ll
index 85752890b4b80d..26175ebbe15358 100644
--- a/llvm/test/Transforms/InstCombine/shl-demand.ll
+++ b/llvm/test/Transforms/InstCombine/shl-demand.ll
@@ -1,6 +1,124 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+; If we only want bits that already match the signbit then we don't need to shift.
+; https://alive2.llvm.org/ce/z/WJBPVt
+define i32 @src_srem_shl_demand_max_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_signbit(
+; CHECK-NEXT: [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[SREM]], -2147483648
+; CHECK-NEXT: ret i32 [[MASK]]
+;
+ %srem = srem i32 %a0, 2 ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
+ %shl = shl i32 %srem, 30 ; shl = SD000000000000000000000000000000
+ %mask = and i32 %shl, -2147483648 ; mask = 10000000000000000000000000000000
+ ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_min_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_min_signbit(
+; CHECK-NEXT: [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741823
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[SREM]], -2147483648
+; CHECK-NEXT: ret i32 [[MASK]]
+;
+ %srem = srem i32 %a0, 1073741823 ; srem = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+ %shl = shl i32 %srem, 1 ; shl = SDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+ %mask = and i32 %shl, -2147483648 ; mask = 10000000000000000000000000000000
+ ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_max_mask(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_mask(
+; CHECK-NEXT: [[SREM:%.*]] = srem i32 [[A0:%.*]], 2
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[SREM]], -4
+; CHECK-NEXT: ret i32 [[MASK]]
+;
+ %srem = srem i32 %a0, 2 ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD
+ %shl = shl i32 %srem, 1 ; shl = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSD0
+ %mask = and i32 %shl, -4 ; mask = 11111111111111111111111111111100
+ ret i32 %mask
+}
+
+; Negative test - mask demands non-signbit from shift source
+define i32 @src_srem_shl_demand_max_signbit_mask_hit_first_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_signbit_mask_hit_first_demand(
+; CHECK-NEXT: [[SREM:%.*]] = srem i32 [[A0:%.*]], 4
+; CHECK-NEXT: [[SHL:%.*]] = shl nsw i32 [[SREM]], 29
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[SHL]], -1073741824
+; CHECK-NEXT: ret i32 [[MASK]]
+;
+ %srem = srem i32 %a0, 4 ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD
+ %shl = shl i32 %srem, 29 ; shl = SDD00000000000000000000000000000
+ %mask = and i32 %shl, -1073741824 ; mask = 11000000000000000000000000000000
+ ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_min_signbit_mask_hit_last_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_min_signbit_mask_hit_last_demand(
+; CHECK-NEXT: [[SREM:%.*]] = srem i32 [[A0:%.*]], 536870912
+; CHECK-NEXT: [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[SHL]], -1073741822
+; CHECK-NEXT: ret i32 [[MASK]]
+;
+ %srem = srem i32 %a0, 536870912 ; srem = SSSDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+ %shl = shl i32 %srem, 1 ; shl = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+ %mask = and i32 %shl, -1073741822 ; mask = 11000000000000000000000000000010
+ ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_eliminate_signbit(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_eliminate_signbit(
+; CHECK-NEXT: [[SREM:%.*]] = srem i32 [[A0:%.*]], 1073741824
+; CHECK-NEXT: [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[SHL]], 2
+; CHECK-NEXT: ret i32 [[MASK]]
+;
+ %srem = srem i32 %a0, 1073741824 ; srem = SSDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
+ %shl = shl i32 %srem, 1 ; shl = DDDDDDDDDDDDDDDDDDDDDDDDDDDDDDD0
+ %mask = and i32 %shl, 2 ; mask = 00000000000000000000000000000010
+ ret i32 %mask
+}
+
+define i32 @src_srem_shl_demand_max_mask_hit_demand(i32 %a0) {
+; CHECK-LABEL: @src_srem_shl_demand_max_mask_hit_demand(
+; CHECK-NEXT: [[SREM:%.*]] = srem i32 [[A0:%.*]], 4
+; CHECK-NEXT: [[SHL:%.*]] = shl nsw i32 [[SREM]], 1
+; CHECK-NEXT: [[MASK:%.*]] = and i32 [[SHL]], -4
+; CHECK-NEXT: ret i32 [[MASK]]
+;
+ %srem = srem i32 %a0, 4 ; srem = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD
+ %shl= shl i32 %srem, 1 ; shl = SSSSSSSSSSSSSSSSSSSSSSSSSSSSSDD0
+ %mask = and i32 %shl, -4 ; mask = 11111111111111111111111111111100
+ ret i32 %mask
+}
+
+define <2 x i32> @src_srem_shl_mask_vector(<2 x i32> %a0) {
+; CHECK-LABEL: @src_srem_shl_mask_vector(
+; CHECK-NEXT: [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], <i32 4, i32 4>
+; CHECK-NEXT: [[SHL:%.*]] = shl nsw <2 x i32> [[SREM]], <i32 29, i32 29>
+; CHECK-NEXT: [[MASK:%.*]] = and <2 x i32> [[SHL]], <i32 -1073741824, i32 -1073741824>
+; CHECK-NEXT: ret <2 x i32> [[MASK]]
+;
+ %srem = srem <2 x i32> %a0, <i32 4, i32 4>
+ %shl = shl <2 x i32> %srem, <i32 29, i32 29>
+ %mask = and <2 x i32> %shl, <i32 -1073741824, i32 -1073741824>
+ ret <2 x i32> %mask
+}
+
+define <2 x i32> @src_srem_shl_mask_vector_nonconstant(<2 x i32> %a0, <2 x i32> %a1) {
+; CHECK-LABEL: @src_srem_shl_mask_vector_nonconstant(
+; CHECK-NEXT: [[SREM:%.*]] = srem <2 x i32> [[A0:%.*]], <i32 4, i32 4>
+; CHECK-NEXT: [[SHL:%.*]] = shl <2 x i32> [[SREM]], [[A1:%.*]]
+; CHECK-NEXT: [[MASK:%.*]] = and <2 x i32> [[SHL]], <i32 -1073741824, i32 -1073741824>
+; CHECK-NEXT: ret <2 x i32> [[MASK]]
+;
+ %srem = srem <2 x i32> %a0, <i32 4, i32 4>
+ %shl = shl <2 x i32> %srem, %a1
+ %mask = and <2 x i32> %shl, <i32 -1073741824, i32 -1073741824>
+ ret <2 x i32> %mask
+}
+
define i16 @sext_shl_trunc_same_size(i16 %x, i32 %y) {
; CHECK-LABEL: @sext_shl_trunc_same_size(
; CHECK-NEXT: [[CONV1:%.*]] = zext i16 [[X:%.*]] to i32
|
// need to shift. | ||
if (DemandedMask.countr_zero() >= ShiftAmt) { | ||
unsigned NumLowDemandedBits = BitWidth - DemandedMask.countr_zero(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this highbits?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my bad.. I'll fix it!
unsigned NumLowDemandedBits = BitWidth - DemandedMask.countr_zero(); | ||
unsigned SignBits = | ||
ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); | ||
if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumLowDemandedBits) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be >=
:
https://alive2.llvm.org/ce/z/C3VNoR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@goldsteinn maybe your link is wrong. (<< not now. I can open link. maybe it was my network problem)
anyway. considering that if signbits equal shiftamt,
https://alive2.llvm.org/ce/z/82A769
It doesn't seem to work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that @goldsteinn is in principle correct, but if SignBits == ShiftAmt
then SignBits - ShiftAmt
is 0, so the condition would become 0 >= NumHighDemandedBits
, which is the same as NumHighDemandedBits == 0
, which is a degenerate case that gets handled earlier anyway. So this just doesn't matter either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://alive2.llvm.org/ce/z/gaBQk7
I updated alive link for showing verification of case SignBits == ShiftAmt
and SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHighDemandedBits
. I can see that SignBits == ShiftAmt
does not pass verification when the signbit is flipped.
if I misunderstand of it, please let me know.
011f822
to
8f08a11
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
Outdated
Show resolved
Hide resolved
4132690
to
0a2c8dd
Compare
@ParkHanbum Can you rebase this patch on main? |
@dtcxzyw it is done! |
if (DemandedMask.countr_zero() >= ShiftAmt) { | ||
unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero(); | ||
unsigned SignBits = | ||
ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just check nsw
here? We will set nsw
if computeNumSignBits >= ShCnt
+ other conditions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that knowing that computeNumSignBits
is greater than ShCnt
is not enough. We need to shift computeNumSignBits
by ShCnt
and then check whether DemandedBits
is still within the range of SignBits
.
If there is a way to check the range of SignBits
and DemandedBits
by checking nsw
, please let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you are right, although you should be able to check nsw
first before doing the computeNumSignBits
as w.o nsw
think we always know ShCnt > computeNumSignBits
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@goldsteinn I understand!
…urce this patch resolve TODO written in commit 5909c67 proof: https://alive2.llvm.org/ce/z/WJBPVt
LGTM |
this patch resolve TODO written in commit
5909c67