diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInsertDelayAlu.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInsertDelayAlu.cpp index 44eaebffb70dc..9a90787963d7b 100644 --- a/llvm/lib/Target/AMDGPU/AMDGPUInsertDelayAlu.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUInsertDelayAlu.cpp @@ -25,6 +25,7 @@ namespace { class AMDGPUInsertDelayAlu { public: + const GCNSubtarget *ST; const SIInstrInfo *SII; const TargetRegisterInfo *TRI; @@ -65,13 +66,16 @@ class AMDGPUInsertDelayAlu { // Types of delay that can be encoded in an s_delay_alu instruction. enum DelayType { VALU, TRANS, SALU, OTHER }; - // Get the delay type for an instruction with the specified TSFlags. - static DelayType getDelayType(uint64_t TSFlags) { - if (TSFlags & SIInstrFlags::TRANS) + // Get the delay type for a MachineInstr. + DelayType getDelayType(const MachineInstr &MI) { + if (SIInstrInfo::isTRANS(MI)) return TRANS; - if (TSFlags & SIInstrFlags::VALU) + // WMMA XDL ops are treated the same as TRANS. + if (AMDGPU::isGFX1250(*ST) && SII->isXDLWMMA(MI)) + return TRANS; + if (SIInstrInfo::isVALU(MI)) return VALU; - if (TSFlags & SIInstrFlags::SALU) + if (SIInstrInfo::isSALU(MI)) return SALU; return OTHER; } @@ -368,7 +372,7 @@ class AMDGPUInsertDelayAlu { continue; } - DelayType Type = getDelayType(MI.getDesc().TSFlags); + DelayType Type = getDelayType(MI); if (instructionWaitsForSGPRWrites(MI)) { auto It = State.find(LastSGPRFromVALU); @@ -456,12 +460,12 @@ class AMDGPUInsertDelayAlu { LLVM_DEBUG(dbgs() << "AMDGPUInsertDelayAlu running on " << MF.getName() << "\n"); - const GCNSubtarget &ST = MF.getSubtarget(); - if (!ST.hasDelayAlu()) + ST = &MF.getSubtarget(); + if (!ST->hasDelayAlu()) return false; - SII = ST.getInstrInfo(); - TRI = ST.getRegisterInfo(); + SII = ST->getInstrInfo(); + TRI = ST->getRegisterInfo(); SchedModel = &SII->getSchedModel(); // Calculate the delay state for each basic block, iterating until we reach diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp index a1e14d90ebcab..46cec413ba8bd 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp @@ -10466,10 +10466,23 @@ bool SIInstrInfo::isGlobalMemoryObject(const MachineInstr *MI) const { return TargetInstrInfo::isGlobalMemoryObject(MI); } +bool SIInstrInfo::isXDLWMMA(const MachineInstr &MI) const { + if (!isWMMA(MI) && !isSWMMAC(MI)) + return false; + + if (AMDGPU::isGFX1250(ST)) + return AMDGPU::getWMMAIsXDL(MI.getOpcode()); + + return true; +} + bool SIInstrInfo::isXDL(const MachineInstr &MI) const { unsigned Opcode = MI.getOpcode(); - if (!SIInstrInfo::isMAI(MI) || isDGEMM(Opcode) || + if (AMDGPU::isGFX12Plus(ST)) + return isDOT(MI) || isXDLWMMA(MI); + + if (!isMAI(MI) || isDGEMM(Opcode) || Opcode == AMDGPU::V_ACCVGPR_WRITE_B32_e64 || Opcode == AMDGPU::V_ACCVGPR_READ_B32_e64) return false; diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.h b/llvm/lib/Target/AMDGPU/SIInstrInfo.h index a380199977616..3a48e6579238e 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.h +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.h @@ -867,6 +867,8 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo { return get(Opcode).TSFlags & SIInstrFlags::IsDOT; } + bool isXDLWMMA(const MachineInstr &MI) const; + bool isXDL(const MachineInstr &MI) const; static bool isDGEMM(unsigned Opcode) { return AMDGPU::getMAIIsDGEMM(Opcode); } diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp index 9df2bdededa13..77258810dd68c 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.cpp @@ -296,6 +296,7 @@ unsigned getCompletionActionImplicitArgPosition(unsigned CodeObjectVersion) { #define GET_MIMGOffsetMappingTable_IMPL #define GET_MIMGG16MappingTable_IMPL #define GET_MAIInstInfoTable_IMPL +#define GET_WMMAInstInfoTable_IMPL #include "AMDGPUGenSearchableTables.inc" int getMIMGOpcode(unsigned BaseOpcode, unsigned MIMGEncoding, @@ -568,6 +569,11 @@ bool getMAIIsGFX940XDL(unsigned Opc) { return Info && Info->is_gfx940_xdl; } +bool getWMMAIsXDL(unsigned Opc) { + const WMMAInstInfo *Info = getWMMAInstInfoHelper(Opc); + return Info ? Info->is_wmma_xdl : false; +} + uint8_t mfmaScaleF8F6F4FormatToNumRegs(unsigned EncodingVal) { switch (EncodingVal) { case MFMAScaleFormats::FP6_E2M3: diff --git a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h index 6708e0a3f4549..c9d2c286bf237 100644 --- a/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h +++ b/llvm/lib/Target/AMDGPU/Utils/AMDGPUBaseInfo.h @@ -119,6 +119,11 @@ struct True16D16Info { unsigned LoOp; }; +struct WMMAInstInfo { + uint16_t Opcode; + bool is_wmma_xdl; +}; + #define GET_MIMGBaseOpcode_DECL #define GET_MIMGDim_DECL #define GET_MIMGEncoding_DECL @@ -129,6 +134,7 @@ struct True16D16Info { #define GET_isMFMA_F8F6F4Table_DECL #define GET_isCvtScaleF32_F32F16ToF8F4Table_DECL #define GET_True16D16Table_DECL +#define GET_WMMAInstInfoTable_DECL #include "AMDGPUGenSearchableTables.inc" namespace IsaInfo { @@ -593,6 +599,9 @@ bool getMAIIsDGEMM(unsigned Opc); LLVM_READONLY bool getMAIIsGFX940XDL(unsigned Opc); +LLVM_READONLY +bool getWMMAIsXDL(unsigned Opc); + // Get an equivalent BitOp3 for a binary logical \p Opc. // \returns BitOp3 modifier for the logical operation or zero. // Used in VOPD3 conversion. diff --git a/llvm/test/CodeGen/AMDGPU/insert-delay-alu-wmma-xdl.mir b/llvm/test/CodeGen/AMDGPU/insert-delay-alu-wmma-xdl.mir new file mode 100644 index 0000000000000..7c3170d8d1e9f --- /dev/null +++ b/llvm/test/CodeGen/AMDGPU/insert-delay-alu-wmma-xdl.mir @@ -0,0 +1,84 @@ +# RUN: llc -mtriple=amdgcn -mcpu=gfx1250 -start-before=amdgpu-insert-delay-alu %s -o - | FileCheck %s + +--- +name: wmma_xdl_twoaddr_trans +tracksRegLiveness: true +body: | + bb.0: + ; CHECK-LABEL: {{^}}wmma_xdl_twoaddr_trans: + ; CHECK: %bb.0: + ; CHECK-NEXT: v_wmma_f32_16x16x64_fp8_fp8 v[8:15], v[0:7], v[0:7], v[8:15] + ; CHECK-NEXT: v_exp_f32_e32 v16, v16 + ; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2) + ; CHECK-NEXT: v_add_nc_u32_e32 v17, v17, v8 + liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, $vgpr16, $vgpr17 + $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = V_WMMA_F32_16X16X64_FP8_FP8_w32_twoaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, 8, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, 0, 0, 0, 0, implicit $exec + $vgpr16 = V_EXP_F32_e32 $vgpr16, implicit $exec, implicit $mode + $vgpr17 = V_ADD_U32_e32 $vgpr17, $vgpr8, implicit $exec +... + +--- +name: wmma_xdl_threeaddr_trans +tracksRegLiveness: true +body: | + bb.0: + ; CHECK-LABEL: {{^}}wmma_xdl_threeaddr_trans: + ; CHECK: %bb.0: + ; CHECK-NEXT: v_wmma_f32_16x16x64_fp8_fp8 v[8:15], v[0:7], v[0:7], v[16:23] + ; CHECK-NEXT: v_exp_f32_e32 v24, v24 + ; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2) + ; CHECK-NEXT: v_add_nc_u32_e32 v25, v25, v8 + liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15, $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24, $vgpr25 + $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = V_WMMA_F32_16X16X64_FP8_FP8_w32_threeaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, 8, $vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, 0, 0, 0, 0, implicit $exec + $vgpr24 = V_EXP_F32_e32 $vgpr24, implicit $exec, implicit $mode + $vgpr25 = V_ADD_U32_e32 $vgpr25, $vgpr8, implicit $exec +... + +name: swmmac_xdl_twoaddr_trans +tracksRegLiveness: true +body: | + bb.0: + ; CHECK-LABEL: {{^}}swmmac_xdl_twoaddr_trans: + ; CHECK: %bb.0: + ; CHECK-NEXT: v_swmmac_f16_16x16x128_bf8_bf8 v[24:27], v[0:7], v[8:23], v[28:29] + ; CHECK-NEXT: v_exp_f32_e32 v30, v30 + ; CHECK-NEXT: s_delay_alu instid0(TRANS32_DEP_2) + ; CHECK-NEXT: v_add_nc_u32_e32 v31, v31, v24 + liveins: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24_vgpr25_vgpr26_vgpr27, $vgpr28, $vgpr29, $vgpr30, $vgpr31 + $vgpr24_vgpr25_vgpr26_vgpr27 = V_SWMMAC_F16_16X16X128_BF8_BF8_w32_twoaddr $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7, $vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15_vgpr16_vgpr17_vgpr18_vgpr19_vgpr20_vgpr21_vgpr22_vgpr23, $vgpr24_vgpr25_vgpr26_vgpr27, $vgpr28_vgpr29, 0, 0, 0, implicit $exec + $vgpr30 = V_EXP_F32_e32 $vgpr30, implicit $exec, implicit $mode + $vgpr31 = V_ADD_U32_e32 $vgpr31, $vgpr24, implicit $exec +... + +name: wmma_non_xdl_large_data_valu +tracksRegLiveness: true +body: | + bb.0: + ; CHECK-LABEL: {{^}}wmma_non_xdl_large_data_valu: + ; CHECK: %bb.0: + ; CHECK-NEXT: v_wmma_f32_16x16x4_f32 v[4:11], v[0:1], v[2:3], v[4:11] matrix_b_reuse + ; CHECK-NEXT: v_exp_f32_e32 v12, v12 + ; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_1) + ; CHECK-NEXT: v_add_nc_u32_e32 v13, v13, v8 + liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4_vgpr5_vgpr6_vgpr7, $vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11, $vgpr12, $vgpr13 + $vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11 = V_WMMA_F32_16X16X4_F32_w32_twoaddr 8, $vgpr0_vgpr1, 8, $vgpr2_vgpr3, 8, $vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11, 0, -1, 0, 0, implicit $exec + $vgpr12 = V_EXP_F32_e32 $vgpr12, implicit $exec, implicit $mode + $vgpr13 = V_ADD_U32_e32 $vgpr13, $vgpr8, implicit $exec +... + +--- +name: dot_xdl_dep_2 +tracksRegLiveness: true +body: | + bb.0: + ; CHECK-LABEL: {{^}}dot_xdl_dep_2: + ; CHECK: %bb.0: + ; CHECK-NEXT: v_dot4_i32_iu8 v0, s2, s3, v0 neg_lo:[1,1,0] + ; CHECK-NEXT: v_dot4_i32_iu8 v1, s2, s3, v2 neg_lo:[1,1,0] + ; CHECK-NEXT: s_delay_alu instid0(VALU_DEP_2) + ; CHECK-NEXT: v_add_nc_u32_e32 v2, v0, v0 + liveins: $vgpr0, $sgpr2, $sgpr3, $vgpr0, $vgpr1, $vgpr2 + $vgpr0 = V_DOT4_I32_IU8 9, $sgpr2, 9, $sgpr3, 8, $vgpr0, 0, 0, 0, implicit $exec + $vgpr1 = V_DOT4_I32_IU8 9, $sgpr2, 9, $sgpr3, 8, $vgpr2, 0, 0, 0, implicit $exec + $vgpr2 = V_ADD_U32_e32 $vgpr0, $vgpr0, implicit $exec +...