Skip to content

Commit

Permalink
[CodeGen][AMDGPU] EXTRACT_VECTOR_ELT: input vector element type can d…
Browse files Browse the repository at this point in the history
…iffer from output type

In function SITargetLowering::performExtractVectorElt,
the output type was not considered which could lead to type mismatches
later.

Reviewed By: arsenm

Differential Revision: https://reviews.llvm.org/D139943
  • Loading branch information
Juan Manuel MARTINEZ CAAMAÑO committed Jan 6, 2023
1 parent c8ec751 commit 543db09
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 23 deletions.
48 changes: 25 additions & 23 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Expand Up @@ -10793,26 +10793,28 @@ SDValue SITargetLowering::performExtractVectorEltCombine(
SelectionDAG &DAG = DCI.DAG;

EVT VecVT = Vec.getValueType();
EVT EltVT = VecVT.getVectorElementType();
EVT VecEltVT = VecVT.getVectorElementType();
EVT ResVT = N->getValueType(0);

unsigned VecSize = VecVT.getSizeInBits();
unsigned VecEltSize = VecEltVT.getSizeInBits();

if ((Vec.getOpcode() == ISD::FNEG ||
Vec.getOpcode() == ISD::FABS) && allUsesHaveSourceMods(N)) {
SDLoc SL(N);
EVT EltVT = N->getValueType(0);
SDValue Idx = N->getOperand(1);
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT,
Vec.getOperand(0), Idx);
return DAG.getNode(Vec.getOpcode(), SL, EltVT, Elt);
SDValue Elt =
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec.getOperand(0), Idx);
return DAG.getNode(Vec.getOpcode(), SL, ResVT, Elt);
}

// ScalarRes = EXTRACT_VECTOR_ELT ((vector-BINOP Vec1, Vec2), Idx)
// =>
// Vec1Elt = EXTRACT_VECTOR_ELT(Vec1, Idx)
// Vec2Elt = EXTRACT_VECTOR_ELT(Vec2, Idx)
// ScalarRes = scalar-BINOP Vec1Elt, Vec2Elt
if (Vec.hasOneUse() && DCI.isBeforeLegalize()) {
if (Vec.hasOneUse() && DCI.isBeforeLegalize() && VecEltVT == ResVT) {
SDLoc SL(N);
EVT EltVT = N->getValueType(0);
SDValue Idx = N->getOperand(1);
unsigned Opc = Vec.getOpcode();

Expand All @@ -10832,29 +10834,26 @@ SDValue SITargetLowering::performExtractVectorEltCombine(
case ISD::FMINNUM:
case ISD::FMAXNUM_IEEE:
case ISD::FMINNUM_IEEE: {
SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT,
SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT,
Vec.getOperand(0), Idx);
SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT,
SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT,
Vec.getOperand(1), Idx);

DCI.AddToWorklist(Elt0.getNode());
DCI.AddToWorklist(Elt1.getNode());
return DAG.getNode(Opc, SL, EltVT, Elt0, Elt1, Vec->getFlags());
return DAG.getNode(Opc, SL, ResVT, Elt0, Elt1, Vec->getFlags());
}
}
}

unsigned VecSize = VecVT.getSizeInBits();
unsigned EltSize = EltVT.getSizeInBits();

// EXTRACT_VECTOR_ELT (<n x e>, var-idx) => n x select (e, const-idx)
if (shouldExpandVectorDynExt(N)) {
SDLoc SL(N);
SDValue Idx = N->getOperand(1);
SDValue V;
for (unsigned I = 0, E = VecVT.getVectorNumElements(); I < E; ++I) {
SDValue IC = DAG.getVectorIdxConstant(I, SL);
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Vec, IC);
SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec, IC);
if (I == 0)
V = Elt;
else
Expand All @@ -10870,15 +10869,11 @@ SDValue SITargetLowering::performExtractVectorEltCombine(
// elements. This exposes more load reduction opportunities by replacing
// multiple small extract_vector_elements with a single 32-bit extract.
auto *Idx = dyn_cast<ConstantSDNode>(N->getOperand(1));
if (isa<MemSDNode>(Vec) &&
EltSize <= 16 &&
EltVT.isByteSized() &&
VecSize > 32 &&
VecSize % 32 == 0 &&
Idx) {
if (isa<MemSDNode>(Vec) && VecEltSize <= 16 && VecEltVT.isByteSized() &&
VecSize > 32 && VecSize % 32 == 0 && Idx) {
EVT NewVT = getEquivalentMemType(*DAG.getContext(), VecVT);

unsigned BitIndex = Idx->getZExtValue() * EltSize;
unsigned BitIndex = Idx->getZExtValue() * VecEltSize;
unsigned EltIdx = BitIndex / 32;
unsigned LeftoverBitIdx = BitIndex % 32;
SDLoc SL(N);
Expand All @@ -10893,9 +10888,16 @@ SDValue SITargetLowering::performExtractVectorEltCombine(
DAG.getConstant(LeftoverBitIdx, SL, MVT::i32));
DCI.AddToWorklist(Srl.getNode());

SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, EltVT.changeTypeToInteger(), Srl);
EVT VecEltAsIntVT = VecEltVT.changeTypeToInteger();
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VecEltAsIntVT, Srl);
DCI.AddToWorklist(Trunc.getNode());
return DAG.getNode(ISD::BITCAST, SL, EltVT, Trunc);

if (VecEltVT == ResVT) {
return DAG.getNode(ISD::BITCAST, SL, VecEltVT, Trunc);
}

assert(ResVT.isScalarInteger());
return DAG.getAnyExtOrTrunc(Trunc, SL, ResVT);
}

return SDValue();
Expand Down
100 changes: 100 additions & 0 deletions llvm/test/CodeGen/AMDGPU/dagcomb-extract-vec-elt-different-sizes.ll
@@ -0,0 +1,100 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx90a < %s | FileCheck -enable-var-scope %s
;
; This code is used to trigger the following dag node, with different return type and vector element type: i16 extract_vec_elt <N x i8> v, 0

define amdgpu_kernel void @eggs(i1 %arg, ptr addrspace(1) %arg1, ptr %arg2, ptr %arg3, ptr %arg4, ptr %arg5, ptr %arg6, ptr %arg7, ptr %arg8, ptr %arg9) {
; CHECK-LABEL: eggs:
; CHECK: ; %bb.0: ; %bb
; CHECK-NEXT: s_load_dword s0, s[4:5], 0x0
; CHECK-NEXT: s_load_dwordx16 s[8:23], s[4:5], 0x8
; CHECK-NEXT: s_waitcnt lgkmcnt(0)
; CHECK-NEXT: s_bitcmp0_b32 s0, 0
; CHECK-NEXT: s_cbranch_scc1 .LBB0_2
; CHECK-NEXT: ; %bb.1: ; %bb10
; CHECK-NEXT: v_mov_b32_e32 v0, 0
; CHECK-NEXT: global_load_dwordx2 v[0:1], v0, s[8:9]
; CHECK-NEXT: s_waitcnt vmcnt(0)
; CHECK-NEXT: v_lshrrev_b32_e32 v7, 8, v0
; CHECK-NEXT: v_lshrrev_b32_e32 v6, 16, v0
; CHECK-NEXT: v_lshrrev_b32_e32 v5, 24, v0
; CHECK-NEXT: v_lshrrev_b32_e32 v4, 8, v1
; CHECK-NEXT: v_lshrrev_b32_e32 v3, 16, v1
; CHECK-NEXT: v_lshrrev_b32_e32 v2, 24, v1
; CHECK-NEXT: s_branch .LBB0_3
; CHECK-NEXT: .LBB0_2:
; CHECK-NEXT: v_mov_b32_e32 v2, 0
; CHECK-NEXT: v_mov_b32_e32 v3, 0
; CHECK-NEXT: v_mov_b32_e32 v4, 0
; CHECK-NEXT: v_mov_b32_e32 v1, 0
; CHECK-NEXT: v_mov_b32_e32 v5, 0
; CHECK-NEXT: v_mov_b32_e32 v6, 0
; CHECK-NEXT: v_mov_b32_e32 v7, 0
; CHECK-NEXT: v_mov_b32_e32 v0, 0
; CHECK-NEXT: .LBB0_3: ; %bb41
; CHECK-NEXT: s_load_dwordx2 s[0:1], s[4:5], 0x48
; CHECK-NEXT: v_mov_b32_e32 v8, s10
; CHECK-NEXT: v_mov_b32_e32 v9, s11
; CHECK-NEXT: v_mov_b32_e32 v10, s12
; CHECK-NEXT: v_mov_b32_e32 v11, s13
; CHECK-NEXT: v_mov_b32_e32 v12, s14
; CHECK-NEXT: v_mov_b32_e32 v13, s15
; CHECK-NEXT: v_mov_b32_e32 v14, s16
; CHECK-NEXT: v_mov_b32_e32 v15, s17
; CHECK-NEXT: v_mov_b32_e32 v16, s18
; CHECK-NEXT: v_mov_b32_e32 v17, s19
; CHECK-NEXT: v_mov_b32_e32 v18, s20
; CHECK-NEXT: v_mov_b32_e32 v19, s21
; CHECK-NEXT: v_mov_b32_e32 v20, s22
; CHECK-NEXT: v_mov_b32_e32 v21, s23
; CHECK-NEXT: flat_store_byte v[8:9], v0
; CHECK-NEXT: flat_store_byte v[10:11], v7
; CHECK-NEXT: flat_store_byte v[12:13], v6
; CHECK-NEXT: flat_store_byte v[14:15], v5
; CHECK-NEXT: flat_store_byte v[16:17], v1
; CHECK-NEXT: flat_store_byte v[18:19], v4
; CHECK-NEXT: flat_store_byte v[20:21], v3
; CHECK-NEXT: s_waitcnt lgkmcnt(0)
; CHECK-NEXT: v_pk_mov_b32 v[0:1], s[0:1], s[0:1] op_sel:[0,1]
; CHECK-NEXT: flat_store_byte v[0:1], v2
; CHECK-NEXT: s_endpgm
bb:
br i1 %arg, label %bb10, label %bb41

bb10: ; preds = %bb
%tmp12 = load <1 x i8>, ptr addrspace(1) %arg1
%tmp13 = getelementptr i8, ptr addrspace(1) %arg1, i64 1
%tmp16 = load <1 x i8>, ptr addrspace(1) %tmp13
%tmp17 = getelementptr i8, ptr addrspace(1) %arg1, i64 2
%tmp20 = load <1 x i8>, ptr addrspace(1) %tmp17
%tmp21 = getelementptr i8, ptr addrspace(1) %arg1, i64 3
%tmp24 = load <1 x i8>, ptr addrspace(1) %tmp21
%tmp25 = getelementptr i8, ptr addrspace(1) %arg1, i64 4
%tmp28 = load <1 x i8>, ptr addrspace(1) %tmp25
%tmp29 = getelementptr i8, ptr addrspace(1) %arg1, i64 5
%tmp32 = load <1 x i8>, ptr addrspace(1) %tmp29
%tmp33 = getelementptr i8, ptr addrspace(1) %arg1, i64 6
%tmp36 = load <1 x i8>, ptr addrspace(1) %tmp33
%tmp37 = getelementptr i8, ptr addrspace(1) %arg1, i64 7
%tmp40 = load <1 x i8>, ptr addrspace(1) %tmp37
br label %bb41

bb41: ; preds = %bb10, %bb
%tmp42 = phi <1 x i8> [ %tmp40, %bb10 ], [ zeroinitializer, %bb ]
%tmp43 = phi <1 x i8> [ %tmp36, %bb10 ], [ zeroinitializer, %bb ]
%tmp44 = phi <1 x i8> [ %tmp32, %bb10 ], [ zeroinitializer, %bb ]
%tmp45 = phi <1 x i8> [ %tmp28, %bb10 ], [ zeroinitializer, %bb ]
%tmp46 = phi <1 x i8> [ %tmp24, %bb10 ], [ zeroinitializer, %bb ]
%tmp47 = phi <1 x i8> [ %tmp20, %bb10 ], [ zeroinitializer, %bb ]
%tmp48 = phi <1 x i8> [ %tmp16, %bb10 ], [ zeroinitializer, %bb ]
%tmp49 = phi <1 x i8> [ %tmp12, %bb10 ], [ zeroinitializer, %bb ]
store <1 x i8> %tmp49, ptr %arg2
store <1 x i8> %tmp48, ptr %arg3
store <1 x i8> %tmp47, ptr %arg4
store <1 x i8> %tmp46, ptr %arg5
store <1 x i8> %tmp45, ptr %arg6
store <1 x i8> %tmp44, ptr %arg7
store <1 x i8> %tmp43, ptr %arg8
store <1 x i8> %tmp42, ptr %arg9
ret void
}

0 comments on commit 543db09

Please sign in to comment.