Skip to content

Conversation

@kuhar
Copy link
Member

@kuhar kuhar commented Oct 25, 2025

Update amdgpu.wmma op definition and implement amdgpu to rocdl conversion for new variants.

@llvmbot
Copy link
Member

llvmbot commented Oct 25, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir-amdgpu

Author: Jakub Kuderski (kuhar)

Changes

Update amdgpu.wmma op definition and implement amdgpu to rocdl conversion for new variants.


Patch is 22.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165064.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+20-7)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+106-11)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+6-4)
  • (added) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir (+89)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+50-10)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+35)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d74abc22acd5e..99e10b231482f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
 // wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
-                             VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
-                             VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
+                             VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+                             VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
+                             VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
                              VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
 def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
                               VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -990,7 +991,7 @@ def AMDGPU_WMMAOp :
     Arguments<(ins
                    ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
                    ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
-                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
                    WMMAInTypes:$sourceA,
                    WMMAInTypes:$sourceB,
                    WMMAOutTypes:$destC,
@@ -1003,8 +1004,14 @@ def AMDGPU_WMMAOp :
   let description = [{
     The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
     instructions in the AMDGPU architecture, which perform matrix multiplication.
-    Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
-    dimensions.
+
+    On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.
+
+    On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
+    all element types, and K=32 for i4 sources.
+
+    On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
+    depending on the element types.
 
     On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
     (or 16xbf16) vector containing only 8 valid values:
@@ -1020,7 +1027,13 @@ def AMDGPU_WMMAOp :
 
     Example:
     ```mlir
-      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+      %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
+
+      %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>
+
+      %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>
+
+      %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
     ```
   }];
   let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aaaec83a..90e731c11da62 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1002,8 +1002,13 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
   Type elemDestType = destVectorType.getElementType();
 
   const uint32_t k = wmma.getK();
+  const bool isRDNA3 = chipset.majorVersion == 11;
+  const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
 
   if (k == 16) {
+    if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
     if (elemSourceType.isF16() && elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
     if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1019,34 +1024,124 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
         return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
     }
   }
-  if (chipset.majorVersion < 12)
+  if (isRDNA3)
     return std::nullopt;
 
+  using fp8 = Float8E4M3FNType;
+  using bf8 = Float8E5M2Type;
+
   // gfx12+
   if (k == 16) {
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+      return std::nullopt;
+
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
-    if (isa<Float8E4M3FNType>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
-    if (isa<Float8E5M2Type>(elemSourceType) &&
-        isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+        elemDestType.isF32())
       return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+
     if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
       return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
 
     return std::nullopt;
   }
   if (k == 32) {
-    if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
-      return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+    if (isRDNA4) {
+      if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+        return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+      return std::nullopt;
+    }
+
+    // gfx1250
+    if (elemSourceType.isF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+    if (elemSourceType.isF16() && elemDestType.isF16())
+      return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+    if (elemSourceType.isBF16() && elemDestType.isBF16())
+      return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (isRDNA4)
+    return std::nullopt;
+
+  // gfx1250
+  if (k == 4) {
+    if (elemSourceType.isF32() && elemDestType.isF32())
+      return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
     return std::nullopt;
   }
 
+  if (k == 64) {
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+    }
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+    }
+    if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+      return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+    return std::nullopt;
+  }
+
+  if (k == 128) {
+    if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+    }
+    if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+    }
+    if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+      if (elemDestType.isF32())
+        return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+      if (elemDestType.isF16())
+        return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+    }
+
+    return std::nullopt;
+  }
   llvm_unreachable("unhandled WMMA case");
 }
 
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4c4965e67676e..7b4e248dcf5e4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
 
   if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
     return emitOpError(
-               "source element types much match (except for fp8) but have ")
+               "source element types much match (except for fp8/bf8) but have ")
            << sourceAType << " and " << sourceBType;
   }
 
-  if (!sourceAElemType.isInteger(4) && getK() != 16) {
-    return emitOpError("K dimension must be 16 for source element type ")
-           << sourceAElemType;
+  if (isSrcFloat) {
+    if (getClamp())
+      return emitOpError("clamp flag is not supported for float types");
+    if (getUnsignedA() || getUnsignedB())
+      return emitOpError("unsigned flags are not supported for float types");
   }
   return success();
 }
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
new file mode 100644
index 0000000000000..bcbdef040ebe3
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @wmma_k4
+func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
+  // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
+  amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k32
+func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
+                    %arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
+  // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
+  amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
+  amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k64
+func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
+                    %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
+  // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
+  amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
+  amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_k128
+func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                     %arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
+  // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
+
+  // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
+  amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
+
+  // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+  amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
+
+  func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6a2518a40cc99..14a3fe3af60ec 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -144,14 +144,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector
 
 // -----
 
-func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
-  %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
-  func.return %0 : vector<8xi32>
-}
-
-// -----
-
 func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
   // expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
   %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
@@ -161,14 +153,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
 // -----
 
 func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
-  // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+  // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}}
   %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
   func.return %0 : vector<8xi32>
 }
 
 // -----
 
-// Missinng `resetOffset`
+func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error@+1 {{'amdgpu.wmma' op source vectors have different lengths}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error@+1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
+  // expected-error@+1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
+  func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error@+1 {{'amdgpu.wmma' op clamp flag is not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+  // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+  %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+  func.return %0 : vector<8xf32>
+}
+
+// -----
+
+// Missing `resetOffset`
 func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
   // expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
   %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<...
[truncated]

if (isRDNA3)
return std::nullopt;

using fp8 = Float8E4M3FNType;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that's a good idea. We should do it in more places.

@kuhar kuhar force-pushed the amdgpu-wmma-gfx1250-v2 branch from c0e8747 to f58a79a Compare October 25, 2025 11:36
@kuhar kuhar requested a review from krzysz00 October 25, 2025 11:37
kuhar added 3 commits October 28, 2025 11:15
Update `amdgpu.wmma` op definition and implement amdgpu to rocdl
conversion for new variants.
@kuhar kuhar force-pushed the amdgpu-wmma-gfx1250-v2 branch from f58a79a to 7e67848 Compare October 28, 2025 15:35
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@kuhar kuhar merged commit 466c526 into llvm:main Oct 28, 2025
10 checks passed
kuhar added a commit to iree-org/llvm-project that referenced this pull request Oct 28, 2025
Update `amdgpu.wmma` op definition and implement amdgpu to rocdl
conversion for new variants.
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
Update `amdgpu.wmma` op definition and implement amdgpu to rocdl
conversion for new variants.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants