diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp index 9dc2bcb363227..0bcc156a50170 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1440,10 +1440,10 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI, case CASE_VFMA_SPLATS(FNMSUB): case CASE_VFMA_SPLATS(FNMACC): case CASE_VFMA_SPLATS(FNMSAC): - case CASE_VFMA_OPCODE_LMULS(FMACC, VV): - case CASE_VFMA_OPCODE_LMULS(FMSAC, VV): - case CASE_VFMA_OPCODE_LMULS(FNMACC, VV): - case CASE_VFMA_OPCODE_LMULS(FNMSAC, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FMACC, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FMSAC, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FNMACC, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FNMSAC, VV): case CASE_VFMA_OPCODE_LMULS(MADD, VX): case CASE_VFMA_OPCODE_LMULS(NMSUB, VX): case CASE_VFMA_OPCODE_LMULS(MACC, VX): @@ -1464,10 +1464,10 @@ bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI, return false; return true; } - case CASE_VFMA_OPCODE_LMULS(FMADD, VV): - case CASE_VFMA_OPCODE_LMULS(FMSUB, VV): - case CASE_VFMA_OPCODE_LMULS(FNMADD, VV): - case CASE_VFMA_OPCODE_LMULS(FNMSUB, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FMADD, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FMSUB, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FNMADD, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FNMSUB, VV): case CASE_VFMA_OPCODE_LMULS(MADD, VV): case CASE_VFMA_OPCODE_LMULS(NMSUB, VV): { // If the tail policy is undisturbed we can't commute. @@ -1585,10 +1585,10 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI, case CASE_VFMA_SPLATS(FNMADD): case CASE_VFMA_SPLATS(FNMSAC): case CASE_VFMA_SPLATS(FNMSUB): - case CASE_VFMA_OPCODE_LMULS(FMACC, VV): - case CASE_VFMA_OPCODE_LMULS(FMSAC, VV): - case CASE_VFMA_OPCODE_LMULS(FNMACC, VV): - case CASE_VFMA_OPCODE_LMULS(FNMSAC, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FMACC, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FMSAC, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FNMACC, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FNMSAC, VV): case CASE_VFMA_OPCODE_LMULS(MADD, VX): case CASE_VFMA_OPCODE_LMULS(NMSUB, VX): case CASE_VFMA_OPCODE_LMULS(MACC, VX): @@ -1611,10 +1611,10 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI, CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMADD, FNMACC) CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMSAC, FNMSUB) CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMSUB, FNMSAC) - CASE_VFMA_CHANGE_OPCODE_LMULS(FMACC, FMADD, VV) - CASE_VFMA_CHANGE_OPCODE_LMULS(FMSAC, FMSUB, VV) - CASE_VFMA_CHANGE_OPCODE_LMULS(FNMACC, FNMADD, VV) - CASE_VFMA_CHANGE_OPCODE_LMULS(FNMSAC, FNMSUB, VV) + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(FMACC, FMADD, VV) + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(FMSAC, FMSUB, VV) + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(FNMACC, FNMADD, VV) + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(FNMSAC, FNMSUB, VV) CASE_VFMA_CHANGE_OPCODE_LMULS(MACC, MADD, VX) CASE_VFMA_CHANGE_OPCODE_LMULS(MADD, MACC, VX) CASE_VFMA_CHANGE_OPCODE_LMULS(NMSAC, NMSUB, VX) @@ -1628,10 +1628,10 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI, return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI=*/false, OpIdx1, OpIdx2); } - case CASE_VFMA_OPCODE_LMULS(FMADD, VV): - case CASE_VFMA_OPCODE_LMULS(FMSUB, VV): - case CASE_VFMA_OPCODE_LMULS(FNMADD, VV): - case CASE_VFMA_OPCODE_LMULS(FNMSUB, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FMADD, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FMSUB, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FNMADD, VV): + case CASE_VFMA_OPCODE_LMULS_MF4(FNMSUB, VV): case CASE_VFMA_OPCODE_LMULS(MADD, VV): case CASE_VFMA_OPCODE_LMULS(NMSUB, VV): { assert((OpIdx1 == 1 || OpIdx2 == 1) && "Unexpected opcode index"); @@ -1642,10 +1642,10 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI, switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected opcode"); - CASE_VFMA_CHANGE_OPCODE_LMULS(FMADD, FMACC, VV) - CASE_VFMA_CHANGE_OPCODE_LMULS(FMSUB, FMSAC, VV) - CASE_VFMA_CHANGE_OPCODE_LMULS(FNMADD, FNMACC, VV) - CASE_VFMA_CHANGE_OPCODE_LMULS(FNMSUB, FNMSAC, VV) + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(FMADD, FMACC, VV) + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(FMSUB, FMSAC, VV) + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(FNMADD, FNMACC, VV) + CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(FNMSUB, FNMSAC, VV) CASE_VFMA_CHANGE_OPCODE_LMULS(MADD, MACC, VV) CASE_VFMA_CHANGE_OPCODE_LMULS(NMSUB, NMSAC, VV) } @@ -1674,13 +1674,16 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI, #define CASE_WIDEOP_OPCODE_COMMON(OP, LMUL) \ RISCV::PseudoV##OP##_##LMUL##_TIED -#define CASE_WIDEOP_OPCODE_LMULS(OP) \ - CASE_WIDEOP_OPCODE_COMMON(OP, MF8): \ - case CASE_WIDEOP_OPCODE_COMMON(OP, MF4): \ +#define CASE_WIDEOP_OPCODE_LMULS_MF4(OP) \ + CASE_WIDEOP_OPCODE_COMMON(OP, MF4): \ case CASE_WIDEOP_OPCODE_COMMON(OP, MF2): \ case CASE_WIDEOP_OPCODE_COMMON(OP, M1): \ case CASE_WIDEOP_OPCODE_COMMON(OP, M2): \ case CASE_WIDEOP_OPCODE_COMMON(OP, M4) + +#define CASE_WIDEOP_OPCODE_LMULS(OP) \ + CASE_WIDEOP_OPCODE_COMMON(OP, MF8): \ + case CASE_WIDEOP_OPCODE_LMULS_MF4(OP) // clang-format on #define CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, LMUL) \ @@ -1688,22 +1691,25 @@ MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI, NewOpc = RISCV::PseudoV##OP##_##LMUL; \ break; -#define CASE_WIDEOP_CHANGE_OPCODE_LMULS(OP) \ - CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF8) \ +#define CASE_WIDEOP_CHANGE_OPCODE_LMULS_MF4(OP) \ CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF4) \ CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF2) \ CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M1) \ CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2) \ CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4) +#define CASE_WIDEOP_CHANGE_OPCODE_LMULS(OP) \ + CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF8) \ + CASE_WIDEOP_CHANGE_OPCODE_LMULS_MF4(OP) + MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, LiveVariables *LV, LiveIntervals *LIS) const { switch (MI.getOpcode()) { default: break; - case CASE_WIDEOP_OPCODE_LMULS(FWADD_WV): - case CASE_WIDEOP_OPCODE_LMULS(FWSUB_WV): + case CASE_WIDEOP_OPCODE_LMULS_MF4(FWADD_WV): + case CASE_WIDEOP_OPCODE_LMULS_MF4(FWSUB_WV): case CASE_WIDEOP_OPCODE_LMULS(WADD_WV): case CASE_WIDEOP_OPCODE_LMULS(WADDU_WV): case CASE_WIDEOP_OPCODE_LMULS(WSUB_WV): @@ -1713,8 +1719,8 @@ MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI, switch (MI.getOpcode()) { default: llvm_unreachable("Unexpected opcode"); - CASE_WIDEOP_CHANGE_OPCODE_LMULS(FWADD_WV) - CASE_WIDEOP_CHANGE_OPCODE_LMULS(FWSUB_WV) + CASE_WIDEOP_CHANGE_OPCODE_LMULS_MF4(FWADD_WV) + CASE_WIDEOP_CHANGE_OPCODE_LMULS_MF4(FWSUB_WV) CASE_WIDEOP_CHANGE_OPCODE_LMULS(WADD_WV) CASE_WIDEOP_CHANGE_OPCODE_LMULS(WADDU_WV) CASE_WIDEOP_CHANGE_OPCODE_LMULS(WSUB_WV) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td index 40ab0bb204020..fe06006c9798a 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -72,9 +72,13 @@ def V_MF2 : LMULInfo<0b111, 4, VR, VR, VR, VR,/*NoVReg*/VR, "M // Used to iterate over all possible LMULs. defvar MxList = [V_MF8, V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8]; +// For floating point which don't need MF8. +defvar MxListF = [V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8]; // Used for widening and narrowing instructions as it doesn't contain M8. defvar MxListW = [V_MF8, V_MF4, V_MF2, V_M1, V_M2, V_M4]; +// For floating point which don't need MF8. +defvar MxListFW = [V_MF4, V_MF2, V_M1, V_M2, V_M4]; // Use for zext/sext.vf2 defvar MxListVF2 = [V_MF4, V_MF2, V_M1, V_M2, V_M4, V_M8]; @@ -1592,6 +1596,12 @@ multiclass VPseudoBinaryV_VV { defm _VV : VPseudoBinary; } +// Similar to VPseudoBinaryV_VV, but uses MxListF. +multiclass VPseudoBinaryFV_VV { + foreach m = MxListF in + defm _VV : VPseudoBinary; +} + multiclass VPseudoVGTR_VV_EEW { foreach m = MxList in { foreach sew = EEWList in { @@ -1654,8 +1664,8 @@ multiclass VPseudoVALU_MM { // * The destination EEW is greater than the source EEW, the source EMUL is // at least 1, and the overlap is in the highest-numbered part of the // destination register group is legal. Otherwise, it is illegal. -multiclass VPseudoBinaryW_VV { - foreach m = MxListW in +multiclass VPseudoBinaryW_VV mxlist = MxListW> { + foreach m = mxlist in defm _VV : VPseudoBinary; } @@ -1674,8 +1684,8 @@ multiclass VPseudoBinaryW_VF { "@earlyclobber $rd">; } -multiclass VPseudoBinaryW_WV { - foreach m = MxListW in { +multiclass VPseudoBinaryW_WV mxlist = MxListW> { + foreach m = mxlist in { defm _WV : VPseudoBinary; defm _WV : VPseudoTiedBinary, Sched<[WriteVFClassV, ReadVFClassV, ReadVMask]>; @@ -1797,7 +1807,7 @@ multiclass VPseudoVCLS_V { } multiclass VPseudoVSQR_V { - foreach m = MxList in { + foreach m = MxListF in { let VLMul = m.value in { def "_V_" # m.MX : VPseudoUnaryNoMask, Sched<[WriteVFSqrtV, ReadVFSqrtV, ReadVMask]>; @@ -1808,7 +1818,7 @@ multiclass VPseudoVSQR_V { } multiclass VPseudoVRCP_V { - foreach m = MxList in { + foreach m = MxListF in { let VLMul = m.value in { def "_V_" # m.MX : VPseudoUnaryNoMask, Sched<[WriteVFRecpV, ReadVFRecpV, ReadVMask]>; @@ -1871,8 +1881,8 @@ multiclass PseudoVEXT_VF8 { // lowest-numbered part of the source register group". // With LMUL<=1 the source and dest occupy a single register so any overlap // is in the lowest-numbered part. -multiclass VPseudoBinaryM_VV { - foreach m = MxList in +multiclass VPseudoBinaryM_VV mxlist = MxList> { + foreach m = mxlist in defm _VV : VPseudoBinaryM; } @@ -1987,14 +1997,14 @@ multiclass VPseudoVDIV_VV_VX { } multiclass VPseudoVFMUL_VV_VF { - defm "" : VPseudoBinaryV_VV, + defm "" : VPseudoBinaryFV_VV, Sched<[WriteVFMulV, ReadVFMulV, ReadVFMulV, ReadVMask]>; defm "" : VPseudoBinaryV_VF, Sched<[WriteVFMulF, ReadVFMulV, ReadVFMulF, ReadVMask]>; } multiclass VPseudoVFDIV_VV_VF { - defm "" : VPseudoBinaryV_VV, + defm "" : VPseudoBinaryFV_VV, Sched<[WriteVFDivV, ReadVFDivV, ReadVFDivV, ReadVMask]>; defm "" : VPseudoBinaryV_VF, Sched<[WriteVFDivF, ReadVFDivV, ReadVFDivF, ReadVMask]>; @@ -2013,21 +2023,21 @@ multiclass VPseudoVALU_VV_VX { } multiclass VPseudoVSGNJ_VV_VF { - defm "" : VPseudoBinaryV_VV, + defm "" : VPseudoBinaryFV_VV, Sched<[WriteVFSgnjV, ReadVFSgnjV, ReadVFSgnjV, ReadVMask]>; defm "" : VPseudoBinaryV_VF, Sched<[WriteVFSgnjF, ReadVFSgnjV, ReadVFSgnjF, ReadVMask]>; } multiclass VPseudoVMAX_VV_VF { - defm "" : VPseudoBinaryV_VV, + defm "" : VPseudoBinaryFV_VV, Sched<[WriteVFCmpV, ReadVFCmpV, ReadVFCmpV, ReadVMask]>; defm "" : VPseudoBinaryV_VF, Sched<[WriteVFCmpF, ReadVFCmpV, ReadVFCmpF, ReadVMask]>; } multiclass VPseudoVALU_VV_VF { - defm "" : VPseudoBinaryV_VV, + defm "" : VPseudoBinaryFV_VV, Sched<[WriteVFALUV, ReadVFALUV, ReadVFALUV, ReadVMask]>; defm "" : VPseudoBinaryV_VF, Sched<[WriteVFALUF, ReadVFALUV, ReadVFALUF, ReadVMask]>; @@ -2060,17 +2070,12 @@ multiclass VPseudoVWMUL_VV_VX { } multiclass VPseudoVWMUL_VV_VF { - defm "" : VPseudoBinaryW_VV, + defm "" : VPseudoBinaryW_VV, Sched<[WriteVFWMulV, ReadVFWMulV, ReadVFWMulV, ReadVMask]>; defm "" : VPseudoBinaryW_VF, Sched<[WriteVFWMulF, ReadVFWMulV, ReadVFWMulF, ReadVMask]>; } -multiclass VPseudoBinaryW_VV_VF { - defm "" : VPseudoBinaryW_VV; - defm "" : VPseudoBinaryW_VF; -} - multiclass VPseudoVWALU_WV_WX { defm "" : VPseudoBinaryW_WV, Sched<[WriteVIWALUV, ReadVIWALUV, ReadVIWALUV, ReadVMask]>; @@ -2079,14 +2084,14 @@ multiclass VPseudoVWALU_WV_WX { } multiclass VPseudoVFWALU_VV_VF { - defm "" : VPseudoBinaryW_VV, + defm "" : VPseudoBinaryW_VV, Sched<[WriteVFWALUV, ReadVFWALUV, ReadVFWALUV, ReadVMask]>; defm "" : VPseudoBinaryW_VF, Sched<[WriteVFWALUF, ReadVFWALUV, ReadVFWALUF, ReadVMask]>; } multiclass VPseudoVFWALU_WV_WF { - defm "" : VPseudoBinaryW_WV, + defm "" : VPseudoBinaryW_WV, Sched<[WriteVFWALUV, ReadVFWALUV, ReadVFWALUV, ReadVMask]>; defm "" : VPseudoBinaryW_WF, Sched<[WriteVFWALUF, ReadVFWALUV, ReadVFWALUF, ReadVMask]>; @@ -2191,8 +2196,9 @@ multiclass VPseudoTernaryWithPolicy { - foreach m = MxList in { +multiclass VPseudoTernaryV_VV_AAXA mxlist = MxList> { + foreach m = mxlist in { defm _VV : VPseudoTernaryWithPolicy; } @@ -2217,9 +2223,9 @@ multiclass VPseudoTernaryV_VF_AAXA { /*Commutable*/1>; } -multiclass VPseudoTernaryW_VV { +multiclass VPseudoTernaryW_VV mxlist = MxListW> { defvar constraint = "@earlyclobber $rd"; - foreach m = MxListW in + foreach m = mxlist in defm _VV : VPseudoTernaryWithPolicy; } @@ -2252,7 +2258,7 @@ multiclass VPseudoVMAC_VV_VX_AAXA { } multiclass VPseudoVMAC_VV_VF_AAXA { - defm "" : VPseudoTernaryV_VV_AAXA, + defm "" : VPseudoTernaryV_VV_AAXA, Sched<[WriteVFMulAddV, ReadVFMulAddV, ReadVFMulAddV, ReadVFMulAddV, ReadVMask]>; defm "" : VPseudoTernaryV_VF_AAXA, Sched<[WriteVFMulAddF, ReadVFMulAddV, ReadVFMulAddV, ReadVFMulAddF, ReadVMask]>; @@ -2278,7 +2284,7 @@ multiclass VPseudoVWMAC_VX { } multiclass VPseudoVWMAC_VV_VF { - defm "" : VPseudoTernaryW_VV, + defm "" : VPseudoTernaryW_VV, Sched<[WriteVFWMulAddV, ReadVFWMulAddV, ReadVFWMulAddV, ReadVFWMulAddV, ReadVMask]>; defm "" : VPseudoTernaryW_VF, Sched<[WriteVFWMulAddF, ReadVFWMulAddV, ReadVFWMulAddV, ReadVFWMulAddF, ReadVMask]>; @@ -2301,7 +2307,7 @@ multiclass VPseudoVCMPM_VV_VX { } multiclass VPseudoVCMPM_VV_VF { - defm "" : VPseudoBinaryM_VV, + defm "" : VPseudoBinaryM_VV, Sched<[WriteVFCmpV, ReadVFCmpV, ReadVFCmpV, ReadVMask]>; defm "" : VPseudoBinaryM_VF, Sched<[WriteVFCmpF, ReadVFCmpV, ReadVFCmpF, ReadVMask]>; @@ -2334,21 +2340,21 @@ multiclass VPseudoVWRED_VS { } multiclass VPseudoVFRED_VS { - foreach m = MxList in { + foreach m = MxListF in { defm _VS : VPseudoTernary, Sched<[WriteVFRedV, ReadVFRedV, ReadVFRedV, ReadVFRedV, ReadVMask]>; } } multiclass VPseudoVFREDO_VS { - foreach m = MxList in { + foreach m = MxListF in { defm _VS : VPseudoTernary, Sched<[WriteVFRedOV, ReadVFRedOV, ReadVFRedOV, ReadVFRedOV, ReadVMask]>; } } multiclass VPseudoVFWRED_VS { - foreach m = MxList in { + foreach m = MxListF in { defm _VS : VPseudoTernary, Sched<[WriteVFWRedV, ReadVFWRedV, ReadVFWRedV, ReadVFWRedV, ReadVMask]>; } @@ -2366,13 +2372,13 @@ multiclass VPseudoConversion, Sched<[WriteVFCvtFToIV, ReadVFCvtFToIV, ReadVMask]>; } multiclass VPseudoVCVTF_V { - foreach m = MxList in + foreach m = MxListF in defm _V : VPseudoConversion, Sched<[WriteVFCvtIToFV, ReadVFCvtIToFV, ReadVMask]>; } @@ -2385,7 +2391,7 @@ multiclass VPseudoConversionW_V { multiclass VPseudoVWCVTI_V { defvar constraint = "@earlyclobber $rd"; - foreach m = MxListW in + foreach m = MxListFW in defm _V : VPseudoConversion, Sched<[WriteVFWCvtFToIV, ReadVFWCvtFToIV, ReadVMask]>; } @@ -2399,7 +2405,7 @@ multiclass VPseudoVWCVTF_V { multiclass VPseudoVWCVTD_V { defvar constraint = "@earlyclobber $rd"; - foreach m = MxListW in + foreach m = MxListFW in defm _V : VPseudoConversion, Sched<[WriteVFWCvtFToFV, ReadVFWCvtFToFV, ReadVMask]>; } @@ -2413,14 +2419,14 @@ multiclass VPseudoVNCVTI_W { multiclass VPseudoVNCVTF_W { defvar constraint = "@earlyclobber $rd"; - foreach m = MxListW in + foreach m = MxListFW in defm _W : VPseudoConversion, Sched<[WriteVFNCvtIToFV, ReadVFNCvtIToFV, ReadVMask]>; } multiclass VPseudoVNCVTD_W { defvar constraint = "@earlyclobber $rd"; - foreach m = MxListW in + foreach m = MxListFW in defm _W : VPseudoConversion, Sched<[WriteVFNCvtFToFV, ReadVFNCvtFToFV, ReadVMask]>; }