Skip to content

Conversation

@kuhar
Copy link
Member

@kuhar kuhar commented Oct 24, 2025

Use the same format as introduced for wmma by
#164920 and for mfma by #165037.

@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-amdgpu

Author: Jakub Kuderski (kuhar)

Changes

Use the same format as introduced for wmma by
#164920 and for mfma by #165037.


Patch is 22.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165044.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+19-11)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+22-22)
  • (modified) mlir/test/Dialect/AMDGPU/canonicalize.mlir (+15-15)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+24)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d74abc22acd5e..91053d5b18461 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1109,9 +1109,9 @@ def AMDGPU_ScaledMFMAOp :
     AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
     Arguments<(ins
-                   I32Attr:$m,
-                   I32Attr:$n,
-                   I32Attr:$k,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[64, 128]>]>:$k,
                    ScaledMFMAInTypes:$sourceA,
                    ScaledMFMAInTypes:$sourceB,
                    ScaledMFMAOutTypes:$destC,
@@ -1124,8 +1124,8 @@ def AMDGPU_ScaledMFMAOp :
   let summary = "MLIR wrapper for CDNA scaled mfma instructions";
   let description = [{
     The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
-    for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
-    multiple outer products in order to allow fast matrix multiplication.
+    for various scaled versions of `mfma` instructions in the CDNA architecture, which
+    perform multiple outer products in order to allow fast matrix multiplication.
 
     The wrapper will select an appropriate `mfma` instruction, if one is available,
     based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
@@ -1140,15 +1140,23 @@ def AMDGPU_ScaledMFMAOp :
 
     This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
     - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
-    fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
-    size.
+      fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as
+      their tile size.
     - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
-    are omitted from this wrapper.
-    - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
-    double-precision operations on gfx94x and so are not included here.
+      are omitted from this wrapper.
+    - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported
+      for double-precision operations on gfx94x and so are not included here.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.scaled_mfma 32x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2
+        : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+    ```
   }];
   let assemblyFormat = [{
-    `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) ` `
+    `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*`
+    `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
     attr-dict
     `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index 39c31d5bf2fa3..cc868dc399b34 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -55,50 +55,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
 // CHECK-LABEL: func @scaled_mfma_to_rocdl(
 // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
 func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
-                    %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
-                    %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
-                    %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>, 
-                    %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
-  
+                                %arg1 : vector<4xf32>,       %arg2 : vector<32xf8E4M3FN>,
+                                %arg3 : vector<32xf8E5M2>,   %arg4 : vector<32xf6E2M3FN>,
+                                %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
+                                %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
+
   // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
   // CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
 
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
-  
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
+
   // CHECK: llvm.bitcast
-  
+
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
-  
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
+
   // CHECK: llvm.bitcast
-  
+
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
-  
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
+
   // CHECK: llvm.bitcast
   // CHECK: llvm.mlir.constant(3 : i32) : i32
 
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
-  
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
+
   // CHECK: llvm.bitcast
   // CHECK: llvm.mlir.constant(4 : i32) : i32
-  
+
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
 
   func.return
 }
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 52d3275dab43b..fee0c00606ab4 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -165,10 +165,10 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
 // CHECK-LABEL: func @scaled_mfma
 // CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
 // CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
 // CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
 // CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
 func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
   %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -176,12 +176,12 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
   %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
   %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_0 = amdgpu.scaled_mfma 16x16x128 (%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
   %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
   %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-  %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_1 = amdgpu.scaled_mfma 16x16x128 (%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
 }
 
@@ -192,7 +192,7 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
 // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
 // CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
 // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 ({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
 func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
   %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -200,17 +200,17 @@ func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
   %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
   %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_0 = amdgpu.scaled_mfma 16x16x128 (%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   return %res_0 : vector<4xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @scaled_mfma_ugly_shapes
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
 func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
   %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -237,10 +237,10 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
   %sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-  
-  %res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-  %res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-  %res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-  %res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+
+  %res_4 = amdgpu.scaled_mfma 16x16x128 (%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_5 = amdgpu.scaled_mfma 16x16x128 (%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_6 = amdgpu.scaled_mfma 16x16x128 (%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
 }
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6a2518a40cc99..7084f1547247e 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -302,3 +302,27 @@ func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
   func.return
 }
+
+// -----
+
+func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Jakub Kuderski (kuhar)

Changes

Use the same format as introduced for wmma by
#164920 and for mfma by #165037.


Patch is 22.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165044.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+19-11)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+22-22)
  • (modified) mlir/test/Dialect/AMDGPU/canonicalize.mlir (+15-15)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+24)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d74abc22acd5e..91053d5b18461 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1109,9 +1109,9 @@ def AMDGPU_ScaledMFMAOp :
     AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
                         Pure]>,
     Arguments<(ins
-                   I32Attr:$m,
-                   I32Attr:$n,
-                   I32Attr:$k,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[64, 128]>]>:$k,
                    ScaledMFMAInTypes:$sourceA,
                    ScaledMFMAInTypes:$sourceB,
                    ScaledMFMAOutTypes:$destC,
@@ -1124,8 +1124,8 @@ def AMDGPU_ScaledMFMAOp :
   let summary = "MLIR wrapper for CDNA scaled mfma instructions";
   let description = [{
     The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
-    for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
-    multiple outer products in order to allow fast matrix multiplication.
+    for various scaled versions of `mfma` instructions in the CDNA architecture, which
+    perform multiple outer products in order to allow fast matrix multiplication.
 
     The wrapper will select an appropriate `mfma` instruction, if one is available,
     based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
@@ -1140,15 +1140,23 @@ def AMDGPU_ScaledMFMAOp :
 
     This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
     - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
-    fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
-    size.
+      fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as
+      their tile size.
     - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
-    are omitted from this wrapper.
-    - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
-    double-precision operations on gfx94x and so are not included here.
+      are omitted from this wrapper.
+    - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported
+      for double-precision operations on gfx94x and so are not included here.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.scaled_mfma 32x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2
+        : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+    ```
   }];
   let assemblyFormat = [{
-    `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) ` `
+    `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*`
+    `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
     attr-dict
     `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index 39c31d5bf2fa3..cc868dc399b34 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -55,50 +55,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
 // CHECK-LABEL: func @scaled_mfma_to_rocdl(
 // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
 func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
-                    %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
-                    %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
-                    %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>, 
-                    %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
-  
+                                %arg1 : vector<4xf32>,       %arg2 : vector<32xf8E4M3FN>,
+                                %arg3 : vector<32xf8E5M2>,   %arg4 : vector<32xf6E2M3FN>,
+                                %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
+                                %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
+
   // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
   // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
   // CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
 
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
-  
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
+
   // CHECK: llvm.bitcast
-  
+
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
-  
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
+
   // CHECK: llvm.bitcast
-  
+
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
-  
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
+
   // CHECK: llvm.bitcast
   // CHECK: llvm.mlir.constant(3 : i32) : i32
 
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
-  
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
+
   // CHECK: llvm.bitcast
   // CHECK: llvm.mlir.constant(4 : i32) : i32
-  
+
   // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
+  amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
   // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
-  amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
+  amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
 
   func.return
 }
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 52d3275dab43b..fee0c00606ab4 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -165,10 +165,10 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
 // CHECK-LABEL: func @scaled_mfma
 // CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
 // CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
 // CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
 // CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
 func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
   %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -176,12 +176,12 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
   %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
   %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_0 = amdgpu.scaled_mfma 16x16x128 (%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
   %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
   %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-  %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_1 = amdgpu.scaled_mfma 16x16x128 (%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
 }
 
@@ -192,7 +192,7 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
 // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
 // CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
 // CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 ({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
 func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
   %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -200,17 +200,17 @@ func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
   %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
   %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_0 = amdgpu.scaled_mfma 16x16x128 (%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   return %res_0 : vector<4xf32>
 }
 
 // -----
 
 // CHECK-LABEL: func @scaled_mfma_ugly_shapes
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
 func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
   %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
   %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -237,10 +237,10 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
   %sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
   %sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-  
-  %res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-  %res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-  %res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-  %res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+
+  %res_4 = amdgpu.scaled_mfma 16x16x128 (%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_5 = amdgpu.scaled_mfma 16x16x128 (%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_6 = amdgpu.scaled_mfma 16x16x128 (%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
   return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
 }
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6a2518a40cc99..7084f1547247e 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -302,3 +302,27 @@ func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf
   %ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
   func.return
 }
+
+// -----
+
+func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
...
[truncated]

Use the same format as introduced for wmma by
llvm#164920 and for mfma by
llvm#165037.
@kuhar kuhar force-pushed the amdgpu-scaled-mfma-syntax branch from af2c29e to 9779a41 Compare October 25, 2025 11:30
@kuhar
Copy link
Member Author

kuhar commented Oct 25, 2025

Rebased

@kuhar kuhar enabled auto-merge (squash) October 25, 2025 11:31
@kuhar kuhar merged commit bbe9209 into llvm:main Oct 25, 2025
10 checks passed
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Oct 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants