From 543db09b97774ebf3c5da4a7044f1a94d6ba2975 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20Manuel=20MARTINEZ=20CAAMA=C3=91O?= Date: Fri, 6 Jan 2023 09:45:36 +0100 Subject: [PATCH] [CodeGen][AMDGPU] EXTRACT_VECTOR_ELT: input vector element type can differ from output type In function SITargetLowering::performExtractVectorElt, the output type was not considered which could lead to type mismatches later. Reviewed By: arsenm Differential Revision: https://reviews.llvm.org/D139943 --- llvm/lib/Target/AMDGPU/SIISelLowering.cpp | 48 +++++---- ...dagcomb-extract-vec-elt-different-sizes.ll | 100 ++++++++++++++++++ 2 files changed, 125 insertions(+), 23 deletions(-) create mode 100644 llvm/test/CodeGen/AMDGPU/dagcomb-extract-vec-elt-different-sizes.ll diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp index 598a941ac2912..329f08004abfb 100644 --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -10793,16 +10793,19 @@ SDValue SITargetLowering::performExtractVectorEltCombine( SelectionDAG &DAG = DCI.DAG; EVT VecVT = Vec.getValueType(); - EVT EltVT = VecVT.getVectorElementType(); + EVT VecEltVT = VecVT.getVectorElementType(); + EVT ResVT = N->getValueType(0); + + unsigned VecSize = VecVT.getSizeInBits(); + unsigned VecEltSize = VecEltVT.getSizeInBits(); if ((Vec.getOpcode() == ISD::FNEG || Vec.getOpcode() == ISD::FABS) && allUsesHaveSourceMods(N)) { SDLoc SL(N); - EVT EltVT = N->getValueType(0); SDValue Idx = N->getOperand(1); - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, - Vec.getOperand(0), Idx); - return DAG.getNode(Vec.getOpcode(), SL, EltVT, Elt); + SDValue Elt = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec.getOperand(0), Idx); + return DAG.getNode(Vec.getOpcode(), SL, ResVT, Elt); } // ScalarRes = EXTRACT_VECTOR_ELT ((vector-BINOP Vec1, Vec2), Idx) @@ -10810,9 +10813,8 @@ SDValue SITargetLowering::performExtractVectorEltCombine( // Vec1Elt = EXTRACT_VECTOR_ELT(Vec1, Idx) // Vec2Elt = EXTRACT_VECTOR_ELT(Vec2, Idx) // ScalarRes = scalar-BINOP Vec1Elt, Vec2Elt - if (Vec.hasOneUse() && DCI.isBeforeLegalize()) { + if (Vec.hasOneUse() && DCI.isBeforeLegalize() && VecEltVT == ResVT) { SDLoc SL(N); - EVT EltVT = N->getValueType(0); SDValue Idx = N->getOperand(1); unsigned Opc = Vec.getOpcode(); @@ -10832,21 +10834,18 @@ SDValue SITargetLowering::performExtractVectorEltCombine( case ISD::FMINNUM: case ISD::FMAXNUM_IEEE: case ISD::FMINNUM_IEEE: { - SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, + SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec.getOperand(0), Idx); - SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, + SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec.getOperand(1), Idx); DCI.AddToWorklist(Elt0.getNode()); DCI.AddToWorklist(Elt1.getNode()); - return DAG.getNode(Opc, SL, EltVT, Elt0, Elt1, Vec->getFlags()); + return DAG.getNode(Opc, SL, ResVT, Elt0, Elt1, Vec->getFlags()); } } } - unsigned VecSize = VecVT.getSizeInBits(); - unsigned EltSize = EltVT.getSizeInBits(); - // EXTRACT_VECTOR_ELT (, var-idx) => n x select (e, const-idx) if (shouldExpandVectorDynExt(N)) { SDLoc SL(N); @@ -10854,7 +10853,7 @@ SDValue SITargetLowering::performExtractVectorEltCombine( SDValue V; for (unsigned I = 0, E = VecVT.getVectorNumElements(); I < E; ++I) { SDValue IC = DAG.getVectorIdxConstant(I, SL); - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Vec, IC); + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec, IC); if (I == 0) V = Elt; else @@ -10870,15 +10869,11 @@ SDValue SITargetLowering::performExtractVectorEltCombine( // elements. This exposes more load reduction opportunities by replacing // multiple small extract_vector_elements with a single 32-bit extract. auto *Idx = dyn_cast(N->getOperand(1)); - if (isa(Vec) && - EltSize <= 16 && - EltVT.isByteSized() && - VecSize > 32 && - VecSize % 32 == 0 && - Idx) { + if (isa(Vec) && VecEltSize <= 16 && VecEltVT.isByteSized() && + VecSize > 32 && VecSize % 32 == 0 && Idx) { EVT NewVT = getEquivalentMemType(*DAG.getContext(), VecVT); - unsigned BitIndex = Idx->getZExtValue() * EltSize; + unsigned BitIndex = Idx->getZExtValue() * VecEltSize; unsigned EltIdx = BitIndex / 32; unsigned LeftoverBitIdx = BitIndex % 32; SDLoc SL(N); @@ -10893,9 +10888,16 @@ SDValue SITargetLowering::performExtractVectorEltCombine( DAG.getConstant(LeftoverBitIdx, SL, MVT::i32)); DCI.AddToWorklist(Srl.getNode()); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, EltVT.changeTypeToInteger(), Srl); + EVT VecEltAsIntVT = VecEltVT.changeTypeToInteger(); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VecEltAsIntVT, Srl); DCI.AddToWorklist(Trunc.getNode()); - return DAG.getNode(ISD::BITCAST, SL, EltVT, Trunc); + + if (VecEltVT == ResVT) { + return DAG.getNode(ISD::BITCAST, SL, VecEltVT, Trunc); + } + + assert(ResVT.isScalarInteger()); + return DAG.getAnyExtOrTrunc(Trunc, SL, ResVT); } return SDValue(); diff --git a/llvm/test/CodeGen/AMDGPU/dagcomb-extract-vec-elt-different-sizes.ll b/llvm/test/CodeGen/AMDGPU/dagcomb-extract-vec-elt-different-sizes.ll new file mode 100644 index 0000000000000..53acbb6a7bceb --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/dagcomb-extract-vec-elt-different-sizes.ll @@ -0,0 +1,100 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=amdgcn-amd-amdhsa -mcpu=gfx90a < %s | FileCheck -enable-var-scope %s +; +; This code is used to trigger the following dag node, with different return type and vector element type: i16 extract_vec_elt v, 0 + +define amdgpu_kernel void @eggs(i1 %arg, ptr addrspace(1) %arg1, ptr %arg2, ptr %arg3, ptr %arg4, ptr %arg5, ptr %arg6, ptr %arg7, ptr %arg8, ptr %arg9) { +; CHECK-LABEL: eggs: +; CHECK: ; %bb.0: ; %bb +; CHECK-NEXT: s_load_dword s0, s[4:5], 0x0 +; CHECK-NEXT: s_load_dwordx16 s[8:23], s[4:5], 0x8 +; CHECK-NEXT: s_waitcnt lgkmcnt(0) +; CHECK-NEXT: s_bitcmp0_b32 s0, 0 +; CHECK-NEXT: s_cbranch_scc1 .LBB0_2 +; CHECK-NEXT: ; %bb.1: ; %bb10 +; CHECK-NEXT: v_mov_b32_e32 v0, 0 +; CHECK-NEXT: global_load_dwordx2 v[0:1], v0, s[8:9] +; CHECK-NEXT: s_waitcnt vmcnt(0) +; CHECK-NEXT: v_lshrrev_b32_e32 v7, 8, v0 +; CHECK-NEXT: v_lshrrev_b32_e32 v6, 16, v0 +; CHECK-NEXT: v_lshrrev_b32_e32 v5, 24, v0 +; CHECK-NEXT: v_lshrrev_b32_e32 v4, 8, v1 +; CHECK-NEXT: v_lshrrev_b32_e32 v3, 16, v1 +; CHECK-NEXT: v_lshrrev_b32_e32 v2, 24, v1 +; CHECK-NEXT: s_branch .LBB0_3 +; CHECK-NEXT: .LBB0_2: +; CHECK-NEXT: v_mov_b32_e32 v2, 0 +; CHECK-NEXT: v_mov_b32_e32 v3, 0 +; CHECK-NEXT: v_mov_b32_e32 v4, 0 +; CHECK-NEXT: v_mov_b32_e32 v1, 0 +; CHECK-NEXT: v_mov_b32_e32 v5, 0 +; CHECK-NEXT: v_mov_b32_e32 v6, 0 +; CHECK-NEXT: v_mov_b32_e32 v7, 0 +; CHECK-NEXT: v_mov_b32_e32 v0, 0 +; CHECK-NEXT: .LBB0_3: ; %bb41 +; CHECK-NEXT: s_load_dwordx2 s[0:1], s[4:5], 0x48 +; CHECK-NEXT: v_mov_b32_e32 v8, s10 +; CHECK-NEXT: v_mov_b32_e32 v9, s11 +; CHECK-NEXT: v_mov_b32_e32 v10, s12 +; CHECK-NEXT: v_mov_b32_e32 v11, s13 +; CHECK-NEXT: v_mov_b32_e32 v12, s14 +; CHECK-NEXT: v_mov_b32_e32 v13, s15 +; CHECK-NEXT: v_mov_b32_e32 v14, s16 +; CHECK-NEXT: v_mov_b32_e32 v15, s17 +; CHECK-NEXT: v_mov_b32_e32 v16, s18 +; CHECK-NEXT: v_mov_b32_e32 v17, s19 +; CHECK-NEXT: v_mov_b32_e32 v18, s20 +; CHECK-NEXT: v_mov_b32_e32 v19, s21 +; CHECK-NEXT: v_mov_b32_e32 v20, s22 +; CHECK-NEXT: v_mov_b32_e32 v21, s23 +; CHECK-NEXT: flat_store_byte v[8:9], v0 +; CHECK-NEXT: flat_store_byte v[10:11], v7 +; CHECK-NEXT: flat_store_byte v[12:13], v6 +; CHECK-NEXT: flat_store_byte v[14:15], v5 +; CHECK-NEXT: flat_store_byte v[16:17], v1 +; CHECK-NEXT: flat_store_byte v[18:19], v4 +; CHECK-NEXT: flat_store_byte v[20:21], v3 +; CHECK-NEXT: s_waitcnt lgkmcnt(0) +; CHECK-NEXT: v_pk_mov_b32 v[0:1], s[0:1], s[0:1] op_sel:[0,1] +; CHECK-NEXT: flat_store_byte v[0:1], v2 +; CHECK-NEXT: s_endpgm +bb: + br i1 %arg, label %bb10, label %bb41 + +bb10: ; preds = %bb + %tmp12 = load <1 x i8>, ptr addrspace(1) %arg1 + %tmp13 = getelementptr i8, ptr addrspace(1) %arg1, i64 1 + %tmp16 = load <1 x i8>, ptr addrspace(1) %tmp13 + %tmp17 = getelementptr i8, ptr addrspace(1) %arg1, i64 2 + %tmp20 = load <1 x i8>, ptr addrspace(1) %tmp17 + %tmp21 = getelementptr i8, ptr addrspace(1) %arg1, i64 3 + %tmp24 = load <1 x i8>, ptr addrspace(1) %tmp21 + %tmp25 = getelementptr i8, ptr addrspace(1) %arg1, i64 4 + %tmp28 = load <1 x i8>, ptr addrspace(1) %tmp25 + %tmp29 = getelementptr i8, ptr addrspace(1) %arg1, i64 5 + %tmp32 = load <1 x i8>, ptr addrspace(1) %tmp29 + %tmp33 = getelementptr i8, ptr addrspace(1) %arg1, i64 6 + %tmp36 = load <1 x i8>, ptr addrspace(1) %tmp33 + %tmp37 = getelementptr i8, ptr addrspace(1) %arg1, i64 7 + %tmp40 = load <1 x i8>, ptr addrspace(1) %tmp37 + br label %bb41 + +bb41: ; preds = %bb10, %bb + %tmp42 = phi <1 x i8> [ %tmp40, %bb10 ], [ zeroinitializer, %bb ] + %tmp43 = phi <1 x i8> [ %tmp36, %bb10 ], [ zeroinitializer, %bb ] + %tmp44 = phi <1 x i8> [ %tmp32, %bb10 ], [ zeroinitializer, %bb ] + %tmp45 = phi <1 x i8> [ %tmp28, %bb10 ], [ zeroinitializer, %bb ] + %tmp46 = phi <1 x i8> [ %tmp24, %bb10 ], [ zeroinitializer, %bb ] + %tmp47 = phi <1 x i8> [ %tmp20, %bb10 ], [ zeroinitializer, %bb ] + %tmp48 = phi <1 x i8> [ %tmp16, %bb10 ], [ zeroinitializer, %bb ] + %tmp49 = phi <1 x i8> [ %tmp12, %bb10 ], [ zeroinitializer, %bb ] + store <1 x i8> %tmp49, ptr %arg2 + store <1 x i8> %tmp48, ptr %arg3 + store <1 x i8> %tmp47, ptr %arg4 + store <1 x i8> %tmp46, ptr %arg5 + store <1 x i8> %tmp45, ptr %arg6 + store <1 x i8> %tmp44, ptr %arg7 + store <1 x i8> %tmp43, ptr %arg8 + store <1 x i8> %tmp42, ptr %arg9 + ret void +}