diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 7256c88f5dc32..3750f31e3cffb 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -785,44 +785,27 @@ Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { } const APInt *C; - if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C))) && - A->getType() == DestTy) { - // If the shift is small enough, all zero bits created by the shift are - // removed by the trunc: - // trunc (lshr (sext A), C) --> ashr A, C - if (C->getZExtValue() <= SrcWidth - DestWidth) { - unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1); - return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt)); - } - // TODO: Mask high bits with 'and'. - } + if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C)))) { + unsigned AWidth = A->getType()->getScalarSizeInBits(); + unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth); - // More complicated: deal with mismatched sizes. - // FIXME: This is too restrictive for uses and doesn't work with vectors. - // Transform trunc(lshr (sext A), Cst) to ashr A, Cst to eliminate type - // conversion. - // It works because bits coming from sign extension have the same value as - // the sign bit of the original value; performing ashr instead of lshr - // generates bits of the same value as the sign bit. - if (Src->hasOneUse() && - match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) { - Value *SExt = cast(Src)->getOperand(0); - unsigned ASize = A->getType()->getPrimitiveSizeInBits(); - unsigned MaxAmt = SrcWidth - std::max(DestWidth, ASize); - unsigned ShiftAmt = Cst->getZExtValue(); - - // This optimization can be only performed when zero bits generated by - // the original lshr aren't pulled into the value after truncation, so we - // can only shift by values no larger than the number of extension bits. - // FIXME: Instead of bailing when the shift is too large, use and to clear - // the extra bits. - if (ShiftAmt <= MaxAmt) { - if (SExt->hasOneUse()) { - Value *Shift = Builder.CreateAShr(A, std::min(ShiftAmt, ASize - 1)); - Shift->takeName(Src); + // If the shift is small enough, all zero bits created by the shift are + // removed by the trunc. + if (C->getZExtValue() <= MaxShiftAmt) { + // trunc (lshr (sext A), C) --> ashr A, C + if (A->getType() == DestTy) { + unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1); + return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt)); + } + // The types are mismatched, so create a cast after shifting: + // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) + if (Src->hasOneUse()) { + unsigned ShAmt = std::min((unsigned)C->getZExtValue(), AWidth - 1); + Value *Shift = Builder.CreateAShr(A, ShAmt); return CastInst::CreateIntegerCast(Shift, DestTy, true); } } + // TODO: Mask high bits with 'and'. } if (Instruction *I = narrowBinOp(Trunc)) diff --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll index 89a294c142dfb..a68d81acdde9e 100644 --- a/llvm/test/Transforms/InstCombine/cast.ll +++ b/llvm/test/Transforms/InstCombine/cast.ll @@ -1423,8 +1423,8 @@ define i1 @PR23309v2(i32 %A, i32 %B) { define i16 @PR24763(i8 %V) { ; ALL-LABEL: @PR24763( -; ALL-NEXT: [[L:%.*]] = ashr i8 [[V:%.*]], 1 -; ALL-NEXT: [[T:%.*]] = sext i8 [[L]] to i16 +; ALL-NEXT: [[TMP1:%.*]] = ashr i8 [[V:%.*]], 1 +; ALL-NEXT: [[T:%.*]] = sext i8 [[TMP1]] to i16 ; ALL-NEXT: ret i16 [[T]] ; %conv = sext i8 %V to i32 @@ -1619,8 +1619,8 @@ define i8 @trunc_lshr_overshift_sext_uses3(i8 %A) { define i8 @trunc_lshr_sext_wide_input(i16 %A) { ; ALL-LABEL: @trunc_lshr_sext_wide_input( -; ALL-NEXT: [[C:%.*]] = ashr i16 [[A:%.*]], 9 -; ALL-NEXT: [[D:%.*]] = trunc i16 [[C]] to i8 +; ALL-NEXT: [[TMP1:%.*]] = ashr i16 [[A:%.*]], 9 +; ALL-NEXT: [[D:%.*]] = trunc i16 [[TMP1]] to i8 ; ALL-NEXT: ret i8 [[D]] ; %B = sext i16 %A to i32 @@ -1633,8 +1633,8 @@ define <2 x i8> @trunc_lshr_sext_wide_input_uses1(<2 x i16> %A) { ; ALL-LABEL: @trunc_lshr_sext_wide_input_uses1( ; ALL-NEXT: [[B:%.*]] = sext <2 x i16> [[A:%.*]] to <2 x i32> ; ALL-NEXT: call void @use_v2i32(<2 x i32> [[B]]) -; ALL-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; ALL-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; ALL-NEXT: [[TMP1:%.*]] = ashr <2 x i16> [[A]], +; ALL-NEXT: [[D:%.*]] = trunc <2 x i16> [[TMP1]] to <2 x i8> ; ALL-NEXT: ret <2 x i8> [[D]] ; %B = sext <2 x i16> %A to <2 x i32> @@ -1692,8 +1692,8 @@ define i8 @trunc_lshr_overshift_sext_wide_input_uses1(i16 %A) { ; ALL-LABEL: @trunc_lshr_overshift_sext_wide_input_uses1( ; ALL-NEXT: [[B:%.*]] = sext i16 [[A:%.*]] to i32 ; ALL-NEXT: call void @use_i32(i32 [[B]]) -; ALL-NEXT: [[C:%.*]] = lshr i32 [[B]], 16 -; ALL-NEXT: [[D:%.*]] = trunc i32 [[C]] to i8 +; ALL-NEXT: [[TMP1:%.*]] = ashr i16 [[A]], 15 +; ALL-NEXT: [[D:%.*]] = trunc i16 [[TMP1]] to i8 ; ALL-NEXT: ret i8 [[D]] ; %B = sext i16 %A to i32 @@ -1737,8 +1737,8 @@ define i8 @trunc_lshr_overshift_sext_wide_input_uses3(i16 %A) { define i16 @trunc_lshr_sext_narrow_input(i8 %A) { ; ALL-LABEL: @trunc_lshr_sext_narrow_input( -; ALL-NEXT: [[C:%.*]] = ashr i8 [[A:%.*]], 6 -; ALL-NEXT: [[D:%.*]] = sext i8 [[C]] to i16 +; ALL-NEXT: [[TMP1:%.*]] = ashr i8 [[A:%.*]], 6 +; ALL-NEXT: [[D:%.*]] = sext i8 [[TMP1]] to i16 ; ALL-NEXT: ret i16 [[D]] ; %B = sext i8 %A to i32 @@ -1751,8 +1751,8 @@ define <2 x i16> @trunc_lshr_sext_narrow_input_uses1(<2 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_narrow_input_uses1( ; ALL-NEXT: [[B:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32> ; ALL-NEXT: call void @use_v2i32(<2 x i32> [[B]]) -; ALL-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; ALL-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i16> +; ALL-NEXT: [[TMP1:%.*]] = ashr <2 x i8> [[A]], +; ALL-NEXT: [[D:%.*]] = sext <2 x i8> [[TMP1]] to <2 x i16> ; ALL-NEXT: ret <2 x i16> [[D]] ; %B = sext <2 x i8> %A to <2 x i32> @@ -1796,9 +1796,8 @@ define <2 x i16> @trunc_lshr_sext_narrow_input_uses3(<2 x i8> %A) { define <2 x i16> @trunc_lshr_overshift_narrow_input_sext(<2 x i8> %A) { ; ALL-LABEL: @trunc_lshr_overshift_narrow_input_sext( -; ALL-NEXT: [[B:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32> -; ALL-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; ALL-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i16> +; ALL-NEXT: [[TMP1:%.*]] = ashr <2 x i8> [[A:%.*]], +; ALL-NEXT: [[D:%.*]] = sext <2 x i8> [[TMP1]] to <2 x i16> ; ALL-NEXT: ret <2 x i16> [[D]] ; %B = sext <2 x i8> %A to <2 x i32> @@ -1811,8 +1810,8 @@ define i16 @trunc_lshr_overshift_sext_narrow_input_uses1(i8 %A) { ; ALL-LABEL: @trunc_lshr_overshift_sext_narrow_input_uses1( ; ALL-NEXT: [[B:%.*]] = sext i8 [[A:%.*]] to i32 ; ALL-NEXT: call void @use_i32(i32 [[B]]) -; ALL-NEXT: [[C:%.*]] = lshr i32 [[B]], 8 -; ALL-NEXT: [[D:%.*]] = trunc i32 [[C]] to i16 +; ALL-NEXT: [[TMP1:%.*]] = ashr i8 [[A]], 7 +; ALL-NEXT: [[D:%.*]] = sext i8 [[TMP1]] to i16 ; ALL-NEXT: ret i16 [[D]] ; %B = sext i8 %A to i32 @@ -1930,8 +1929,8 @@ define i8 @pr33078_1(i8 %A) { define i12 @pr33078_2(i8 %A) { ; ALL-LABEL: @pr33078_2( -; ALL-NEXT: [[C:%.*]] = ashr i8 [[A:%.*]], 4 -; ALL-NEXT: [[D:%.*]] = sext i8 [[C]] to i12 +; ALL-NEXT: [[TMP1:%.*]] = ashr i8 [[A:%.*]], 4 +; ALL-NEXT: [[D:%.*]] = sext i8 [[TMP1]] to i12 ; ALL-NEXT: ret i12 [[D]] ; %B = sext i8 %A to i16