Skip to content

Conversation

Muzammiluddin-Syed-ECE
Copy link
Contributor

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE commented Oct 7, 2025

No description provided.

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
@llvmbot
Copy link
Member

llvmbot commented Oct 7, 2025

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Muzammil (Muzammiluddin-Syed-ECE)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/162343.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+3-3)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+29-1)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+41-2)
diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
index ded00b1274670..04ce1aedfdb4d 100644
--- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
+++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td
@@ -4085,11 +4085,11 @@ class AMDGPUWmmaScaleF4IntrinsicModsC<LLVMType scale_ty> :
 
 defset list<Intrinsic> AMDGPUWMMAIntrinsicsGFX1250 = {
 def int_amdgcn_wmma_f32_16x16x4_f32       : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_f32_16x16x32_bf16     : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x32_f16      : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f16_16x16x32_f16      : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_bf16_16x16x32_bf16    : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyfloat_ty, llvm_anyfloat_ty>;
-def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyfloat_ty, llvm_anyfloat_ty, llvm_anyfloat_ty>;
+def int_amdgcn_wmma_bf16_16x16x32_bf16    : AMDGPUWmmaIntrinsicModsAllReuse<llvm_anyint_ty, llvm_anyint_ty>;
+def int_amdgcn_wmma_bf16f32_16x16x32_bf16 : AMDGPUWmmaIntrinsicModsAllDiff<llvm_anyint_ty, llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x64_fp8_fp8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x64_fp8_bf8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
 def int_amdgcn_wmma_f32_16x16x64_bf8_fp8  : AMDGPUWmmaIntrinsicModsC<llvm_anyint_ty, llvm_anyfloat_ty>;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index db1b7e3af62fd..3814a2dae0f3f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -471,7 +471,7 @@ def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f
 
 //===---------------------------------------------------------------------===//
 // WMMA intrinsics
-class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
+class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands = [],
                         list<Trait> traits = []> :
   ROCDL_IntrOp<mnemonic, [0], overloadedOperands, traits, 1>,
   Arguments<(ins Variadic<LLVM_Type>:$args)> {
@@ -492,6 +492,34 @@ def ROCDL_wmma_f32_16x16x16_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_b
 def ROCDL_wmma_f32_16x16x16_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
 def ROCDL_wmma_f32_16x16x16_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_fp8", [1]>;
 def ROCDL_wmma_i32_16x16x32_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x32.iu4", [1]>;
+// Available from gfx1250
+def ROCDL_wmma_f32_16x16x4_f32 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x4.f32", [1]>;
+def ROCDL_wmma_f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.bf16", [1]>;
+def ROCDL_wmma_f32_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x32.f16", [1]>;
+def ROCDL_wmma_f16_16x16x32_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x32.f16", [1]>;
+def ROCDL_wmma_bf16_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x32.bf16", [1]>;
+def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;
+def ROCDL_wmma_f32_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x64_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x64.bf8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f32_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_fp8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.fp8_bf8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_fp8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_fp8", [0]>;
+def ROCDL_wmma_f16_16x16x128_bf8_bf8 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x128.bf8_bf8", [0]>;
+def ROCDL_wmma_i32_16x16x64_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x64.iu8", [1]>;
+def ROCDL_wmma_scale_f32_16x16x128_f8f6f4 : ROCDL_Wmma_IntrOp<"wmma.scale.f32.16x16x128.f8f6f4">;
+def ROCDL_wmma_scale16_f32_16x16x128_f8f6f4 : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.16x16x128.f8f6f4">;
+def ROCDL_wmma_scale_f32_32x16x128_f4 : ROCDL_Wmma_IntrOp<"wmma.scale.f32.32x16x128.f4">;
+def ROCDL_wmma_scale16_f32_32x16x128_f4 : ROCDL_Wmma_IntrOp<"wmma.scale16.f32.32x16x128.f4">;
 
 //===---------------------------------------------------------------------===//
 // LDS transpose intrinsics (available in GFX950)
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 1c0c2eba002aa..e0400bf07b563 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -816,9 +816,12 @@ llvm.func @rocdl.mfma.scale.f32.16x16x128.f8f6f4(%arg0 : i32,
 }
 
 llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>,
-                      %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>) -> vector<8xf32> {
+                      %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>,
+                      %arg9 : vector<32xf16>, %arg10 : vector<16xf32>, %arg11 : vector<4xf32>, %arg12 : vector<32xf32>,
+                      %arg13 : vector<16xi32>, %arg14 : vector<64xf32>, %arg15 : vector<64xi32>, %arg16 : i32) -> vector<8xf32> {
   %zero = llvm.mlir.constant(false) : i1
-
+  %zero_i16 = llvm.mlir.constant(0 : i16) : i16
+  %zero_i32 = llvm.mlir.constant(0 : i32) : i32
   // ---- Wave32 -----
 
   // f16 -> f32
@@ -849,6 +852,42 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
   // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x32.iu4.v8i32.v2i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}})
   %r6.gfx12 = rocdl.wmma.i32.16x16x32.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
 
+  // f32 -> f32
+  // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x4.f32.v4f32.v16f32(i1 false, <16 x float> %10, i1 false, <16 x float> %10, i16 0, <4 x float> %11, i1 false, i1 false)
+  %r1.gfx1250 = rocdl.wmma.f32.16x16x4.f32 %zero, %arg10, %zero, %arg10, %zero_i16, %arg11, %zero, %zero : (i1, vector<16xf32>, i1, vector<16xf32>, i16, vector<4xf32>, i1, i1) -> vector<4xf32>
+
+  // bf16 -> f32
+  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.bf16.v32f32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
+  %r2.gfx1250 = rocdl.wmma.f32.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+
+  // f16 -> f32
+  // CHECK: call <32 x float> @llvm.amdgcn.wmma.f32.16x16x32.f16.v32f32.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x float> %12, i1 false, i1 false)
+  %r3.gfx1250 = rocdl.wmma.f32.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf32>, i1, i1) -> vector<32xf32>
+
+  // f16 -> f16
+  // CHECK: call <32 x half> @llvm.amdgcn.wmma.f16.16x16x32.f16.v32f16.v16f16(i1 false, <16 x half> %1, i1 false, <16 x half> %1, i16 0, <32 x half> %9, i1 false, i1 false)
+  %r4.gfx1250 = rocdl.wmma.f16.16x16x32.f16 %zero, %arg1, %zero, %arg1, %zero_i16, %arg9, %zero, %zero : (i1, vector<16xf16>, i1, vector<16xf16>, i16, vector<32xf16>, i1, i1) -> vector<32xf16>
+
+  // bf16 -> bf16
+  // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16.16x16x32.bf16.v16i32.v16i16(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <16 x i32> %13, i1 false, i1 false)
+  %r5.gfx1250 = rocdl.wmma.bf16.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg13, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<16xi32>, i1, i1) -> vector<16xi32>
+
+  // bf16 -> bf16 / f32
+  // CHECK: call <16 x i32> @llvm.amdgcn.wmma.bf16f32.16x16x32.bf16.v16i32.v16i16.v32f32(i1 false, <16 x i16> %2, i1 false, <16 x i16> %2, i16 0, <32 x float> %12, i1 false, i1 false)
+  %r6.gfx1250 = rocdl.wmma.bf16f32.16x16x32.bf16 %zero, %arg2, %zero, %arg2, %zero_i16, %arg12, %zero, %zero : (i1, vector<16xi16>, i1, vector<16xi16>, i16, vector<32xf32>, i1, i1) -> vector<16xi32>
+
+  // f8 -> f32
+  // CHECK: call <64 x float> @llvm.amdgcn.wmma.f32.16x16x64.fp8.fp8.v64f32.v4i32(<4 x i32> %5, <4 x i32> %5, i16 0, <64 x float> %14, i1 false, i1 false)
+  %r7.gfx1250 = rocdl.wmma.f32.16x16x64.fp8_fp8 %arg5, %arg5, %zero_i16, %arg14, %zero, %zero : (vector<4xi32>, vector<4xi32>, i16, vector<64xf32>, i1, i1) -> vector<64xf32>
+
+  // iu8 -> i32
+  // CHECK: call <64 x i32> @llvm.amdgcn.wmma.i32.16x16x64.iu8.v64i32.v4i32(i1 false, <4 x i32> %5, i1 false, <4 x i32> %5, <64 x i32> %15, i1 false, i1 false)
+  %r8.gfx1250 = rocdl.wmma.i32.16x16x64.iu8 %zero, %arg5, %zero, %arg5, %arg15, %zero, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<64xi32>, i1, i1) -> vector<64xi32>
+
+  %r9.gfx1250 = rocdl.wmma.scale.f32.16x16x128.f8f6f4 %zero_i32, %arg5, %zero_i32, %arg5, %zero_i16, %arg11, %zero_i32, %zero_i32, %arg16, %zero_i32, %zero_i32, %arg16, %zero, %zero : (i32, vector<4xi32>, i32, vector<4xi32>, i16, vector<4xf32>, i32, i32, i32, i32, i32, i32, i1, i1) -> vector<4xf32>
+  // %r7.gfx1250 = rocdl.wmma.scale16.f32.16x16x128.f8f6f4
+  // %r7.gfx1250 = rocdl.wmma.scale.f32.32x16x128.f4
+  // %r7.gfx1250 = rocdl.wmma.scale16.f32.32x16x128.f4
   // ---- Wave64 -----
 
   // f16 -> f32

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE marked this pull request as draft October 7, 2025 18:54
@Muzammiluddin-Syed-ECE
Copy link
Contributor Author

Labelled this as draft while I do a sanity check on this line of suspicious code:

def ROCDL_wmma_bf16f32_16x16x32_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16f32.16x16x32.bf16", [1,5]>;

@krzysz00
Copy link
Contributor

krzysz00 commented Oct 7, 2025

@Muzammiluddin-Syed-ECE Yes, that code looks correct

That being said, While You're Here (tm) ... would it be possible to actually list out the WMMA signatures instead of just using a variadic set of arguments? (Yes, that'll mean more than one template - you can probably borrow names off of LLVM. It'll also let us encode the AllTypesMatch<["x", "y"]> that the LLVM definitions imply). You might not need to change the syntax, but it'll make it harder to construct invalid IR. It'll also let us use Attributes for immargs like we should.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants