Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMDGPU][wmma] - Add tied wmma intrinsic #69903

Merged
merged 1 commit into from
Oct 30, 2023

Conversation

OutOfCache
Copy link
Contributor

New PR based on D158059.

These new intrinsics, amdgcn_wmma_tied_f16_16x16x16_f16 and amdgcn_wmma_tied_f16_16x16x16_f16,
explicitly tie the destination accumulator matrix to the input accumulator matrix.

The wmma_f16 and wmma_bf16 intrinsics only write to 16-bit of the 32-bit destination VGPRs.
Which half is determined via the op_sel argument. The other half of the destination registers remains unchanged.

In some cases however, we expect the destination to copy the other halves from the input accumulator.
For instance, when packing two separate accumulator matrices into one. In that case, the two matrices
are tied into the same registers, but separate halves. Then it is important to copy the other matrix values
to the new destination.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 23, 2023

@llvm/pr-subscribers-backend-amdgpu
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-globalisel

Author: Jessica Del (OutOfCache)

Changes

New PR based on D158059.

These new intrinsics, amdgcn_wmma_tied_f16_16x16x16_f16 and amdgcn_wmma_tied_f16_16x16x16_f16,
explicitly tie the destination accumulator matrix to the input accumulator matrix.

The wmma_f16 and wmma_bf16 intrinsics only write to 16-bit of the 32-bit destination VGPRs.
Which half is determined via the op_sel argument. The other half of the destination registers remains unchanged.

In some cases however, we expect the destination to copy the other halves from the input accumulator.
For instance, when packing two separate accumulator matrices into one. In that case, the two matrices
are tied into the same registers, but separate halves. Then it is important to copy the other matrix values
to the new destination.


Patch is 35.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69903.diff

7 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+2)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp (+2)
  • (modified) llvm/lib/Target/AMDGPU/VOP3PInstructions.td (+26-31)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll (+100)
  • (modified) llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_64.ll (+84)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma_32.ll (+100)
  • (modified) llvm/test/CodeGen/AMDGPU/llvm.amdgcn.wmma_64.ll (+84)
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index 5f1d1d932f74cbd..b1f2b512628bb0d 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -2363,6 +2363,8 @@ def int_amdgcn_wmma_f32_16x16x16_f16   : AMDGPUWmmaIntrinsic<llvm_v16f16_ty, llv
 def int_amdgcn_wmma_f32_16x16x16_bf16  : AMDGPUWmmaIntrinsic<llvm_v16i16_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f16_16x16x16_f16   : AMDGPUWmmaIntrinsicOPSEL<llvm_v16f16_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_bf16_16x16x16_bf16 : AMDGPUWmmaIntrinsicOPSEL<llvm_v16i16_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_tied_f16_16x16x16_f16   : AMDGPUWmmaIntrinsicOPSEL<llvm_v16f16_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_tied_bf16_16x16x16_bf16 : AMDGPUWmmaIntrinsicOPSEL<llvm_v16i16_ty, llvm_anyint_ty>;
 def int_amdgcn_wmma_i32_16x16x16_iu8   : AMDGPUWmmaIntrinsicIU<llvm_v4i32_ty, llvm_anyint_ty>;
 def int_amdgcn_wmma_i32_16x16x16_iu4   : AMDGPUWmmaIntrinsicIU<llvm_v2i32_ty, llvm_anyint_ty>;
 
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
index 5b056bd9e5dba2c..f4f9aa5903c458c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp
@@ -4279,6 +4279,8 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
     case Intrinsic::amdgcn_sudot8:
     case Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16:
     case Intrinsic::amdgcn_wmma_f16_16x16x16_f16:
+    case Intrinsic::amdgcn_wmma_tied_bf16_16x16x16_bf16:
+    case Intrinsic::amdgcn_wmma_tied_f16_16x16x16_f16:
     case Intrinsic::amdgcn_wmma_f32_16x16x16_bf16:
     case Intrinsic::amdgcn_wmma_f32_16x16x16_f16:
     case Intrinsic::amdgcn_wmma_i32_16x16x16_iu4:
diff --git a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
index 05e68f46b32605d..c406e57f9de2654 100644
--- a/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
+++ b/llvm/lib/Target/AMDGPU/VOP3PInstructions.td
@@ -865,35 +865,26 @@ def WMMAOpcode3AddrMappingTable : WMMAMappingTable {
 //    it converts the default pseudo to the pseudo where src2 is not the same as vdst.
 // 3) @earlyclobber on the destination satisfies the constraint during RA.
 
-multiclass WMMAInst<string Suffix, string Instr, VOPProfile P, SDPatternOperator node = null_frag, RegisterOperand _Src01RC64 = VRegSrc_256, WMMAType Type> {
+multiclass WMMAInst<string Suffix, string Instr, VOPProfile P, SDPatternOperator node = null_frag, RegisterOperand _Src01RC64 = VRegSrc_256, WMMAType Type, bit convertibleTo3Addr> {
 
   defvar WMMAConstraints2Addr = "@earlyclobber $vdst,$vdst = $src2";
   defvar WMMAConstraints3Addr = "@earlyclobber $vdst";
 
   defvar WMMAProfile = VOPProfileWMMA<P, Suffix, _Src01RC64, Type.hasClamp, Type.hasOpsel>;
-  if !eq(Suffix, "_w32") then {
     let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0 in {
-      let Constraints = WMMAConstraints2Addr, isConvertibleToThreeAddress = 1 in {
-        def _twoaddr_w32 : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>;
-      }
-      let Constraints = WMMAConstraints3Addr, SchedRW = [Write32Bit, Write32Bit] in {
-        def _threeaddr_w32 : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>;
+      let Constraints = WMMAConstraints2Addr, isConvertibleToThreeAddress = convertibleTo3Addr in {
+        def _twoaddr # Suffix : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>;
       }
     }
-    def : WMMAOpcodeMapping<!cast<Instruction>(NAME # _twoaddr_w32),
-                            !cast<Instruction>(NAME # _threeaddr_w32)>;
-  } else if !eq(Suffix, "_w64") then {
-    let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0 in {
-      let Constraints = WMMAConstraints2Addr, isConvertibleToThreeAddress = 1 in {
-        def _twoaddr_w64 : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>;
-      }
-      let Constraints = WMMAConstraints3Addr, SchedRW = [Write32Bit, Write32Bit] in {
-        def _threeaddr_w64 : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>;
+    if !eq(convertibleTo3Addr, 1) then {
+      let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0 in {
+        let Constraints = WMMAConstraints3Addr, SchedRW = [Write32Bit, Write32Bit] in {
+          def _threeaddr # Suffix : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>;
+        }
       }
+      def : WMMAOpcodeMapping<!cast<Instruction>(NAME # _twoaddr # Suffix),
+                            !cast<Instruction>(NAME # _threeaddr # Suffix)>;
     }
-    def : WMMAOpcodeMapping<!cast<Instruction>(NAME # _twoaddr_w64),
-                            !cast<Instruction>(NAME # _threeaddr_w64)>;
-  }
 
   if !eq(Type, WMMAOpSel) then {
     def : WMMAOpSelPat<!cast<Instruction>(NAME # _twoaddr # Suffix), node, P>;
@@ -906,21 +897,25 @@ multiclass WMMAInst<string Suffix, string Instr, VOPProfile P, SDPatternOperator
 
 
 let WaveSizePredicate = isWave32 in {
-  defm V_WMMA_F32_16X16X16_F16   : WMMAInst<"_w32", "v_wmma_f32_16x16x16_f16",  VOP_V8F32_V16F16_V16F16_V8F32, int_amdgcn_wmma_f32_16x16x16_f16, VRegSrc_256, WMMARegular>;
-  defm V_WMMA_F32_16X16X16_BF16  : WMMAInst<"_w32", "v_wmma_f32_16x16x16_bf16", VOP_V8F32_V16I16_V16I16_V8F32, int_amdgcn_wmma_f32_16x16x16_bf16, VRegSrc_256, WMMARegular>;
-  defm V_WMMA_F16_16X16X16_F16   : WMMAInst<"_w32", "v_wmma_f16_16x16x16_f16",   VOP_V16F16_V16F16_V16F16_V16F16, int_amdgcn_wmma_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel>;
-  defm V_WMMA_BF16_16X16X16_BF16 : WMMAInst<"_w32", "v_wmma_bf16_16x16x16_bf16", VOP_V16I16_V16I16_V16I16_V16I16, int_amdgcn_wmma_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel>;
-  defm V_WMMA_I32_16X16X16_IU8   : WMMAInst<"_w32", "v_wmma_i32_16x16x16_iu8",   VOP_V8I32_V4I32_V4I32_V8I32, int_amdgcn_wmma_i32_16x16x16_iu8, VRegSrc_128, WMMAUIClamp>;
-  defm V_WMMA_I32_16X16X16_IU4   : WMMAInst<"_w32", "v_wmma_i32_16x16x16_iu4",   VOP_V8I32_V2I32_V2I32_V8I32, int_amdgcn_wmma_i32_16x16x16_iu4, VRegSrc_64,  WMMAUIClamp>;
+  defm V_WMMA_F32_16X16X16_F16   : WMMAInst<"_w32", "v_wmma_f32_16x16x16_f16",  VOP_V8F32_V16F16_V16F16_V8F32, int_amdgcn_wmma_f32_16x16x16_f16, VRegSrc_256, WMMARegular, 1>;
+  defm V_WMMA_F32_16X16X16_BF16  : WMMAInst<"_w32", "v_wmma_f32_16x16x16_bf16", VOP_V8F32_V16I16_V16I16_V8F32, int_amdgcn_wmma_f32_16x16x16_bf16, VRegSrc_256, WMMARegular, 1>;
+  defm V_WMMA_F16_16X16X16_F16   : WMMAInst<"_w32", "v_wmma_f16_16x16x16_f16",   VOP_V16F16_V16F16_V16F16_V16F16, int_amdgcn_wmma_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel, 1>;
+  defm V_WMMA_BF16_16X16X16_BF16 : WMMAInst<"_w32", "v_wmma_bf16_16x16x16_bf16", VOP_V16I16_V16I16_V16I16_V16I16, int_amdgcn_wmma_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel, 1>;
+  defm V_WMMA_TIED_F16_16X16X16_F16   : WMMAInst<"_w32", "v_wmma_f16_16x16x16_f16",   VOP_V16F16_V16F16_V16F16_V16F16, int_amdgcn_wmma_tied_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel, 0>;
+  defm V_WMMA_TIED_BF16_16X16X16_BF16 : WMMAInst<"_w32", "v_wmma_bf16_16x16x16_bf16", VOP_V16I16_V16I16_V16I16_V16I16, int_amdgcn_wmma_tied_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel, 0>;
+  defm V_WMMA_I32_16X16X16_IU8   : WMMAInst<"_w32", "v_wmma_i32_16x16x16_iu8",   VOP_V8I32_V4I32_V4I32_V8I32, int_amdgcn_wmma_i32_16x16x16_iu8, VRegSrc_128, WMMAUIClamp, 1>;
+  defm V_WMMA_I32_16X16X16_IU4   : WMMAInst<"_w32", "v_wmma_i32_16x16x16_iu4",   VOP_V8I32_V2I32_V2I32_V8I32, int_amdgcn_wmma_i32_16x16x16_iu4, VRegSrc_64,  WMMAUIClamp, 1>;
 }
 
 let WaveSizePredicate = isWave64 in {
-  defm V_WMMA_F32_16X16X16_F16   : WMMAInst<"_w64", "v_wmma_f32_16x16x16_f16",   VOP_V4F32_V16F16_V16F16_V4F32, int_amdgcn_wmma_f32_16x16x16_f16, VRegSrc_256, WMMARegular>;
-  defm V_WMMA_F32_16X16X16_BF16  : WMMAInst<"_w64", "v_wmma_f32_16x16x16_bf16",  VOP_V4F32_V16I16_V16I16_V4F32, int_amdgcn_wmma_f32_16x16x16_bf16, VRegSrc_256, WMMARegular>;
-  defm V_WMMA_F16_16X16X16_F16   : WMMAInst<"_w64", "v_wmma_f16_16x16x16_f16",   VOP_V8F16_V16F16_V16F16_V8F16, int_amdgcn_wmma_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel>;
-  defm V_WMMA_BF16_16X16X16_BF16 : WMMAInst<"_w64", "v_wmma_bf16_16x16x16_bf16", VOP_V8I16_V16I16_V16I16_V8I16, int_amdgcn_wmma_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel>;
-  defm V_WMMA_I32_16X16X16_IU8   : WMMAInst<"_w64", "v_wmma_i32_16x16x16_iu8",   VOP_V4I32_V4I32_V4I32_V4I32, int_amdgcn_wmma_i32_16x16x16_iu8, VRegSrc_128, WMMAUIClamp>;
-  defm V_WMMA_I32_16X16X16_IU4   : WMMAInst<"_w64", "v_wmma_i32_16x16x16_iu4",   VOP_V4I32_V2I32_V2I32_V4I32, int_amdgcn_wmma_i32_16x16x16_iu4, VRegSrc_64, WMMAUIClamp>;
+  defm V_WMMA_F32_16X16X16_F16   : WMMAInst<"_w64", "v_wmma_f32_16x16x16_f16",   VOP_V4F32_V16F16_V16F16_V4F32, int_amdgcn_wmma_f32_16x16x16_f16, VRegSrc_256, WMMARegular, 1>;
+  defm V_WMMA_F32_16X16X16_BF16  : WMMAInst<"_w64", "v_wmma_f32_16x16x16_bf16",  VOP_V4F32_V16I16_V16I16_V4F32, int_amdgcn_wmma_f32_16x16x16_bf16, VRegSrc_256, WMMARegular, 1>;
+  defm V_WMMA_F16_16X16X16_F16   : WMMAInst<"_w64", "v_wmma_f16_16x16x16_f16",   VOP_V8F16_V16F16_V16F16_V8F16, int_amdgcn_wmma_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel, 1>;
+  defm V_WMMA_BF16_16X16X16_BF16 : WMMAInst<"_w64", "v_wmma_bf16_16x16x16_bf16", VOP_V8I16_V16I16_V16I16_V8I16, int_amdgcn_wmma_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel, 1>;
+  defm V_WMMA_TIED_F16_16X16X16_F16   : WMMAInst<"_w64", "v_wmma_f16_16x16x16_f16",   VOP_V8F16_V16F16_V16F16_V8F16, int_amdgcn_wmma_tied_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel, 0>;
+  defm V_WMMA_TIED_BF16_16X16X16_BF16 : WMMAInst<"_w64", "v_wmma_bf16_16x16x16_bf16", VOP_V8I16_V16I16_V16I16_V8I16, int_amdgcn_wmma_tied_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel, 0>;
+  defm V_WMMA_I32_16X16X16_IU8   : WMMAInst<"_w64", "v_wmma_i32_16x16x16_iu8",   VOP_V4I32_V4I32_V4I32_V4I32, int_amdgcn_wmma_i32_16x16x16_iu8, VRegSrc_128, WMMAUIClamp, 1>;
+  defm V_WMMA_I32_16X16X16_IU4   : WMMAInst<"_w64", "v_wmma_i32_16x16x16_iu4",   VOP_V4I32_V2I32_V2I32_V4I32, int_amdgcn_wmma_i32_16x16x16_iu4, VRegSrc_64, WMMAUIClamp, 1>;
 
 }
 
diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll b/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
index 6ca2dd838d37ac9..1cc1c6d7d46e29d 100644
--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_32.ll
@@ -4,7 +4,9 @@
 declare <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half>, <16 x half> , <8 x float>)
 declare <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16(<16 x i16>, <16 x i16> , <8 x float>)
 declare <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half>, <16 x half> , <16 x half>, i1 immarg)
+declare <16 x half> @llvm.amdgcn.wmma.tied.f16.16x16x16.f16(<16 x half>, <16 x half> , <16 x half>, i1 immarg)
 declare <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16(<16 x i16>, <16 x i16> , <16 x i16>, i1 immarg)
+declare <16 x i16> @llvm.amdgcn.wmma.tied.bf16.16x16x16.bf16(<16 x i16>, <16 x i16> , <16 x i16>, i1 immarg)
 declare <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8(i1 immarg, <4 x i32>, i1 immarg, <4 x i32> , <8 x i32>, i1 immarg)
 declare <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4(i1 immarg, <2 x i32>, i1 immarg, <2 x i32> , <8 x i32>, i1 immarg)
 
@@ -78,6 +80,55 @@ bb:
   ret void
 }
 
+define amdgpu_ps void @test_wmma_f16_16x16x16_f16_untied(<16 x half> %A.0, <16 x half> %B.0, <16 x half> %A.1, <16 x half> %B.1, <16 x half> %C, ptr addrspace(1) %out.0, ptr addrspace(1) %out.1) {
+; W32-LABEL: test_wmma_f16_16x16x16_f16_untied:
+; W32:       ; %bb.0: ; %bb
+; W32-NEXT:    v_wmma_f16_16x16x16_f16 v[44:51], v[0:7], v[8:15], v[32:39]
+; W32-NEXT:    v_wmma_f16_16x16x16_f16 v[32:39], v[16:23], v[24:31], v[32:39]
+; W32-NEXT:    s_clause 0x1
+; W32-NEXT:    global_store_b128 v[40:41], v[44:47], off
+; W32-NEXT:    global_store_b128 v[40:41], v[48:51], off offset:16
+; W32-NEXT:    s_clause 0x1
+; W32-NEXT:    global_store_b128 v[42:43], v[32:35], off
+; W32-NEXT:    global_store_b128 v[42:43], v[36:39], off offset:16
+; W32-NEXT:    s_nop 0
+; W32-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
+; W32-NEXT:    s_endpgm
+bb:
+  %res.0 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %A.0, <16 x half> %B.0, <16 x half> %C, i1 0)
+  %res.1 = call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %A.1, <16 x half> %B.1, <16 x half> %C, i1 0)
+  store <16 x half> %res.0, ptr addrspace(1) %out.0, align 32
+  store <16 x half> %res.1, ptr addrspace(1) %out.1, align 32
+  ret void
+}
+
+define amdgpu_ps void @test_wmma_f16_16x16x16_f16_tied(<16 x half> %A.0, <16 x half> %B.0, <16 x half> %A.1, <16 x half> %B.1, <16 x half> %C, ptr addrspace(1) %out.0, ptr addrspace(1) %out.1) {
+; W32-LABEL: test_wmma_f16_16x16x16_f16_tied:
+; W32:       ; %bb.0: ; %bb
+; W32-NEXT:    v_dual_mov_b32 v51, v39 :: v_dual_mov_b32 v50, v38
+; W32-NEXT:    v_dual_mov_b32 v49, v37 :: v_dual_mov_b32 v48, v36
+; W32-NEXT:    v_dual_mov_b32 v47, v35 :: v_dual_mov_b32 v46, v34
+; W32-NEXT:    v_dual_mov_b32 v45, v33 :: v_dual_mov_b32 v44, v32
+; W32-NEXT:    v_wmma_f16_16x16x16_f16 v[32:39], v[16:23], v[24:31], v[32:39]
+; W32-NEXT:    s_delay_alu instid0(VALU_DEP_2)
+; W32-NEXT:    v_wmma_f16_16x16x16_f16 v[44:51], v[0:7], v[8:15], v[44:51]
+; W32-NEXT:    s_clause 0x1
+; W32-NEXT:    global_store_b128 v[40:41], v[44:47], off
+; W32-NEXT:    global_store_b128 v[40:41], v[48:51], off offset:16
+; W32-NEXT:    s_clause 0x1
+; W32-NEXT:    global_store_b128 v[42:43], v[32:35], off
+; W32-NEXT:    global_store_b128 v[42:43], v[36:39], off offset:16
+; W32-NEXT:    s_nop 0
+; W32-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
+; W32-NEXT:    s_endpgm
+bb:
+  %res.0 = call <16 x half> @llvm.amdgcn.wmma.tied.f16.16x16x16.f16(<16 x half> %A.0, <16 x half> %B.0, <16 x half> %C, i1 0)
+  %res.1 = call <16 x half> @llvm.amdgcn.wmma.tied.f16.16x16x16.f16(<16 x half> %A.1, <16 x half> %B.1, <16 x half> %C, i1 0)
+  store <16 x half> %res.0, ptr addrspace(1) %out.0, align 32
+  store <16 x half> %res.1, ptr addrspace(1) %out.1, align 32
+  ret void
+}
+
 ; @llvm.amdgcn.wmma.bf16.16x16x16.bf16
 
 define amdgpu_ps void @test_wmma_bf16_16x16x16_bf16_lo(<16 x i16> %A, <16 x i16> %B, <16 x i16> %C, ptr addrspace(1) %out) {
@@ -112,6 +163,55 @@ bb:
   ret void
 }
 
+define amdgpu_ps void @test_wmma_bf16_16x16x16_bf16_untied(<16 x i16> %A.0, <16 x i16> %B.0, <16 x i16> %A.1, <16 x i16> %B.1, <16 x i16> %C, ptr addrspace(1) %out.0, ptr addrspace(1) %out.1) {
+; W32-LABEL: test_wmma_bf16_16x16x16_bf16_untied:
+; W32:       ; %bb.0: ; %bb
+; W32-NEXT:    v_wmma_bf16_16x16x16_bf16 v[44:51], v[0:7], v[8:15], v[32:39]
+; W32-NEXT:    v_wmma_bf16_16x16x16_bf16 v[32:39], v[16:23], v[24:31], v[32:39]
+; W32-NEXT:    s_clause 0x1
+; W32-NEXT:    global_store_b128 v[40:41], v[44:47], off
+; W32-NEXT:    global_store_b128 v[40:41], v[48:51], off offset:16
+; W32-NEXT:    s_clause 0x1
+; W32-NEXT:    global_store_b128 v[42:43], v[32:35], off
+; W32-NEXT:    global_store_b128 v[42:43], v[36:39], off offset:16
+; W32-NEXT:    s_nop 0
+; W32-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
+; W32-NEXT:    s_endpgm
+bb:
+  %res.0 = call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16(<16 x i16> %A.0, <16 x i16> %B.0, <16 x i16> %C, i1 0)
+  %res.1 = call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16(<16 x i16> %A.1, <16 x i16> %B.1, <16 x i16> %C, i1 0)
+  store <16 x i16> %res.0, ptr addrspace(1) %out.0, align 32
+  store <16 x i16> %res.1, ptr addrspace(1) %out.1, align 32
+  ret void
+}
+
+define amdgpu_ps void @test_wmma_bf16_16x16x16_bf16_tied(<16 x i16> %A.0, <16 x i16> %B.0, <16 x i16> %A.1, <16 x i16> %B.1, <16 x i16> %C, ptr addrspace(1) %out.0, ptr addrspace(1) %out.1) {
+; W32-LABEL: test_wmma_bf16_16x16x16_bf16_tied:
+; W32:       ; %bb.0: ; %bb
+; W32-NEXT:    v_dual_mov_b32 v51, v39 :: v_dual_mov_b32 v50, v38
+; W32-NEXT:    v_dual_mov_b32 v49, v37 :: v_dual_mov_b32 v48, v36
+; W32-NEXT:    v_dual_mov_b32 v47, v35 :: v_dual_mov_b32 v46, v34
+; W32-NEXT:    v_dual_mov_b32 v45, v33 :: v_dual_mov_b32 v44, v32
+; W32-NEXT:    v_wmma_bf16_16x16x16_bf16 v[32:39], v[16:23], v[24:31], v[32:39]
+; W32-NEXT:    s_delay_alu instid0(VALU_DEP_2)
+; W32-NEXT:    v_wmma_bf16_16x16x16_bf16 v[44:51], v[0:7], v[8:15], v[44:51]
+; W32-NEXT:    s_clause 0x1
+; W32-NEXT:    global_store_b128 v[40:41], v[44:47], off
+; W32-NEXT:    global_store_b128 v[40:41], v[48:51], off offset:16
+; W32-NEXT:    s_clause 0x1
+; W32-NEXT:    global_store_b128 v[42:43], v[32:35], off
+; W32-NEXT:    global_store_b128 v[42:43], v[36:39], off offset:16
+; W32-NEXT:    s_nop 0
+; W32-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
+; W32-NEXT:    s_endpgm
+bb:
+  %res.0 = call <16 x i16> @llvm.amdgcn.wmma.tied.bf16.16x16x16.bf16(<16 x i16> %A.0, <16 x i16> %B.0, <16 x i16> %C, i1 0)
+  %res.1 = call <16 x i16> @llvm.amdgcn.wmma.tied.bf16.16x16x16.bf16(<16 x i16> %A.1, <16 x i16> %B.1, <16 x i16> %C, i1 0)
+  store <16 x i16> %res.0, ptr addrspace(1) %out.0, align 32
+  store <16 x i16> %res.1, ptr addrspace(1) %out.1, align 32
+  ret void
+}
+
 ; @llvm.amdgcn.wmma.i32.16x16x16.iu8
 
 define amdgpu_ps void @test_wmma_i32_16x16x16_ui8_unsigned_unsigned(<4 x i32> %A, <4 x i32> %B, <8 x i32> %C, ptr addrspace(1) %out) {
diff --git a/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_64.ll b/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_64.ll
index a18d0a569bfb6ef..66655a3f2d16c98 100644
--- a/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_64.ll
+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.wmma_64.ll
@@ -4,7 +4,9 @@
 declare <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16(<16 x half>, <16 x half>, <4 x float>)
 declare <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16(<16 x i16>, <16 x i16>, <4 x float>)
 declare <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half>, <16 x half>, <8 x half>, i1 immarg)
+declare <8 x half> @llvm.amdgcn.wmma.tied.f16.16x16x16.f16(<16 x half>, <16 x half>, <8 x half>, i1 immarg)
 declare <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16(<16 x i16>, <16 x i16>, <8 x i16>, i1 immarg)
+declare <8 x i16> @llvm.amdgcn.wmma.tied.bf16.16x16x16.bf16(<16 x i16>, <16 x i16>, <8 x i16>, i1 immarg)
 declare <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8(i1 immarg, <4 x i32>, i1 immarg, <4 x i32>, <4 x i32>, i1 immarg)
 declare <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4(i1 immarg, <2 x i32>, i1 immarg, <2 x i32>, <4 x i32>, i1 immarg)
 
@@ -70,6 +72,47 @@ bb:
   ret void
 }
 
+define amdgpu_ps void @test_wmma_f16_16x16x16_f16_untied(<16 x half> %A.0, <16 x half> %B.0, <16 x half> %A.1, <16 x half> %B.1, <8 x half> %C, ptr addrspace(1) %out.0, ptr addrspace(1) %out.1) {
+; W64-LABEL: test_wmma_f16_16x16x16_f16_untied:
+; W64:       ; %bb.0: ; %bb
+; W64-NEXT:    v_wmma_f16_16x16x16_f16 v[40:43], v[0:7], v[8:15], v[32:35]
+; W64-NEXT:    v_wmma_f16_16x16x16_f16 v[32:35], v[16:23], v[24:31], v[32:35]
+; W64-NEXT:    global_store_b128 v[36:37], v[40:43], off
+; W64-NEXT:    global_store_b128 v[38:39], v[32:35], off
+; W64-NEXT:    s_nop 0
+; W64-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
+; W64-NEXT:    s_endpgm
+bb:
+  %res.0 = call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %A.0, <16 x half> %B.0, <8 x half> %C, i1 0)
+  %res.1 = call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16(<16 x half> %A.1, <16 x half> %B.1, <8 x half> %C, i1 0)
+  store <8 x half> %res.0, ptr addrspace(1) %out.0, align 32
+  store <8 x half> %res.1, ptr addrspace(1) %out.1, align 32
+  ret void
+}
+
+define amdgpu_ps void @test_wmma_f16_16x16x16_f16_tied(<16 x half> %A.0, <16 x half> %B.0, <16 x half> %A.1, <16 x half> %B.1, <8 x half> %C, ptr addrspace(1) %out.0, ptr addrspace(1) %out.1) {
+; W64-LABEL: test_wmma_f16_16x16x16_f16_tied:
+; W64:       ; %bb.0: ; %bb
+; W64-NEXT:    v_mov_b32_e32 v43, v35
+; W64-NEXT:    v_mov_b32_e32 v42, v34
+; W64-NEXT:    v_mov_b32_e32 v41, v33
+; W64-NEXT:    v_mov_b32_e32 v40, v32
+; W64-NEXT:    v_wmma_f16_16x16x16_f16 v[32:35], v[16:23], v[24:31], v[32:35]
+; W64-NEXT:    s_delay_alu instid0(VALU_DEP_2)
+; W64-NEXT:    v_wmma_f16_16x16x16_f16 v[40:43], v[0:7], v[8:15], v[40:43]
+; W64-NEXT:    global_store_b128 v[36:37], v[40:43], off
+; W64-NEXT:    global_store_b128 v[38:39], v[32:35], off
+; W64-NEXT:    s_nop 0
+; W64-NEXT:    s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)
+; W64-NEXT:    s_endpgm
+bb:
+  %res.0 = call <8 x half> @llvm.amdgcn.wmma.tied.f16.16x16x16.f16(<16 x half> %A.0, <16 x half> %B.0, <8 x half> %C, i1 0)
+  %res.1 = call <8 x half> @llvm.amdgcn.wmma.tied.f16.16x16x16.f16(<16 x half> %A.1, <16 x half> %B...
[truncated]

Copy link
Collaborator

@nhaehnle nhaehnle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

arsenm
arsenm previously requested changes Oct 24, 2023
Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should avoid breaking up the underlying opcode name in the intrinsic with tied. Is there a better suffix you could use?

@OutOfCache
Copy link
Contributor Author

I think you should avoid breaking up the underlying opcode name in the intrinsic with tied. Is there a better suffix you could use?

I don't know about a better suffix, but I could move the 'tied' to the very end (amdgcn_wmma_f16_16x16x16_f16_tied). Would that clear up any confusion?
I initially placed the suffix there because it sounds like an modifier of the accumulator matrices, i.e., there is the intrinsic with any f16 accumulators, and there is the intrinsic with tied f16 accumulators. Moving the suffix to the end still sounds reasonable, then it becomes a modifier of the entire intrinsic itself.

@nhaehnle
Copy link
Collaborator

Thanks. Can you please still add a comment in IntrinsicsAMDGPU.td that for f16 / bf16 versions, the tied version copies the "other" half of the accumulator to the result, while in the non-tied versions, the "other" half of the result is undefined?

@piotrAMD
Copy link
Collaborator

Would "preserve" be a better suffix, as in "preserve the other half"?

Need to add clang builtins too, similarly to D128952 (can be a separate patch).

@arsenm
Copy link
Contributor

arsenm commented Oct 26, 2023

If this is strictly more expressive than what we already have for the same instructions, can we just auto-upgrade the existing intrinsic to the version with the tied operand?

@piotrAMD
Copy link
Collaborator

If this is strictly more expressive than what we already have for the same instructions, can we just auto-upgrade the existing intrinsic to the version with the tied operand?

This is more expressive, but can produce worse code in the cases where you do not care about the other half.

@nhaehnle
Copy link
Collaborator

Would "preserve" be a better suffix, as in "preserve the other half"?

I don't feel strongly about it, but in an SSA world I think "copy" is the better way to think about it. "Preserve" is the way to think about it in a non-SSA world, i.e. at the assembly level after register allocation.

If this is strictly more expressive than what we already have for the same instructions, can we just auto-upgrade the existing intrinsic to the version with the tied operand?

I wouldn't say it's more expressive, it's just different. As Piotr said, using the tied versions when you don't need them can lead to worse codegen.

@jayfoad
Copy link
Contributor

jayfoad commented Oct 27, 2023

The "Simplify wmma instruction defs" commit is clearly good. Maybe commit that one separately first?

@OutOfCache
Copy link
Contributor Author

The "Simplify wmma instruction defs" commit is clearly good. Maybe commit that one separately first?

Did that in this PR. Closing this PR for now. I will open another one for the new intrinsics after that one is merged.

Need to add clang builtins too, similarly to D128952 (can be a separate patch).

I added them locally and will open a PR once these intrinsics are merged.

@OutOfCache OutOfCache closed this Oct 30, 2023
@jayfoad
Copy link
Contributor

jayfoad commented Oct 30, 2023

The "Simplify wmma instruction defs" commit is clearly good. Maybe commit that one separately first?

Did that in this PR. Closing this PR for now. I will open another one for the new intrinsics after that one is merged.

I think it would have been fine to leave this one open it and rebase it after that one was merged. I know some people don't like force pushing because it loses some context, but it's surely better than closing the PR and opening a new one which loses all context.

@OutOfCache OutOfCache reopened this Oct 30, 2023
Add intrinsics for `wmma_f16` and `wmma_bf16`, which stay as two-address
instructions.

This is a requirement for a future optimization
regarding wmma instructions.
The new changes make use of the `op_sel` argument of `wmma` instructions
to read from the upper halves of the input
accumulator and write to the upper halves of the output matrix.
With two-address instructions, we can guarantee that the content
of the upper halves is the same as the input
accumulator.
With three-address instructions, the output
registers do not copy the content of the input
registers. Instead, the upper halves
remain unchanged from their previous values.
This can cause issues if there are unexpected
values remaining in these registers.

For example:
```
v_wmma_f16_16x16x16_f16 v[0:7],   ..., v[24:31]
v_wmma_f16_16x16x16_f16 v[32:30], ..., v[24:31]
```
After these two instructions run, there is
no guarantee that the content of bits 16-31 of
`v[0:7]` are the same as the ones from `v[24:31]`.
If we have another instruction like the following:
```
v_wmma_f16_16x16x16_f16 v[0:7], v[24:31], v[32:49], v[0:7] op_sel:[0,0,1]
```
We read from the upper halves of `v[0:7]`, but
the content is not necessarily correct.

For our purpose, we create new pseudo instructions, while maintaining
the behavior of the original instructions.
}
def : WMMAOpcodeMapping<!cast<Instruction>(NAME # _twoaddr # Suffix),
if convertibleTo3Addr then {
let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0 in {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: instead of repeating this "let" clause here, I think it would be neater to put this new code inside the existing "let" clause, i.e. move it up to just before the closing brace on line 878.

Yes this will mean that the "let" clause now applies to the WMMAOpcodeMapping part too, but that should be harmless.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I think I did that now, if I understood your comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted this, but this caused an Error. Moving the WMMAOpcodeMapping into the let produces

error: Value 'Mnemonic' unknown!
    let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0 in {
        ^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, and sorry for my bad idea.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea itself was not bad! It would look better. I am also surprised that it does not work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's explained by the implementation of if in TableGen being a bit of a hack :)

Copy link
Contributor

@jayfoad jayfoad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, thanks.

@OutOfCache OutOfCache merged commit 849297c into llvm:main Oct 30, 2023
5 checks passed
@OutOfCache
Copy link
Contributor Author

Here is the PR for the clang builtins: #70669

OutOfCache added a commit that referenced this pull request Nov 13, 2023
Add clang builtins for the new tied wmma intrinsics. 
These variations tie the destination
accumulator matrix to the input
accumulator matrix.

See #69903 for context.
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
Add clang builtins for the new tied wmma intrinsics. 
These variations tie the destination
accumulator matrix to the input
accumulator matrix.

See llvm#69903 for context.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants