diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp index cca9fa72d0ca5..792e17eeedab1 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -4217,18 +4217,21 @@ SDValue AMDGPUTargetLowering::performTruncateCombine( // trunc (srl (bitcast (build_vector x, y))), 16 -> trunc (bitcast y) if (Src.getOpcode() == ISD::SRL && !VT.isVector()) { if (auto *K = isConstOrConstSplat(Src.getOperand(1))) { - if (2 * K->getZExtValue() == Src.getValueType().getScalarSizeInBits()) { - SDValue BV = stripBitcast(Src.getOperand(0)); - if (BV.getOpcode() == ISD::BUILD_VECTOR && - BV.getValueType().getVectorNumElements() == 2) { - SDValue SrcElt = BV.getOperand(1); - EVT SrcEltVT = SrcElt.getValueType(); - if (SrcEltVT.isFloatingPoint()) { - SrcElt = DAG.getNode(ISD::BITCAST, SL, - SrcEltVT.changeTypeToInteger(), SrcElt); + SDValue BV = stripBitcast(Src.getOperand(0)); + if (BV.getOpcode() == ISD::BUILD_VECTOR) { + EVT SrcEltVT = BV.getOperand(0).getValueType(); + unsigned SrcEltSize = SrcEltVT.getSizeInBits(); + unsigned BitIndex = K->getZExtValue(); + unsigned PartIndex = BitIndex / SrcEltSize; + + if (PartIndex * SrcEltSize == BitIndex && + PartIndex < BV.getNumOperands()) { + if (SrcEltVT.getSizeInBits() == VT.getSizeInBits()) { + SDValue SrcElt = + DAG.getNode(ISD::BITCAST, SL, SrcEltVT.changeTypeToInteger(), + BV.getOperand(PartIndex)); + return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt); } - - return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt); } } } diff --git a/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll b/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll new file mode 100644 index 0000000000000..1c3091f6b8d3b --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/truncate-lshr-cast-build-vector-combine.ll @@ -0,0 +1,140 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx900 < %s | FileCheck %s + +; extract element 0 as shift +define i32 @cast_v4i32_to_i128_trunc_i32(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_trunc_i32: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %trunc = trunc i128 %bigint to i32 + ret i32 %trunc +} + +; extract element 1 as shift +define i32 @cast_v4i32_to_i128_lshr_32_trunc_i32(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i32: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: v_mov_b32_e32 v0, v1 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %srl = lshr i128 %bigint, 32 + %trunc = trunc i128 %srl to i32 + ret i32 %trunc +} + +; extract element 2 as shift +define i32 @cast_v4i32_to_i128_lshr_64_trunc_i32(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i32: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: v_mov_b32_e32 v0, v2 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %srl = lshr i128 %bigint, 64 + %trunc = trunc i128 %srl to i32 + ret i32 %trunc +} + +; extract element 3 as shift +define i32 @cast_v4i32_to_i128_lshr_96_trunc_i32(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_lshr_96_trunc_i32: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: v_mov_b32_e32 v0, v3 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %srl = lshr i128 %bigint, 96 + %trunc = trunc i128 %srl to i32 + ret i32 %trunc +} + +; Shift not aligned to element, not a simple extract +define i32 @cast_v4i32_to_i128_lshr_33_trunc_i32(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_lshr_33_trunc_i32: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: v_alignbit_b32 v0, v2, v1, 1 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %srl = lshr i128 %bigint, 33 + %trunc = trunc i128 %srl to i32 + ret i32 %trunc +} + +; extract misaligned element +define i32 @cast_v4i32_to_i128_lshr_31_trunc_i32(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_lshr_31_trunc_i32: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: v_alignbit_b32 v0, v1, v0, 31 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %srl = lshr i128 %bigint, 31 + %trunc = trunc i128 %srl to i32 + ret i32 %trunc +} + +; extract misaligned element +define i32 @cast_v4i32_to_i128_lshr_48_trunc_i32(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_lshr_48_trunc_i32: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: s_mov_b32 s4, 0x1000706 +; CHECK-NEXT: v_perm_b32 v0, v1, v2, s4 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %srl = lshr i128 %bigint, 48 + %trunc = trunc i128 %srl to i32 + ret i32 %trunc +} + +; extract elements 1 and 2 with shift +define i64 @cast_v4i32_to_i128_lshr_32_trunc_i64(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_lshr_32_trunc_i64: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: v_mov_b32_e32 v0, v1 +; CHECK-NEXT: v_mov_b32_e32 v1, v2 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %srl = lshr i128 %bigint, 32 + %trunc = trunc i128 %srl to i64 + ret i64 %trunc +} + +; extract elements 2 and 3 with shift +define i64 @cast_v4i32_to_i128_lshr_64_trunc_i64(<4 x i32> %arg) { +; CHECK-LABEL: cast_v4i32_to_i128_lshr_64_trunc_i64: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: v_mov_b32_e32 v1, v3 +; CHECK-NEXT: v_mov_b32_e32 v0, v2 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %bigint = bitcast <4 x i32> %arg to i128 + %srl = lshr i128 %bigint, 64 + %trunc = trunc i128 %srl to i64 + ret i64 %trunc +} + +; FIXME: We don't process this case because we see multiple bitcasts +; before a 32-bit build_vector +define i32 @build_vector_i16_to_shift(i16 %arg0, i16 %arg1, i16 %arg2, i16 %arg3) { +; CHECK-LABEL: build_vector_i16_to_shift: +; CHECK: ; %bb.0: +; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) +; CHECK-NEXT: s_mov_b32 s4, 0x5040100 +; CHECK-NEXT: v_perm_b32 v0, v3, v2, s4 +; CHECK-NEXT: s_setpc_b64 s[30:31] + %ins.0 = insertelement <4 x i16> poison, i16 %arg0, i32 0 + %ins.1 = insertelement <4 x i16> %ins.0, i16 %arg1, i32 1 + %ins.2 = insertelement <4 x i16> %ins.1, i16 %arg2, i32 2 + %ins.3 = insertelement <4 x i16> %ins.2, i16 %arg3, i32 3 + + %cast = bitcast <4 x i16> %ins.3 to i64 + %srl = lshr i64 %cast, 32 + %trunc = trunc i64 %srl to i32 + ret i32 %trunc +}