Skip to content

Conversation

@kuhar
Copy link
Member

@kuhar kuhar commented Oct 24, 2025

This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes.

This is in preparation for adding support for gfx1250 wmma intrinsics
that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and
vector sizes, require the intrinsic shapes to be set explicitly as
attributes.
@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-ods

@llvm/pr-subscribers-backend-amdgpu

Author: Jakub Kuderski (kuhar)

Changes

This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes.


Patch is 31.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164920.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+29-16)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+24-1)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+40-30)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+34-27)
  • (renamed) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+14-13)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+23-23)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+43-3)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+11-4)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7184de93bfacb..3a808ff3a01e4 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
-                             [4, 8, 16],
-                             [F16, BF16,
-                              I8, SI8, UI8,
-                              I<4>, SI<4>, UI<4>,
-                              F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
 
     The negateA, negateB, and negateC flags are only supported for double-precision
     operations on gfx94x.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.mfma %matA * %matB + %matC
+        { abid = 1 : i32, cbsz = 1 : i32,
+          m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+        blgp = bcast_second_32 : f32, f32, vector<32xf32>
+    ```
   }];
   let assemblyFormat = [{
     $sourceA `*` $sourceB `+` $destC
@@ -982,6 +988,9 @@ def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
                        Pure]>,
     Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$m,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$n,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<32>, IntPowerOf2]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -990,28 +999,32 @@ def AMDGPU_WMMAOp :
                    UnitAttr:$unsignedB,
                    UnitAttr:$clamp)>,
     Results<(outs WMMAOutTypes: $destD)> {
-  let summary = "MLIR wrapper for RDNA3 wmma instructions";
+  let summary = "MLIR wrapper for wmma instructions";
   let description = [{
-    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
-    for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
-    perform a 16x16 * 16x16 matrix multiplication for different data types.
-    Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
-    integer inputs.
+    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
+    instructions in the AMDGPU architecture, which perform matrix multiplication.
+    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
+    dimensions.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
       - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
       - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
-    On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
-    all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
+    On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
+    the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
 
     `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
 
-    The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+    The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
     in case of overflow.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+    ```
   }];
   let assemblyFormat = [{
-    $sourceA `*` $sourceB `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
     attr-dict
     `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 3de57c923178a..b6fe61ff1afa2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file declares a dialect for MLIR wrappers around AMDGPU-specific
-// intrinssics and for other AMD GPU-specific functionality.
+// intrinsics and for other AMD GPU-specific functionality.
 //
 //===----------------------------------------------------------------------===//
 
@@ -26,6 +26,29 @@
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
 
+namespace mlir {
+/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                  IntegerAttr &n, IntegerAttr &k);
+inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
+                                         IntegerAttr &m, IntegerAttr &n,
+                                         IntegerAttr &k) {
+  return parseMNKDimensionList(parser, m, n, k);
+}
+
+/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
+                                  IntegerAttr n, IntegerAttr k) {
+  printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
+}
+inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
+                                  IntegerAttr m, IntegerAttr n, IntegerAttr k) {
+  printMNKDimensionList(printer, m, n, k);
+}
+} // namespace mlir
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9b154350cd913..478b6aaaec83a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
 /// on the architecture you are compiling for.
 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                   Chipset chipset) {
-  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
-  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
-  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
-  auto elemSourceType = sourceVectorType.getElementType();
-  auto elemBSourceType = sourceBVectorType.getElementType();
-  auto elemDestType = destVectorType.getElementType();
-
-  if (elemSourceType.isF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isF16() && elemDestType.isF16())
-    return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isBF16())
-    return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
-    return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (chipset.majorVersion == 11) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+  auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+  Type elemSourceType = sourceVectorType.getElementType();
+  Type elemBSourceType = sourceBVectorType.getElementType();
+  Type elemDestType = destVectorType.getElementType();
+
+  const uint32_t k = wmma.getK();
+
+  if (k == 16) {
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+    if (chipset.majorVersion == 11) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+    }
   }
-  if (chipset.majorVersion >= 12) {
+  if (chipset.majorVersion < 12)
+    return std::nullopt;
+
+  // gfx12+
+  if (k == 16) {
     if (isa<Float8E4M3FNType>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     if (isa<Float8E5M2Type>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
-      bool isWave64 = destVectorType.getNumElements() == 4;
-      // This is the ambiguous case. 8 inputs to the wave64 version means that
-      // we want the 16x16x32 version, but for wave32 they mean the short form.
-      bool has8Inputs = sourceVectorType.getNumElements() == 8;
-      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
-        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
-    }
+
+    return std::nullopt;
   }
-  return std::nullopt;
+  if (k == 32) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    return std::nullopt;
+  }
+
+  llvm_unreachable("unhandled WMMA case");
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db0ff210..eb40374d61303 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,52 @@ LogicalResult ScaledExtPacked816Op::verify() {
 //===----------------------------------------------------------------------===//
 // WMMAOp
 //===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
-  Type sourceAType = getSourceA().getType();
-  Type sourceBType = getSourceB().getType();
-  Type destType = getDestC().getType();
 
-  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
-  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
-  VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                        IntegerAttr &n, IntegerAttr &k) {
+  SmallVector<int64_t, 3> dimensions;
+  if (parser.parseDimensionList(dimensions, false, false))
+    return failure();
+  if (dimensions.size() != 3)
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected 3 dimensions in MNK dimension list";
 
-  Type sourceAElemType = sourceVectorAType.getElementType();
-  Type sourceBElemType = sourceVectorBType.getElementType();
-  Type destElemType = destVectorType.getElementType();
+  m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+  n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+  k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+  return success();
+}
 
-  if (sourceVectorAType.getNumElements() !=
-      sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sourceAElemType = sourceAType.getElementType();
+  Type sourceBElemType = sourceBType.getElementType();
+  if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
     return emitOpError("source vectors have different lengths: ")
-           << sourceVectorAType << " vs. " << sourceVectorBType;
+           << sourceAType << " vs. " << sourceBType;
   }
 
-  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
-  bool isSrcFloat =
-      isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
-          sourceAElemType);
-
-  if (isDestFloat && !isSrcFloat) {
-    return emitOpError("Expected float sources with float destination");
-  }
+  bool isDestFloat = destType.getElementType().isFloat();
+  bool isSrcFloat = sourceAElemType.isFloat();
 
-  if (!isDestFloat && isSrcFloat) {
-    return emitOpError("Expected int sources with int destination");
-  }
+  if (isDestFloat && !isSrcFloat)
+    return emitOpError("expected float sources with float destination");
+  if (!isDestFloat && isSrcFloat)
+    return emitOpError("expected int sources with int destination");
 
-  if (sourceAElemType != sourceBElemType &&
-      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
-        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+  if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
                "source element types much match (except for fp8) but have ")
            << sourceAType << " and " << sourceBType;
   }
+
+  if (!sourceAElemType.isInteger(4) && getK() != 16) {
+    return emitOpError("K dimension must be 16 for source element type ")
+           << sourceAElemType;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
similarity index 59%
rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3f8c1c5..d1301d0089220 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
                          %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
                          %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
                          %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
   // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
   // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78d5f040..b897323340402 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes

This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes.


Patch is 31.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164920.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+29-16)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+24-1)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+40-30)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+34-27)
  • (renamed) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+14-13)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+23-23)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+43-3)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+11-4)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7184de93bfacb..3a808ff3a01e4 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
-                             [4, 8, 16],
-                             [F16, BF16,
-                              I8, SI8, UI8,
-                              I<4>, SI<4>, UI<4>,
-                              F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
 
     The negateA, negateB, and negateC flags are only supported for double-precision
     operations on gfx94x.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.mfma %matA * %matB + %matC
+        { abid = 1 : i32, cbsz = 1 : i32,
+          m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+        blgp = bcast_second_32 : f32, f32, vector<32xf32>
+    ```
   }];
   let assemblyFormat = [{
     $sourceA `*` $sourceB `+` $destC
@@ -982,6 +988,9 @@ def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
                        Pure]>,
     Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$m,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$n,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<32>, IntPowerOf2]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -990,28 +999,32 @@ def AMDGPU_WMMAOp :
                    UnitAttr:$unsignedB,
                    UnitAttr:$clamp)>,
     Results<(outs WMMAOutTypes: $destD)> {
-  let summary = "MLIR wrapper for RDNA3 wmma instructions";
+  let summary = "MLIR wrapper for wmma instructions";
   let description = [{
-    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
-    for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
-    perform a 16x16 * 16x16 matrix multiplication for different data types.
-    Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
-    integer inputs.
+    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
+    instructions in the AMDGPU architecture, which perform matrix multiplication.
+    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
+    dimensions.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
       - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
       - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
-    On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
-    all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
+    On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
+    the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
 
     `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
 
-    The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+    The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
     in case of overflow.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+    ```
   }];
   let assemblyFormat = [{
-    $sourceA `*` $sourceB `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
     attr-dict
     `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 3de57c923178a..b6fe61ff1afa2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file declares a dialect for MLIR wrappers around AMDGPU-specific
-// intrinssics and for other AMD GPU-specific functionality.
+// intrinsics and for other AMD GPU-specific functionality.
 //
 //===----------------------------------------------------------------------===//
 
@@ -26,6 +26,29 @@
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
 
+namespace mlir {
+/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                  IntegerAttr &n, IntegerAttr &k);
+inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
+                                         IntegerAttr &m, IntegerAttr &n,
+                                         IntegerAttr &k) {
+  return parseMNKDimensionList(parser, m, n, k);
+}
+
+/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
+                                  IntegerAttr n, IntegerAttr k) {
+  printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
+}
+inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
+                                  IntegerAttr m, IntegerAttr n, IntegerAttr k) {
+  printMNKDimensionList(printer, m, n, k);
+}
+} // namespace mlir
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9b154350cd913..478b6aaaec83a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
 /// on the architecture you are compiling for.
 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                   Chipset chipset) {
-  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
-  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
-  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
-  auto elemSourceType = sourceVectorType.getElementType();
-  auto elemBSourceType = sourceBVectorType.getElementType();
-  auto elemDestType = destVectorType.getElementType();
-
-  if (elemSourceType.isF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isF16() && elemDestType.isF16())
-    return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isBF16())
-    return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
-    return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (chipset.majorVersion == 11) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+  auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+  Type elemSourceType = sourceVectorType.getElementType();
+  Type elemBSourceType = sourceBVectorType.getElementType();
+  Type elemDestType = destVectorType.getElementType();
+
+  const uint32_t k = wmma.getK();
+
+  if (k == 16) {
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+    if (chipset.majorVersion == 11) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+    }
   }
-  if (chipset.majorVersion >= 12) {
+  if (chipset.majorVersion < 12)
+    return std::nullopt;
+
+  // gfx12+
+  if (k == 16) {
     if (isa<Float8E4M3FNType>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     if (isa<Float8E5M2Type>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
-      bool isWave64 = destVectorType.getNumElements() == 4;
-      // This is the ambiguous case. 8 inputs to the wave64 version means that
-      // we want the 16x16x32 version, but for wave32 they mean the short form.
-      bool has8Inputs = sourceVectorType.getNumElements() == 8;
-      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
-        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
-    }
+
+    return std::nullopt;
   }
-  return std::nullopt;
+  if (k == 32) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    return std::nullopt;
+  }
+
+  llvm_unreachable("unhandled WMMA case");
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db0ff210..eb40374d61303 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,52 @@ LogicalResult ScaledExtPacked816Op::verify() {
 //===----------------------------------------------------------------------===//
 // WMMAOp
 //===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
-  Type sourceAType = getSourceA().getType();
-  Type sourceBType = getSourceB().getType();
-  Type destType = getDestC().getType();
 
-  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
-  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
-  VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                        IntegerAttr &n, IntegerAttr &k) {
+  SmallVector<int64_t, 3> dimensions;
+  if (parser.parseDimensionList(dimensions, false, false))
+    return failure();
+  if (dimensions.size() != 3)
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected 3 dimensions in MNK dimension list";
 
-  Type sourceAElemType = sourceVectorAType.getElementType();
-  Type sourceBElemType = sourceVectorBType.getElementType();
-  Type destElemType = destVectorType.getElementType();
+  m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+  n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+  k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+  return success();
+}
 
-  if (sourceVectorAType.getNumElements() !=
-      sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sourceAElemType = sourceAType.getElementType();
+  Type sourceBElemType = sourceBType.getElementType();
+  if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
     return emitOpError("source vectors have different lengths: ")
-           << sourceVectorAType << " vs. " << sourceVectorBType;
+           << sourceAType << " vs. " << sourceBType;
   }
 
-  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
-  bool isSrcFloat =
-      isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
-          sourceAElemType);
-
-  if (isDestFloat && !isSrcFloat) {
-    return emitOpError("Expected float sources with float destination");
-  }
+  bool isDestFloat = destType.getElementType().isFloat();
+  bool isSrcFloat = sourceAElemType.isFloat();
 
-  if (!isDestFloat && isSrcFloat) {
-    return emitOpError("Expected int sources with int destination");
-  }
+  if (isDestFloat && !isSrcFloat)
+    return emitOpError("expected float sources with float destination");
+  if (!isDestFloat && isSrcFloat)
+    return emitOpError("expected int sources with int destination");
 
-  if (sourceAElemType != sourceBElemType &&
-      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
-        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+  if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
                "source element types much match (except for fp8) but have ")
            << sourceAType << " and " << sourceBType;
   }
+
+  if (!sourceAElemType.isInteger(4) && getK() != 16) {
+    return emitOpError("K dimension must be 16 for source element type ")
+           << sourceAElemType;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
similarity index 59%
rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3f8c1c5..d1301d0089220 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
                          %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
                          %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
                          %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
   // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
   // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78d5f040..b897323340402 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir-amdgpu

Author: Jakub Kuderski (kuhar)

Changes

This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes.


Patch is 31.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164920.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+29-16)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+24-1)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+40-30)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+34-27)
  • (renamed) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+14-13)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+23-23)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+43-3)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+11-4)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7184de93bfacb..3a808ff3a01e4 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
-                             [4, 8, 16],
-                             [F16, BF16,
-                              I8, SI8, UI8,
-                              I<4>, SI<4>, UI<4>,
-                              F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
 
     The negateA, negateB, and negateC flags are only supported for double-precision
     operations on gfx94x.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.mfma %matA * %matB + %matC
+        { abid = 1 : i32, cbsz = 1 : i32,
+          m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+        blgp = bcast_second_32 : f32, f32, vector<32xf32>
+    ```
   }];
   let assemblyFormat = [{
     $sourceA `*` $sourceB `+` $destC
@@ -982,6 +988,9 @@ def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
                        Pure]>,
     Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$m,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$n,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<32>, IntPowerOf2]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -990,28 +999,32 @@ def AMDGPU_WMMAOp :
                    UnitAttr:$unsignedB,
                    UnitAttr:$clamp)>,
     Results<(outs WMMAOutTypes: $destD)> {
-  let summary = "MLIR wrapper for RDNA3 wmma instructions";
+  let summary = "MLIR wrapper for wmma instructions";
   let description = [{
-    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
-    for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
-    perform a 16x16 * 16x16 matrix multiplication for different data types.
-    Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
-    integer inputs.
+    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
+    instructions in the AMDGPU architecture, which perform matrix multiplication.
+    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
+    dimensions.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
       - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
       - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
-    On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
-    all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
+    On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
+    the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
 
     `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
 
-    The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+    The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
     in case of overflow.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+    ```
   }];
   let assemblyFormat = [{
-    $sourceA `*` $sourceB `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
     attr-dict
     `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 3de57c923178a..b6fe61ff1afa2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file declares a dialect for MLIR wrappers around AMDGPU-specific
-// intrinssics and for other AMD GPU-specific functionality.
+// intrinsics and for other AMD GPU-specific functionality.
 //
 //===----------------------------------------------------------------------===//
 
@@ -26,6 +26,29 @@
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
 
+namespace mlir {
+/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                  IntegerAttr &n, IntegerAttr &k);
+inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
+                                         IntegerAttr &m, IntegerAttr &n,
+                                         IntegerAttr &k) {
+  return parseMNKDimensionList(parser, m, n, k);
+}
+
+/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
+                                  IntegerAttr n, IntegerAttr k) {
+  printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
+}
+inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
+                                  IntegerAttr m, IntegerAttr n, IntegerAttr k) {
+  printMNKDimensionList(printer, m, n, k);
+}
+} // namespace mlir
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9b154350cd913..478b6aaaec83a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
 /// on the architecture you are compiling for.
 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                   Chipset chipset) {
-  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
-  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
-  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
-  auto elemSourceType = sourceVectorType.getElementType();
-  auto elemBSourceType = sourceBVectorType.getElementType();
-  auto elemDestType = destVectorType.getElementType();
-
-  if (elemSourceType.isF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isF16() && elemDestType.isF16())
-    return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isBF16())
-    return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
-    return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (chipset.majorVersion == 11) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+  auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+  Type elemSourceType = sourceVectorType.getElementType();
+  Type elemBSourceType = sourceBVectorType.getElementType();
+  Type elemDestType = destVectorType.getElementType();
+
+  const uint32_t k = wmma.getK();
+
+  if (k == 16) {
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+    if (chipset.majorVersion == 11) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+    }
   }
-  if (chipset.majorVersion >= 12) {
+  if (chipset.majorVersion < 12)
+    return std::nullopt;
+
+  // gfx12+
+  if (k == 16) {
     if (isa<Float8E4M3FNType>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     if (isa<Float8E5M2Type>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
-      bool isWave64 = destVectorType.getNumElements() == 4;
-      // This is the ambiguous case. 8 inputs to the wave64 version means that
-      // we want the 16x16x32 version, but for wave32 they mean the short form.
-      bool has8Inputs = sourceVectorType.getNumElements() == 8;
-      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
-        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
-    }
+
+    return std::nullopt;
   }
-  return std::nullopt;
+  if (k == 32) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    return std::nullopt;
+  }
+
+  llvm_unreachable("unhandled WMMA case");
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db0ff210..eb40374d61303 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,52 @@ LogicalResult ScaledExtPacked816Op::verify() {
 //===----------------------------------------------------------------------===//
 // WMMAOp
 //===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
-  Type sourceAType = getSourceA().getType();
-  Type sourceBType = getSourceB().getType();
-  Type destType = getDestC().getType();
 
-  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
-  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
-  VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                        IntegerAttr &n, IntegerAttr &k) {
+  SmallVector<int64_t, 3> dimensions;
+  if (parser.parseDimensionList(dimensions, false, false))
+    return failure();
+  if (dimensions.size() != 3)
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected 3 dimensions in MNK dimension list";
 
-  Type sourceAElemType = sourceVectorAType.getElementType();
-  Type sourceBElemType = sourceVectorBType.getElementType();
-  Type destElemType = destVectorType.getElementType();
+  m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+  n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+  k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+  return success();
+}
 
-  if (sourceVectorAType.getNumElements() !=
-      sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sourceAElemType = sourceAType.getElementType();
+  Type sourceBElemType = sourceBType.getElementType();
+  if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
     return emitOpError("source vectors have different lengths: ")
-           << sourceVectorAType << " vs. " << sourceVectorBType;
+           << sourceAType << " vs. " << sourceBType;
   }
 
-  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
-  bool isSrcFloat =
-      isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
-          sourceAElemType);
-
-  if (isDestFloat && !isSrcFloat) {
-    return emitOpError("Expected float sources with float destination");
-  }
+  bool isDestFloat = destType.getElementType().isFloat();
+  bool isSrcFloat = sourceAElemType.isFloat();
 
-  if (!isDestFloat && isSrcFloat) {
-    return emitOpError("Expected int sources with int destination");
-  }
+  if (isDestFloat && !isSrcFloat)
+    return emitOpError("expected float sources with float destination");
+  if (!isDestFloat && isSrcFloat)
+    return emitOpError("expected int sources with int destination");
 
-  if (sourceAElemType != sourceBElemType &&
-      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
-        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+  if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
                "source element types much match (except for fp8) but have ")
            << sourceAType << " and " << sourceBType;
   }
+
+  if (!sourceAElemType.isInteger(4) && getK() != 16) {
+    return emitOpError("K dimension must be 16 for source element type ")
+           << sourceAElemType;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
similarity index 59%
rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3f8c1c5..d1301d0089220 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
                          %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
                          %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
                          %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
   // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
   // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78d5f040..b897323340402 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Jakub Kuderski (kuhar)

Changes

This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes.


Patch is 31.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/164920.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+29-16)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h (+24-1)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+40-30)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+34-27)
  • (renamed) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir (+14-13)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir (+23-23)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+43-3)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+11-4)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 7184de93bfacb..3a808ff3a01e4 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
-                             [4, 8, 16],
-                             [F16, BF16,
-                              I8, SI8, UI8,
-                              I<4>, SI<4>, UI<4>,
-                              F8E4M3FN, F8E5M2]>]>;
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+                             VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
 
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp :
 
     The negateA, negateB, and negateC flags are only supported for double-precision
     operations on gfx94x.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.mfma %matA * %matB + %matC
+        { abid = 1 : i32, cbsz = 1 : i32,
+          m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+        blgp = bcast_second_32 : f32, f32, vector<32xf32>
+    ```
   }];
   let assemblyFormat = [{
     $sourceA `*` $sourceB `+` $destC
@@ -982,6 +988,9 @@ def AMDGPU_WMMAOp :
     AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>,
                        Pure]>,
     Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$m,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<16>]>:$n,
+                   ConfinedAttr<I32Attr, [IntMinValue<16>, IntMaxValue<32>, IntPowerOf2]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -990,28 +999,32 @@ def AMDGPU_WMMAOp :
                    UnitAttr:$unsignedB,
                    UnitAttr:$clamp)>,
     Results<(outs WMMAOutTypes: $destD)> {
-  let summary = "MLIR wrapper for RDNA3 wmma instructions";
+  let summary = "MLIR wrapper for wmma instructions";
   let description = [{
-    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics
-    for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which
-    perform a 16x16 * 16x16 matrix multiplication for different data types.
-    Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit
-    integer inputs.
+    The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
+    instructions in the AMDGPU architecture, which perform matrix multiplication.
+    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
+    dimensions.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
       - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14.
       - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15.
-    On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where
-    all values are valid and the `subwordOffset` must be `0`, as it cannot be used.
+    On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all
+    the values are valid and the `subwordOffset` must be `0`, as it cannot be used.
 
     `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned.
 
-    The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max()
+    The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()`
     in case of overflow.
+
+    Example:
+    ```mlir
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+    ```
   }];
   let assemblyFormat = [{
-    $sourceA `*` $sourceB `+` $destC
+    custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
     attr-dict
     `:` type($sourceA) `,` type($sourceB) `,` type($destC)
   }];
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
index 3de57c923178a..b6fe61ff1afa2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // This file declares a dialect for MLIR wrappers around AMDGPU-specific
-// intrinssics and for other AMD GPU-specific functionality.
+// intrinsics and for other AMD GPU-specific functionality.
 //
 //===----------------------------------------------------------------------===//
 
@@ -26,6 +26,29 @@
 
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc"
 
+namespace mlir {
+/// Parser for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                  IntegerAttr &n, IntegerAttr &k);
+inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *,
+                                         IntegerAttr &m, IntegerAttr &n,
+                                         IntegerAttr &k) {
+  return parseMNKDimensionList(parser, m, n, k);
+}
+
+/// Printer for the `custom<MNKDimensionList>` custom assembly format used by
+/// WMMAOp.
+inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m,
+                                  IntegerAttr n, IntegerAttr k) {
+  printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()});
+}
+inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *,
+                                  IntegerAttr m, IntegerAttr n, IntegerAttr k) {
+  printMNKDimensionList(printer, m, n, k);
+}
+} // namespace mlir
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc"
 
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 9b154350cd913..478b6aaaec83a 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Pass/Pass.h"
@@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
 /// on the architecture you are compiling for.
 static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
                                                   Chipset chipset) {
-  auto sourceVectorType = dyn_cast<VectorType>(wmma.getSourceA().getType());
-  auto sourceBVectorType = dyn_cast<VectorType>(wmma.getSourceB().getType());
-  auto destVectorType = dyn_cast<VectorType>(wmma.getDestC().getType());
-  auto elemSourceType = sourceVectorType.getElementType();
-  auto elemBSourceType = sourceBVectorType.getElementType();
-  auto elemDestType = destVectorType.getElementType();
-
-  if (elemSourceType.isF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isF32())
-    return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isF16() && elemDestType.isF16())
-    return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
-  if (elemSourceType.isBF16() && elemDestType.isBF16())
-    return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
-  if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
-    return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
-  if (chipset.majorVersion == 11) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+  auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
+  auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
+  auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
+  Type elemSourceType = sourceVectorType.getElementType();
+  Type elemBSourceType = sourceBVectorType.getElementType();
+  Type elemDestType = destVectorType.getElementType();
+
+  const uint32_t k = wmma.getK();
+
+  if (k == 16) {
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
+    if (chipset.majorVersion == 11) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
+    }
   }
-  if (chipset.majorVersion >= 12) {
+  if (chipset.majorVersion < 12)
+    return std::nullopt;
+
+  // gfx12+
+  if (k == 16) {
     if (isa<Float8E4M3FNType>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1027,17 +1036,18 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
     if (isa<Float8E5M2Type>(elemSourceType) &&
         isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) {
-      bool isWave64 = destVectorType.getNumElements() == 4;
-      // This is the ambiguous case. 8 inputs to the wave64 version means that
-      // we want the 16x16x32 version, but for wave32 they mean the short form.
-      bool has8Inputs = sourceVectorType.getNumElements() == 8;
-      if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs))
-        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
-    }
+
+    return std::nullopt;
   }
-  return std::nullopt;
+  if (k == 32) {
+    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    return std::nullopt;
+  }
+
+  llvm_unreachable("unhandled WMMA case");
 }
 
 namespace {
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 61166db0ff210..eb40374d61303 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -360,45 +360,52 @@ LogicalResult ScaledExtPacked816Op::verify() {
 //===----------------------------------------------------------------------===//
 // WMMAOp
 //===----------------------------------------------------------------------===//
-LogicalResult WMMAOp::verify() {
-  Type sourceAType = getSourceA().getType();
-  Type sourceBType = getSourceB().getType();
-  Type destType = getDestC().getType();
 
-  VectorType sourceVectorAType = dyn_cast<VectorType>(sourceAType);
-  VectorType sourceVectorBType = dyn_cast<VectorType>(sourceBType);
-  VectorType destVectorType = dyn_cast<VectorType>(destType);
+ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m,
+                                        IntegerAttr &n, IntegerAttr &k) {
+  SmallVector<int64_t, 3> dimensions;
+  if (parser.parseDimensionList(dimensions, false, false))
+    return failure();
+  if (dimensions.size() != 3)
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected 3 dimensions in MNK dimension list";
 
-  Type sourceAElemType = sourceVectorAType.getElementType();
-  Type sourceBElemType = sourceVectorBType.getElementType();
-  Type destElemType = destVectorType.getElementType();
+  m = parser.getBuilder().getI32IntegerAttr(dimensions[0]);
+  n = parser.getBuilder().getI32IntegerAttr(dimensions[1]);
+  k = parser.getBuilder().getI32IntegerAttr(dimensions[2]);
+  return success();
+}
 
-  if (sourceVectorAType.getNumElements() !=
-      sourceVectorBType.getNumElements()) {
+LogicalResult WMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+
+  Type sourceAElemType = sourceAType.getElementType();
+  Type sourceBElemType = sourceBType.getElementType();
+  if (sourceAType.getNumElements() != sourceBType.getNumElements()) {
     return emitOpError("source vectors have different lengths: ")
-           << sourceVectorAType << " vs. " << sourceVectorBType;
+           << sourceAType << " vs. " << sourceBType;
   }
 
-  bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
-  bool isSrcFloat =
-      isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
-          sourceAElemType);
-
-  if (isDestFloat && !isSrcFloat) {
-    return emitOpError("Expected float sources with float destination");
-  }
+  bool isDestFloat = destType.getElementType().isFloat();
+  bool isSrcFloat = sourceAElemType.isFloat();
 
-  if (!isDestFloat && isSrcFloat) {
-    return emitOpError("Expected int sources with int destination");
-  }
+  if (isDestFloat && !isSrcFloat)
+    return emitOpError("expected float sources with float destination");
+  if (!isDestFloat && isSrcFloat)
+    return emitOpError("expected int sources with int destination");
 
-  if (sourceAElemType != sourceBElemType &&
-      !(isa<Float8E5M2Type, Float8E4M3FNType>(sourceAElemType) &&
-        isa<Float8E5M2Type, Float8E4M3FNType>(sourceBElemType))) {
+  if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
                "source element types much match (except for fp8) but have ")
            << sourceAType << " and " << sourceBType;
   }
+
+  if (!sourceAElemType.isInteger(4) && getK() != 16) {
+    return emitOpError("K dimension must be 16 for source element type ")
+           << sourceAElemType;
+  }
   return success();
 }
 
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
similarity index 59%
rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
index 638a7c3f8c1c5..d1301d0089220 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir
@@ -1,35 +1,36 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s
+
 // CHECK-LABEL: @wmma_to_rocdl
 func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>,
                          %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>,
                          %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>,
                          %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) {
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
   // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32>
-  amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
   // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
-  amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
+  amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
   // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
   // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
   // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
-  amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+  amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32>
-  amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
+  amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32>
   // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32>
-  amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
+  amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32>
 
   func.return
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
index 94a1b78d5f040..b897323340402 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 ...
[truncated]

@Muzammiluddin-Syed-ECE
Copy link
Contributor

Do we need another attribute to represent the fact that the accumulator and the result types can be different?
Specifically, I'm thinking about the intrinsic with shape 16x16x32 and source types BF16 and result type BF16. There are two different intrinsics that can be used here that differ in what accumulator type they support.

@Muzammiluddin-Syed-ECE
Copy link
Contributor

Do we need another attribute to represent the fact that the accumulator and the result types can be different? Specifically, I'm thinking about the intrinsic with shape 16x16x32 and source types BF16 and result type BF16. There are two different intrinsics that can be used here that differ in what accumulator type they support.

So for instance including both type(destd) and type(destc) in the assembly format

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM


Example:
```mlir
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the syntax is okay, it is weird that the mfma instructions encode this stuff as an attribute dict while wmma does it as a custom parser

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can give mfma the same syntax, I didn't want to make too many changes in the same PR though

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:ods labels Oct 24, 2025
Copy link
Contributor

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kuhar kuhar merged commit dc5f274 into llvm:main Oct 24, 2025
12 checks passed
kuhar added a commit to kuhar/llvm-project that referenced this pull request Oct 24, 2025
Use the same format as introduced for wmma by
llvm#164920.

Also make `blocks` default to 1.
kuhar added a commit to kuhar/llvm-project that referenced this pull request Oct 24, 2025
Use the same format as introduced for wmma by
llvm#164920.

Also make `blocks` default to 1.
kuhar added a commit to kuhar/llvm-project that referenced this pull request Oct 24, 2025
Use the same format as introduced for wmma by
llvm#164920 and for mfma by
llvm#165037.
kuhar added a commit that referenced this pull request Oct 25, 2025
)

Use the same format as introduced for wmma by
#164920.

Also make `blocks` default to 1.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Oct 25, 2025
…shape (#165037)

Use the same format as introduced for wmma by
llvm/llvm-project#164920.

Also make `blocks` default to 1.
kuhar added a commit to kuhar/llvm-project that referenced this pull request Oct 25, 2025
Use the same format as introduced for wmma by
llvm#164920 and for mfma by
llvm#165037.
kuhar added a commit that referenced this pull request Oct 25, 2025
#165044)

Use the same format as introduced for wmma by
#164920 and for mfma by
#165037.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Oct 25, 2025
…rinsic shape (#165044)

Use the same format as introduced for wmma by
llvm/llvm-project#164920 and for mfma by
llvm/llvm-project#165037.
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
This is in preparation for adding support for gfx1250 wmma intrinsics
that include much more possible shapes.

Instead of guessing the wave32/wave64 mode based on element types and
vector sizes, require the intrinsic shapes to be set explicitly as
attributes.
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
…#165037)

Use the same format as introduced for wmma by
llvm#164920.

Also make `blocks` default to 1.
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants