diff --git a/llvm/lib/Target/AMDGPU/SIInstrInfo.td b/llvm/lib/Target/AMDGPU/SIInstrInfo.td index 04c92155f5aad..441d72cc173f5 100644 --- a/llvm/lib/Target/AMDGPU/SIInstrInfo.td +++ b/llvm/lib/Target/AMDGPU/SIInstrInfo.td @@ -284,12 +284,17 @@ def SIfptrunc_round_downward : SDNode<"AMDGPUISD::FPTRUNC_ROUND_DOWNWARD", // Returns 1 if the source arguments have modifiers, 0 if they do not. class isFloatType { bit ret = !or(!eq(SrcVT.Value, f16.Value), + !eq(SrcVT.Value, bf16.Value), !eq(SrcVT.Value, f32.Value), !eq(SrcVT.Value, f64.Value), !eq(SrcVT.Value, v2f16.Value), + !eq(SrcVT.Value, v2bf16.Value), !eq(SrcVT.Value, v4f16.Value), + !eq(SrcVT.Value, v4bf16.Value), !eq(SrcVT.Value, v8f16.Value), + !eq(SrcVT.Value, v8bf16.Value), !eq(SrcVT.Value, v16f16.Value), + !eq(SrcVT.Value, v16bf16.Value), !eq(SrcVT.Value, v2f32.Value), !eq(SrcVT.Value, v4f32.Value), !eq(SrcVT.Value, v8f32.Value), @@ -314,7 +319,9 @@ class isIntType { class isPackedType { bit ret = !or(!eq(SrcVT.Value, v2i16.Value), !eq(SrcVT.Value, v2f16.Value), + !eq(SrcVT.Value, v2bf16.Value), !eq(SrcVT.Value, v4f16.Value), + !eq(SrcVT.Value, v4bf16.Value), !eq(SrcVT.Value, v2i32.Value), !eq(SrcVT.Value, v2f32.Value), !eq(SrcVT.Value, v4i32.Value), @@ -1495,14 +1502,14 @@ class getVOPSrc0ForVT { !if(isFP, !if(!eq(VT.Size, 64), VSrc_f64, - !if(!eq(VT.Value, f16.Value), + !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), !if(IsTrue16, !if(IsFake16, VSrcFake16_f16_Lo128, VSrcT_f16_Lo128), VSrc_f16 ), - !if(!eq(VT.Value, v2f16.Value), + !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VSrc_v2f16, - !if(!eq(VT.Value, v4f16.Value), + !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32 ) @@ -1576,11 +1583,11 @@ class getVOP3SrcForVT { !if(!eq(VT.Value, i1.Value), SSrc_i1, !if(isFP, - !if(!eq(VT.Value, f16.Value), + !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), !if(IsTrue16, VSrcT_f16, VSrc_f16), - !if(!eq(VT.Value, v2f16.Value), + !if(!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VSrc_v2f16, - !if(!eq(VT.Value, v4f16.Value), + !if(!or(!eq(VT.Value, v4f16.Value), !eq(VT.Value, v4bf16.Value)), AVSrc_64, VSrc_f32 ) @@ -1605,8 +1612,8 @@ class getVOP3DPPSrcForVT { RegisterOperand ret = !if (!eq(VT.Value, i1.Value), SSrc_i1, !if (isFP, - !if (!eq(VT.Value, f16.Value), VCSrc_f16, - !if (!eq(VT.Value, v2f16.Value), VCSrc_v2f16, VCSrc_f32)), + !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), VCSrc_f16, + !if (!or(!eq(VT.Value, v2f16.Value), !eq(VT.Value, v2bf16.Value)), VCSrc_v2f16, VCSrc_f32)), !if (!eq(VT.Value, i16.Value), VCSrc_b16, !if (!eq(VT.Value, v2i16.Value), VCSrc_v2b16, VCSrc_b32)))); @@ -1615,22 +1622,27 @@ class getVOP3DPPSrcForVT { // Float or packed int class isModifierType { bit ret = !or(!eq(SrcVT.Value, f16.Value), + !eq(SrcVT.Value, bf16.Value), !eq(SrcVT.Value, f32.Value), !eq(SrcVT.Value, f64.Value), !eq(SrcVT.Value, v2f16.Value), !eq(SrcVT.Value, v2i16.Value), + !eq(SrcVT.Value, v2bf16.Value), !eq(SrcVT.Value, v2f32.Value), !eq(SrcVT.Value, v2i32.Value), !eq(SrcVT.Value, v4f16.Value), !eq(SrcVT.Value, v4i16.Value), + !eq(SrcVT.Value, v4bf16.Value), !eq(SrcVT.Value, v4f32.Value), !eq(SrcVT.Value, v4i32.Value), !eq(SrcVT.Value, v8f16.Value), !eq(SrcVT.Value, v8i16.Value), + !eq(SrcVT.Value, v8bf16.Value), !eq(SrcVT.Value, v8f32.Value), !eq(SrcVT.Value, v8i32.Value), !eq(SrcVT.Value, v16f16.Value), - !eq(SrcVT.Value, v16i16.Value)); + !eq(SrcVT.Value, v16i16.Value), + !eq(SrcVT.Value, v16bf16.Value)); } // Return type of input modifiers operand for specified input operand @@ -1646,7 +1658,8 @@ class getSrcMod { } class getOpSelMod { - Operand ret = !if(!eq(VT.Value, f16.Value), FP16InputMods, IntOpSelMods); + Operand ret = !if(!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), + FP16InputMods, IntOpSelMods); } // Return type of input modifiers operand specified input operand for DPP @@ -1659,8 +1672,8 @@ class getSrcModDPP_t16 { bit isFP = isFloatType.ret; Operand ret = !if (isFP, - !if (!eq(VT.Value, f16.Value), FPT16VRegInputMods, - FPVRegInputMods), + !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), + FPT16VRegInputMods, FPVRegInputMods), !if (!eq(VT.Value, i16.Value), IntT16VRegInputMods, IntVRegInputMods)); } @@ -1671,8 +1684,8 @@ class getSrcModVOP3DPP { bit isPacked = isPackedType.ret; Operand ret = !if (isFP, - !if (!eq(VT.Value, f16.Value), FP16VCSrcInputMods, - FP32VCSrcInputMods), + !if (!or(!eq(VT.Value, f16.Value), !eq(VT.Value, bf16.Value)), + FP16VCSrcInputMods, FP32VCSrcInputMods), Int32VCSrcInputMods); } @@ -1681,7 +1694,8 @@ class getSrcModSDWA { Operand ret = !if(!eq(VT.Value, f16.Value), FP16SDWAInputMods, !if(!eq(VT.Value, f32.Value), FP32SDWAInputMods, !if(!eq(VT.Value, i16.Value), Int16SDWAInputMods, - Int32SDWAInputMods))); + !if(!eq(VT.Value, bf16.Value), FP16SDWAInputMods, + Int32SDWAInputMods)))); } // Returns the input arguments for VOP[12C] instructions for the given SrcVT.