Skip to content

Commit

Permalink
[InstCombine] generalize fold for (trunc (X u>> C1)) u>> C
Browse files Browse the repository at this point in the history
This is another step towards trying to re-apply D110170
by eliminating conflicting transforms that cause infinite loops.
a47c8e4 was a previous patch in this direction.

The diffs here are mostly cosmetic, but intentional:
1. The existing code that would handle this pattern in FoldShiftByConstant()
   is limited to 'shl' only now. The formatting change to IsLeftShift shows
   that we could move several transforms into visitShl() directly for
   efficiency because they are not common shift transforms.

2. The tests are regenerated to show new instruction names to prove that
   we are getting (almost) identical logic results.

3. The one case where we differ ("trunc_sandwich_small_shift1") shows that
   we now use a narrow 'and' instruction. Previously, we relied on another
   transform to do that, but it is limited to legal types. That seems to
   be a legacy constraint from when IR analysis and codegen were less robust.

https://alive2.llvm.org/ce/z/JxyGA4

  declare void @llvm.assume(i1)

  define i8 @src(i32 %x, i32 %c0, i8 %c1) {
    ; The sum of the shifts must not overflow the source width.
    %z1 = zext i8 %c1 to i32
    %sum = add i32 %c0, %z1
    %ov = icmp ult i32 %sum, 32
    call void @llvm.assume(i1 %ov)

    %sh1 = lshr i32 %x, %c0
    %tr = trunc i32 %sh1 to i8
    %sh2 = lshr i8 %tr, %c1
    ret i8 %sh2
  }

  define i8 @tgt(i32 %x, i32 %c0, i8 %c1) {
    %z1 = zext i8 %c1 to i32
    %sum = add i32 %c0, %z1
    %maskc = lshr i8 -1, %c1

    %s = lshr i32 %x, %sum
    %t = trunc i32 %s to i8
    %a = and i8 %t, %maskc
    ret i8 %a
  }
  • Loading branch information
rotateright committed Sep 27, 2021
1 parent 025a805 commit 21429cf
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 49 deletions.
41 changes: 26 additions & 15 deletions llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Expand Up @@ -661,7 +661,7 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift,

Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I) {
bool isLeftShift = I.getOpcode() == Instruction::Shl;
bool IsLeftShift = I.getOpcode() == Instruction::Shl;

const APInt *Op1C;
if (!match(Op1, m_APInt(Op1C)))
Expand All @@ -670,14 +670,14 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
// See if we can propagate this shift into the input, this covers the trivial
// cast of lshr(shl(x,c1),c2) as well as other more complex cases.
if (I.getOpcode() != Instruction::AShr &&
canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) {
LLVM_DEBUG(
dbgs() << "ICE: GetShiftedValue propagating shift through expression"
" to eliminate shift:\n IN: "
<< *Op0 << "\n SH: " << I << "\n");

return replaceInstUsesWith(
I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL));
}

// See if we can simplify any instructions used by the instruction whose sole
Expand All @@ -701,7 +701,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
// xform in more cases, but it is unlikely to be profitable.
Instruction *TrOp;
const APInt *TrShiftAmt;
if (I.isLogicalShift() && match(Op0, m_Trunc(m_Instruction(TrOp))) &&
if (IsLeftShift && match(Op0, m_Trunc(m_Instruction(TrOp))) &&
match(TrOp, m_OneUse(m_Shift(m_Value(), m_APInt(TrShiftAmt)))) &&
TrShiftAmt->ult(TrOp->getType()->getScalarSizeInBits())) {
Type *SrcTy = TrOp->getType();
Expand Down Expand Up @@ -743,7 +743,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
case Instruction::Xor: {
// These operators commute.
// Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C)
if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
if (IsLeftShift && Op0BO->getOperand(1)->hasOneUse() &&
match(Op0BO->getOperand(1), m_Shr(m_Value(V1), m_Specific(Op1)))) {
Value *YS = // (Y << C)
Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName());
Expand All @@ -758,7 +758,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,

// Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C))
Value *Op0BOOp1 = Op0BO->getOperand(1);
if (isLeftShift && Op0BOOp1->hasOneUse() &&
if (IsLeftShift && Op0BOOp1->hasOneUse() &&
match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
m_APInt(CC)))) {
Value *YS = // (Y << C)
Expand All @@ -774,7 +774,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,

case Instruction::Sub: {
// Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C)
if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
if (IsLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
match(Op0BO->getOperand(0), m_Shr(m_Value(V1), m_Specific(Op1)))) {
Value *YS = // (Y << C)
Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName());
Expand All @@ -788,7 +788,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
}

// Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C)
if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
if (IsLeftShift && Op0BO->getOperand(0)->hasOneUse() &&
match(Op0BO->getOperand(0),
m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))),
m_APInt(CC)))) {
Expand Down Expand Up @@ -824,7 +824,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1,
// If the operand is a subtract with a constant LHS, and the shift
// is the only use, we can pull it out of the shift.
// This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2))
if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
if (IsLeftShift && Op0BO->getOpcode() == Instruction::Sub &&
match(Op0BO->getOperand(0), m_APInt(Op0C))) {
Constant *NewRHS = ConstantExpr::get(
I.getOpcode(), cast<Constant>(Op0BO->getOperand(0)), Op1);
Expand Down Expand Up @@ -1158,15 +1158,26 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
}

// If the first shift covers the number of bits truncated and the combined
// shift fits in the source width:
// (trunc (X >>u C1)) >>u C --> trunc (X >>u (C1 + C))
if (match(Op0, m_OneUse(m_Trunc(m_LShr(m_Value(X), m_APInt(C1)))))) {
Instruction *TruncSrc;
if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) &&
match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) {
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
unsigned AmtSum = ShAmtC + C1->getZExtValue();
if (C1->uge(SrcWidth - BitWidth) && AmtSum < SrcWidth) {

// If the combined shift fits in the source width:
// (trunc (X >>u C1)) >>u C --> and (trunc (X >>u (C1 + C)), MaskC
//
// If the first shift covers the number of bits truncated, then the
// mask instruction is eliminated (and so the use check is relaxed).
if (AmtSum < SrcWidth &&
(TruncSrc->hasOneUse() || C1->uge(SrcWidth - BitWidth))) {
Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift");
return new TruncInst(SumShift, Ty);
Value *Trunc = Builder.CreateTrunc(SumShift, Ty, I.getName());

// If the first shift does not cover the number of bits truncated, then
// we require a mask to get rid of high bits in the result.
APInt MaskC = APInt::getAllOnes(BitWidth).lshr(ShAmtC);
return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, MaskC));
}
}

Expand Down
56 changes: 28 additions & 28 deletions llvm/test/Transforms/InstCombine/lshr.ll
Expand Up @@ -392,9 +392,9 @@ define i32 @srem2_lshr30(i32 %x) {

define i12 @trunc_sandwich(i32 %x) {
; CHECK-LABEL: @trunc_sandwich(
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 30
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SH]] to i12
; CHECK-NEXT: ret i12 [[R]]
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 30
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R1]]
;
%sh = lshr i32 %x, 28
%tr = trunc i32 %sh to i12
Expand All @@ -404,9 +404,9 @@ define i12 @trunc_sandwich(i32 %x) {

define <2 x i12> @trunc_sandwich_splat_vec(<2 x i32> %x) {
; CHECK-LABEL: @trunc_sandwich_splat_vec(
; CHECK-NEXT: [[SH:%.*]] = lshr <2 x i32> [[X:%.*]], <i32 30, i32 30>
; CHECK-NEXT: [[R:%.*]] = trunc <2 x i32> [[SH]] to <2 x i12>
; CHECK-NEXT: ret <2 x i12> [[R]]
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr <2 x i32> [[X:%.*]], <i32 30, i32 30>
; CHECK-NEXT: [[R1:%.*]] = trunc <2 x i32> [[SUM_SHIFT]] to <2 x i12>
; CHECK-NEXT: ret <2 x i12> [[R1]]
;
%sh = lshr <2 x i32> %x, <i32 22, i32 22>
%tr = trunc <2 x i32> %sh to <2 x i12>
Expand All @@ -416,9 +416,9 @@ define <2 x i12> @trunc_sandwich_splat_vec(<2 x i32> %x) {

define i12 @trunc_sandwich_min_shift1(i32 %x) {
; CHECK-LABEL: @trunc_sandwich_min_shift1(
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 21
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SH]] to i12
; CHECK-NEXT: ret i12 [[R]]
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 21
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R1]]
;
%sh = lshr i32 %x, 20
%tr = trunc i32 %sh to i12
Expand All @@ -428,9 +428,9 @@ define i12 @trunc_sandwich_min_shift1(i32 %x) {

define i12 @trunc_sandwich_small_shift1(i32 %x) {
; CHECK-LABEL: @trunc_sandwich_small_shift1(
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 20
; CHECK-NEXT: [[TR2:%.*]] = and i32 [[SH]], 2047
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[TR2]] to i12
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 20
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: [[R:%.*]] = and i12 [[R1]], 2047
; CHECK-NEXT: ret i12 [[R]]
;
%sh = lshr i32 %x, 19
Expand All @@ -441,9 +441,9 @@ define i12 @trunc_sandwich_small_shift1(i32 %x) {

define i12 @trunc_sandwich_max_sum_shift(i32 %x) {
; CHECK-LABEL: @trunc_sandwich_max_sum_shift(
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 31
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SH]] to i12
; CHECK-NEXT: ret i12 [[R]]
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 31
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R1]]
;
%sh = lshr i32 %x, 20
%tr = trunc i32 %sh to i12
Expand All @@ -453,9 +453,9 @@ define i12 @trunc_sandwich_max_sum_shift(i32 %x) {

define i12 @trunc_sandwich_max_sum_shift2(i32 %x) {
; CHECK-LABEL: @trunc_sandwich_max_sum_shift2(
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 31
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SH]] to i12
; CHECK-NEXT: ret i12 [[R]]
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X:%.*]], 31
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R1]]
;
%sh = lshr i32 %x, 30
%tr = trunc i32 %sh to i12
Expand Down Expand Up @@ -488,8 +488,8 @@ define i12 @trunc_sandwich_use1(i32 %x) {
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 28
; CHECK-NEXT: call void @use(i32 [[SH]])
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 30
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R]]
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R1]]
;
%sh = lshr i32 %x, 28
call void @use(i32 %sh)
Expand All @@ -503,8 +503,8 @@ define <3 x i9> @trunc_sandwich_splat_vec_use1(<3 x i14> %x) {
; CHECK-NEXT: [[SH:%.*]] = lshr <3 x i14> [[X:%.*]], <i14 6, i14 6, i14 6>
; CHECK-NEXT: call void @usevec(<3 x i14> [[SH]])
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr <3 x i14> [[X]], <i14 11, i14 11, i14 11>
; CHECK-NEXT: [[R:%.*]] = trunc <3 x i14> [[SUM_SHIFT]] to <3 x i9>
; CHECK-NEXT: ret <3 x i9> [[R]]
; CHECK-NEXT: [[R1:%.*]] = trunc <3 x i14> [[SUM_SHIFT]] to <3 x i9>
; CHECK-NEXT: ret <3 x i9> [[R1]]
;
%sh = lshr <3 x i14> %x, <i14 6, i14 6, i14 6>
call void @usevec(<3 x i14> %sh)
Expand All @@ -518,8 +518,8 @@ define i12 @trunc_sandwich_min_shift1_use1(i32 %x) {
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 20
; CHECK-NEXT: call void @use(i32 [[SH]])
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 21
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R]]
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R1]]
;
%sh = lshr i32 %x, 20
call void @use(i32 %sh)
Expand Down Expand Up @@ -550,8 +550,8 @@ define i12 @trunc_sandwich_max_sum_shift_use1(i32 %x) {
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 20
; CHECK-NEXT: call void @use(i32 [[SH]])
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R]]
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R1]]
;
%sh = lshr i32 %x, 20
call void @use(i32 %sh)
Expand All @@ -565,8 +565,8 @@ define i12 @trunc_sandwich_max_sum_shift2_use1(i32 %x) {
; CHECK-NEXT: [[SH:%.*]] = lshr i32 [[X:%.*]], 30
; CHECK-NEXT: call void @use(i32 [[SH]])
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i32 [[X]], 31
; CHECK-NEXT: [[R:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R]]
; CHECK-NEXT: [[R1:%.*]] = trunc i32 [[SUM_SHIFT]] to i12
; CHECK-NEXT: ret i12 [[R1]]
;
%sh = lshr i32 %x, 30
call void @use(i32 %sh)
Expand Down
12 changes: 6 additions & 6 deletions llvm/test/Transforms/InstCombine/shift.ll
Expand Up @@ -444,9 +444,9 @@ bb2:
define i32 @test29(i64 %d18) {
; CHECK-LABEL: @test29(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[I916:%.*]] = lshr i64 [[D18:%.*]], 63
; CHECK-NEXT: [[I10:%.*]] = trunc i64 [[I916]] to i32
; CHECK-NEXT: ret i32 [[I10]]
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr i64 [[D18:%.*]], 63
; CHECK-NEXT: [[I101:%.*]] = trunc i64 [[SUM_SHIFT]] to i32
; CHECK-NEXT: ret i32 [[I101]]
;
entry:
%i916 = lshr i64 %d18, 32
Expand All @@ -458,9 +458,9 @@ entry:
define <2 x i32> @test29_uniform(<2 x i64> %d18) {
; CHECK-LABEL: @test29_uniform(
; CHECK-NEXT: entry:
; CHECK-NEXT: [[I916:%.*]] = lshr <2 x i64> [[D18:%.*]], <i64 63, i64 63>
; CHECK-NEXT: [[I10:%.*]] = trunc <2 x i64> [[I916]] to <2 x i32>
; CHECK-NEXT: ret <2 x i32> [[I10]]
; CHECK-NEXT: [[SUM_SHIFT:%.*]] = lshr <2 x i64> [[D18:%.*]], <i64 63, i64 63>
; CHECK-NEXT: [[I101:%.*]] = trunc <2 x i64> [[SUM_SHIFT]] to <2 x i32>
; CHECK-NEXT: ret <2 x i32> [[I101]]
;
entry:
%i916 = lshr <2 x i64> %d18, <i64 32, i64 32>
Expand Down

0 comments on commit 21429cf

Please sign in to comment.