Skip to content

Conversation

@efric
Copy link
Contributor

@efric efric commented Dec 10, 2025

No description provided.

Signed-off-by: Eric Feng <Eric.Feng@amd.com>
@llvmbot
Copy link
Member

llvmbot commented Dec 10, 2025

@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir

Author: Eric Feng (efric)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/171737.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+15)
  • (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+76-1)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+76-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0edb208a8fcba..fe8a854cd1321 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -592,6 +592,21 @@ def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.b
 def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
 def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
 def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
+// New in gfx950.
+def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf16">;
+def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.f16">;
+def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x128.i8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.fp8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.fp8">;
+def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf16">;
+def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f16">;
+def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x64.i8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.fp8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.fp8">;
 
 
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 1b50feea418b6..745fea8e38955 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -288,7 +288,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                    %arg6 : vector<8 x i16>,
                    %arg7 : vector<2xi32>,
                    %arg8 : vector<4xi32>,
-                   %arg9 : vector<16xi32>) -> vector<4 x f32> {
+                   %arg9 : vector<16xi32>,
+                   %arg10 : vector<8 x f16>,
+                   %arg11 : vector<16 x f16>,
+                   %arg12 : vector<8 x bf16>,
+                   %arg13 : vector<16 x bf16>,
+                   %arg14 : vector<8 x i32>) -> vector<4 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.smfmac
@@ -362,6 +367,76 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
+  // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+                                 i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+                                 i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
   llvm.return %r0 : vector<4 x f32>
 }
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7be6d6ba4d7be..868597fba92a6 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -528,7 +528,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                    %arg6 : vector<8 x i16>,
                    %arg7 : vector<2xi32>,
                    %arg8 : vector<4xi32>,
-                   %arg9 : vector<16xi32>) -> vector<4 x f32> {
+                   %arg9 : vector<16xi32>,
+                   %arg10 : vector<8 x f16>,
+                   %arg11 : vector<16 x f16>,
+                   %arg12 : vector<8 x bf16>,
+                   %arg13 : vector<16 x bf16>,
+                   %arg14 : vector<8 x i32>) -> vector<4 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.smfmac
@@ -598,12 +603,81 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
-
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
   %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+                                 i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+                                 i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
   llvm.return %r0 : vector<4 x f32>
 }
 

@llvmbot
Copy link
Member

llvmbot commented Dec 10, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Eric Feng (efric)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/171737.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td (+15)
  • (modified) mlir/test/Dialect/LLVMIR/rocdl.mlir (+76-1)
  • (modified) mlir/test/Target/LLVMIR/rocdl.mlir (+76-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index 0edb208a8fcba..fe8a854cd1321 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -592,6 +592,21 @@ def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.b
 def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
 def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
 def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
+// New in gfx950.
+def ROCDL_smfmac_f32_16x16x64_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf16">;
+def ROCDL_smfmac_f32_16x16x64_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.f16">;
+def ROCDL_smfmac_i32_16x16x128_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x128.i8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.bf8.fp8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.bf8">;
+def ROCDL_smfmac_f32_16x16x128_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x128.fp8.fp8">;
+def ROCDL_smfmac_f32_32x32x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf16">;
+def ROCDL_smfmac_f32_32x32x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.f16">;
+def ROCDL_smfmac_i32_32x32x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x64.i8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.bf8.fp8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.bf8">;
+def ROCDL_smfmac_f32_32x32x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x64.fp8.fp8">;
 
 
 //===---------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir
index 1b50feea418b6..745fea8e38955 100644
--- a/mlir/test/Dialect/LLVMIR/rocdl.mlir
+++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir
@@ -288,7 +288,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                    %arg6 : vector<8 x i16>,
                    %arg7 : vector<2xi32>,
                    %arg8 : vector<4xi32>,
-                   %arg9 : vector<16xi32>) -> vector<4 x f32> {
+                   %arg9 : vector<16xi32>,
+                   %arg10 : vector<8 x f16>,
+                   %arg11 : vector<16 x f16>,
+                   %arg12 : vector<8 x bf16>,
+                   %arg13 : vector<16 x bf16>,
+                   %arg14 : vector<8 x i32>) -> vector<4 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.smfmac
@@ -362,6 +367,76 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
+  // CHECK: rocdl.smfmac.f32.16x16x64.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.f16 %{{.*}} : (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x64.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x32.bf16 %{{.*}} : (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.i32.16x16x128.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
+  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+                                 i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
+  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: rocdl.smfmac.i32.32x32x64.i8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
+  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+                                 i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8 %{{.*}} : (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
+  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
   llvm.return %r0 : vector<4 x f32>
 }
 
diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir
index 7be6d6ba4d7be..868597fba92a6 100644
--- a/mlir/test/Target/LLVMIR/rocdl.mlir
+++ b/mlir/test/Target/LLVMIR/rocdl.mlir
@@ -528,7 +528,12 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                    %arg6 : vector<8 x i16>,
                    %arg7 : vector<2xi32>,
                    %arg8 : vector<4xi32>,
-                   %arg9 : vector<16xi32>) -> vector<4 x f32> {
+                   %arg9 : vector<16xi32>,
+                   %arg10 : vector<8 x f16>,
+                   %arg11 : vector<16 x f16>,
+                   %arg12 : vector<8 x bf16>,
+                   %arg13 : vector<16 x bf16>,
+                   %arg14 : vector<8 x i32>) -> vector<4 x f32> {
   %csti32 = llvm.mlir.constant(42 : i32) : i32
 
   // CHECK-LABEL: rocdl.smfmac
@@ -598,12 +603,81 @@ llvm.func @rocdl.smfmac(%arg0 : i32,
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
-
   // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
   %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
                                 (vector<2xi32>, vector<4xi32>, vector<16xf32>,
                                  i32, i32, i32) -> vector<16xf32>
 
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r14 = rocdl.smfmac.f32.16x16x64.f16 %arg10, %arg11, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.f16(<8 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r15 = rocdl.smfmac.f32.32x32x32.f16 %arg10, %arg11, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xf16>, vector<16xf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r16 = rocdl.smfmac.f32.16x16x64.bf16 %arg12, %arg13, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf16(<8 x bfloat> %{{.*}}, <16 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r17 = rocdl.smfmac.f32.32x32x32.bf16 %arg12, %arg13, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<8xbf16>, vector<16xbf16>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x128.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+  %r18 = rocdl.smfmac.i32.16x16x128.i8 %arg8, %arg14, %arg8, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xi32>,
+                                 i32, i32, i32) -> vector<4xi32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r19 = rocdl.smfmac.f32.16x16x128.bf8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r20 = rocdl.smfmac.f32.16x16x128.bf8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r21 = rocdl.smfmac.f32.16x16x128.fp8.bf8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x128.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r22 = rocdl.smfmac.f32.16x16x128.fp8.fp8 %arg8, %arg14, %arg3, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<4xf32>,
+                                 i32, i32, i32) -> vector<4xf32>
+
+  // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x64.i8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
+  %r23 = rocdl.smfmac.i32.32x32x64.i8 %arg8, %arg14, %arg9, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xi32>,
+                                 i32, i32, i32) -> vector<16xi32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r24 = rocdl.smfmac.f32.32x32x64.bf8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.bf8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r25 = rocdl.smfmac.f32.32x32x64.bf8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.bf8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r26 = rocdl.smfmac.f32.32x32x64.fp8.bf8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
+  // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x64.fp8.fp8(<4 x i32> %{{.*}}, <8 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
+  %r27 = rocdl.smfmac.f32.32x32x64.fp8.fp8 %arg8, %arg14, %arg4, %csti32, %csti32, %csti32 :
+                                (vector<4xi32>, vector<8xi32>, vector<16xf32>,
+                                 i32, i32, i32) -> vector<16xf32>
+
   llvm.return %r0 : vector<4 x f32>
 }
 

Signed-off-by: Eric Feng <Eric.Feng@amd.com>
@efric
Copy link
Contributor Author

efric commented Dec 11, 2025

Thanks for reviewing. Can you help me merge this as I don't have the permissions.

@kuhar
Copy link
Member

kuhar commented Dec 11, 2025

Let's wait for @krzysz00 and let him merge if it look OK

@krzysz00
Copy link
Contributor

I'll check tomorrow if the signature isn't different for the sparse ones

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so it looks like the MfmaIntrOp class is doing that thing WMMA was doing until recently where we're just declaring varargs, which means, for example, that immargs aren't being represented as attributes.

Fixing that should happen but is off-topic for this PR. Therefore, approved,.

@krzysz00 krzysz00 merged commit 00b3a18 into llvm:main Dec 11, 2025
17 checks passed
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Dec 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants