From 2d9f5fe167dbb2686bc67328b0ddac76ea579b6e Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 23 Oct 2025 20:12:41 -0400 Subject: [PATCH 1/3] [mlir][amdgpu] Add explicit intrinsic shape to wmma This is in preparation for adding support for gfx1250 wmma intrinsics that include much more possible shapes. Instead of guessing the wave32/wave64 mode based on element types and vector sizes, require the intrinsic shapes to be set explicitly as attributes. --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 45 +++++++----- .../mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h | 25 ++++++- .../AMDGPUToROCDL/AMDGPUToROCDL.cpp | 70 +++++++++++-------- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 61 +++++++++------- .../{wmma.mlir => wmma-gfx11.mlir} | 27 +++---- .../Conversion/AMDGPUToROCDL/wmma-gfx12.mlir | 46 ++++++------ mlir/test/Dialect/AMDGPU/invalid.mlir | 46 +++++++++++- mlir/test/Dialect/AMDGPU/ops.mlir | 15 ++-- 8 files changed, 218 insertions(+), 117 deletions(-) rename mlir/test/Conversion/AMDGPUToROCDL/{wmma.mlir => wmma-gfx11.mlir} (59%) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 7184de93bfacb..3a808ff3a01e4 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>; def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>; // wmma -def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType< - [4, 8, 16], - [F16, BF16, - I8, SI8, UI8, - I<4>, SI<4>, UI<4>, - F8E4M3FN, F8E5M2]>]>; +def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>, + VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>, + VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>, + VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>; def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>; @@ -968,6 +966,14 @@ def AMDGPU_MFMAOp : The negateA, negateB, and negateC flags are only supported for double-precision operations on gfx94x. + + Example: + ```mlir + %0 = amdgpu.mfma %matA * %matB + %matC + { abid = 1 : i32, cbsz = 1 : i32, + m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 } + blgp = bcast_second_32 : f32, f32, vector<32xf32> + ``` }]; let assemblyFormat = [{ $sourceA `*` $sourceB `+` $destC @@ -982,6 +988,9 @@ def AMDGPU_WMMAOp : AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>, Arguments<(ins + ConfinedAttr, IntMaxValue<16>]>:$m, + ConfinedAttr, IntMaxValue<16>]>:$n, + ConfinedAttr, IntMaxValue<32>, IntPowerOf2]>:$k, WMMAInTypes:$sourceA, WMMAInTypes:$sourceB, WMMAOutTypes:$destC, @@ -990,28 +999,32 @@ def AMDGPU_WMMAOp : UnitAttr:$unsignedB, UnitAttr:$clamp)>, Results<(outs WMMAOutTypes: $destD)> { - let summary = "MLIR wrapper for RDNA3 wmma instructions"; + let summary = "MLIR wrapper for wmma instructions"; let description = [{ - The `amdgpu.wmma` op is an MLIR wrapper around intrinsics - for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which - perform a 16x16 * 16x16 matrix multiplication for different data types. - Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit - integer inputs. + The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma` + instructions in the AMDGPU architecture, which perform matrix multiplication. + Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K + dimensions. On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector containing only 8 valid values: - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14. - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15. - On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where - all values are valid and the `subwordOffset` must be `0`, as it cannot be used. + On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all + the values are valid and the `subwordOffset` must be `0`, as it cannot be used. `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned. - The `clamp` flag is used to saturate the output of type T to numeric_limits::max() + The `clamp` flag is used to saturate the output of type T to `numeric_limits::max()` in case of overflow. + + Example: + ```mlir + %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16> + ``` }]; let assemblyFormat = [{ - $sourceA `*` $sourceB `+` $destC + custom($m, $n, $k) $sourceA `*` $sourceB `+` $destC attr-dict `:` type($sourceA) `,` type($sourceB) `,` type($destC) }]; diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h index 3de57c923178a..b6fe61ff1afa2 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file declares a dialect for MLIR wrappers around AMDGPU-specific -// intrinssics and for other AMD GPU-specific functionality. +// intrinsics and for other AMD GPU-specific functionality. // //===----------------------------------------------------------------------===// @@ -26,6 +26,29 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc" +namespace mlir { +/// Parser for the `custom` custom assembly format used by +/// WMMAOp. +ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, + IntegerAttr &n, IntegerAttr &k); +inline ParseResult parseMNKDimensionList(OpAsmParser &parser, Operation *, + IntegerAttr &m, IntegerAttr &n, + IntegerAttr &k) { + return parseMNKDimensionList(parser, m, n, k); +} + +/// Printer for the `custom` custom assembly format used by +/// WMMAOp. +inline void printMNKDimensionList(OpAsmPrinter &printer, IntegerAttr m, + IntegerAttr n, IntegerAttr k) { + printer.printDimensionList(ArrayRef{m.getInt(), n.getInt(), k.getInt()}); +} +inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *, + IntegerAttr m, IntegerAttr n, IntegerAttr k) { + printMNKDimensionList(printer, m, n, k); +} +} // namespace mlir + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc" diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 9b154350cd913..478b6aaaec83a 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" @@ -993,28 +994,36 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { /// on the architecture you are compiling for. static std::optional wmmaOpToIntrinsic(WMMAOp wmma, Chipset chipset) { - auto sourceVectorType = dyn_cast(wmma.getSourceA().getType()); - auto sourceBVectorType = dyn_cast(wmma.getSourceB().getType()); - auto destVectorType = dyn_cast(wmma.getDestC().getType()); - auto elemSourceType = sourceVectorType.getElementType(); - auto elemBSourceType = sourceBVectorType.getElementType(); - auto elemDestType = destVectorType.getElementType(); - - if (elemSourceType.isF16() && elemDestType.isF32()) - return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); - if (elemSourceType.isBF16() && elemDestType.isF32()) - return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); - if (elemSourceType.isF16() && elemDestType.isF16()) - return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); - if (elemSourceType.isBF16() && elemDestType.isBF16()) - return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); - if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) - return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (chipset.majorVersion == 11) { - if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) - return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + auto sourceVectorType = cast(wmma.getSourceA().getType()); + auto sourceBVectorType = cast(wmma.getSourceB().getType()); + auto destVectorType = cast(wmma.getDestC().getType()); + Type elemSourceType = sourceVectorType.getElementType(); + Type elemBSourceType = sourceBVectorType.getElementType(); + Type elemDestType = destVectorType.getElementType(); + + const uint32_t k = wmma.getK(); + + if (k == 16) { + if (elemSourceType.isF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); + if (elemSourceType.isF16() && elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isBF16()) + return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); + if (chipset.majorVersion == 11) { + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + } } - if (chipset.majorVersion >= 12) { + if (chipset.majorVersion < 12) + return std::nullopt; + + // gfx12+ + if (k == 16) { if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); @@ -1027,17 +1036,18 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, if (isa(elemSourceType) && isa(elemBSourceType) && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); - if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) { - bool isWave64 = destVectorType.getNumElements() == 4; - // This is the ambiguous case. 8 inputs to the wave64 version means that - // we want the 16x16x32 version, but for wave32 they mean the short form. - bool has8Inputs = sourceVectorType.getNumElements() == 8; - if ((isWave64 && has8Inputs) || (!isWave64 && !has8Inputs)) - return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); - } + + return std::nullopt; } - return std::nullopt; + if (k == 32) { + if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + return std::nullopt; + } + + llvm_unreachable("unhandled WMMA case"); } namespace { diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 61166db0ff210..eb40374d61303 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -360,45 +360,52 @@ LogicalResult ScaledExtPacked816Op::verify() { //===----------------------------------------------------------------------===// // WMMAOp //===----------------------------------------------------------------------===// -LogicalResult WMMAOp::verify() { - Type sourceAType = getSourceA().getType(); - Type sourceBType = getSourceB().getType(); - Type destType = getDestC().getType(); - VectorType sourceVectorAType = dyn_cast(sourceAType); - VectorType sourceVectorBType = dyn_cast(sourceBType); - VectorType destVectorType = dyn_cast(destType); +ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, + IntegerAttr &n, IntegerAttr &k) { + SmallVector dimensions; + if (parser.parseDimensionList(dimensions, false, false)) + return failure(); + if (dimensions.size() != 3) + return parser.emitError(parser.getCurrentLocation()) + << "expected 3 dimensions in MNK dimension list"; - Type sourceAElemType = sourceVectorAType.getElementType(); - Type sourceBElemType = sourceVectorBType.getElementType(); - Type destElemType = destVectorType.getElementType(); + m = parser.getBuilder().getI32IntegerAttr(dimensions[0]); + n = parser.getBuilder().getI32IntegerAttr(dimensions[1]); + k = parser.getBuilder().getI32IntegerAttr(dimensions[2]); + return success(); +} - if (sourceVectorAType.getNumElements() != - sourceVectorBType.getNumElements()) { +LogicalResult WMMAOp::verify() { + auto sourceAType = cast(getSourceA().getType()); + auto sourceBType = cast(getSourceB().getType()); + auto destType = cast(getDestC().getType()); + + Type sourceAElemType = sourceAType.getElementType(); + Type sourceBElemType = sourceBType.getElementType(); + if (sourceAType.getNumElements() != sourceBType.getNumElements()) { return emitOpError("source vectors have different lengths: ") - << sourceVectorAType << " vs. " << sourceVectorBType; + << sourceAType << " vs. " << sourceBType; } - bool isDestFloat = isa(destElemType); - bool isSrcFloat = - isa( - sourceAElemType); - - if (isDestFloat && !isSrcFloat) { - return emitOpError("Expected float sources with float destination"); - } + bool isDestFloat = destType.getElementType().isFloat(); + bool isSrcFloat = sourceAElemType.isFloat(); - if (!isDestFloat && isSrcFloat) { - return emitOpError("Expected int sources with int destination"); - } + if (isDestFloat && !isSrcFloat) + return emitOpError("expected float sources with float destination"); + if (!isDestFloat && isSrcFloat) + return emitOpError("expected int sources with int destination"); - if (sourceAElemType != sourceBElemType && - !(isa(sourceAElemType) && - isa(sourceBElemType))) { + if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( "source element types much match (except for fp8) but have ") << sourceAType << " and " << sourceBType; } + + if (!sourceAElemType.isInteger(4) && getK() != 16) { + return emitOpError("K dimension must be 16 for source element type ") + << sourceAElemType; + } return success(); } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir similarity index 59% rename from mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir rename to mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir index 638a7c3f8c1c5..d1301d0089220 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx11.mlir @@ -1,35 +1,36 @@ -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s + // CHECK-LABEL: @wmma_to_rocdl func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>, %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>, %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>, %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) { // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> - amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> - amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16> - amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> - amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> - amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> + amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> - amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> + amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> - amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> + amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> - amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> + amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> func.return } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir index 94a1b78d5f040..b897323340402 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s // CHECK-LABEL: @wmma_to_rocdl func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, %arg2 : vector<8xf32>, %arg3 : vector<4xf32>, @@ -9,60 +9,60 @@ func.func @wmma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<4xf16>, %arg12 : vector<8xi32>, %arg13 : vector<4xi32>, %arg14 : vector<16xi4>, %arg15 : vector<8xi4>, %arg16 : vector<4xi4>) { // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<8xf16>, vector<8xf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg3 : vector<4xf16>, vector<4xf16>, vector<4xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg2 : vector<8xbf16>, vector<8xbf16>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg3 : vector<4xbf16>, vector<4xbf16>, vector<4xf32> // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<8xf16>, i1) -> vector<8xf16> - amdgpu.wmma %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16> + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 : vector<8xf16>, vector<8xf16>, vector<8xf16> // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf16>, i1) -> vector<4xf16> - amdgpu.wmma %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16> + amdgpu.wmma 16x16x16 %arg1 * %arg1 + %arg1 : vector<4xf16>, vector<4xf16>, vector<4xf16> // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<8xi16>, vector<8xi16>, vector<8xi16>, i1) -> vector<8xi16> // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> - amdgpu.wmma %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16> + amdgpu.wmma 16x16x16 %arg4 * %arg4 + %arg4 : vector<8xbf16>, vector<8xbf16>, vector<8xbf16> // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xi16>, i1) -> vector<4xi16> - amdgpu.wmma %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16> + amdgpu.wmma 16x16x16 %arg5 * %arg5 + %arg5 : vector<4xbf16>, vector<4xbf16>, vector<4xbf16> // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.fp8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg7 * %arg7 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E4M3FN>, vector<4xf32> // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg6 * %arg8 + %arg2 : vector<8xf8E4M3FN>, vector<8xf8E5M2>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.fp8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg7 * %arg9 + %arg3 : vector<4xf8E4M3FN>, vector<4xf8E5M2>, vector<4xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg8 * %arg8 + %arg2 : vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf8_bf8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg3 : vector<4xf8E5M2>, vector<4xf8E5M2>, vector<4xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32> - amdgpu.wmma %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32> + amdgpu.wmma 16x16x16 %arg8 * %arg6 + %arg2 : vector<8xf8E5M2>, vector<8xf8E4M3FN>, vector<8xf32> // CHECK: rocdl.wmma.f32.16x16x16.bf8_fp8{{.*}}: (i32, i32, vector<4xf32>) -> vector<4xf32> - amdgpu.wmma %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32> + amdgpu.wmma 16x16x16 %arg9 * %arg7 + %arg3 : vector<4xf8E5M2>, vector<4xf8E4M3FN>, vector<4xf32> // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> - amdgpu.wmma %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32> + amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg12 {clamp} : vector<8xi8>, vector<8xi8>, vector<8xi32> // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> - amdgpu.wmma %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32> + amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg13 {unsignedA, unsignedB, clamp}: vector<4xi8>, vector<4xi8>, vector<4xi32> // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> - amdgpu.wmma %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32> + amdgpu.wmma 16x16x32 %arg14 * %arg14 + %arg12 {clamp} : vector<16xi4>, vector<16xi4>, vector<8xi32> // CHECK: rocdl.wmma.i32.16x16x32.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> - amdgpu.wmma %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32> + amdgpu.wmma 16x16x32 %arg15 * %arg15 + %arg13 {clamp} : vector<8xi4>, vector<8xi4>, vector<4xi32> // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<8xi32>, i1) -> vector<8xi32> - amdgpu.wmma %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32> + amdgpu.wmma 16x16x16 %arg15 * %arg15 + %arg12 {clamp} : vector<8xi4>, vector<8xi4>, vector<8xi32> // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> - amdgpu.wmma %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32> + amdgpu.wmma 16x16x16 %arg16 * %arg16 + %arg13 {clamp} : vector<4xi4>, vector<4xi4>, vector<4xi32> func.return } diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index a8256b16ed8a1..aee4705b26958 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -120,9 +120,49 @@ func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> { // ----- -func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> { - // expected-error@+1 {{'amdgpu.wmma' op Expected int sources with int destination}} - %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32> +func.func @wmma_f16_i32(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> { + // expected-error@+1 {{'amdgpu.wmma' op expected int sources with int destination}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32> + func.return %0 : vector<8xi32> +} + +// ----- + +func.func @wmma_i16_f32(%arg0 : vector<16xi8>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op expected float sources with float destination}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { + // expected-error@+1 {{'amdgpu.wmma' expected 3 dimensions in MNK dimension list}} + %0 = amdgpu.wmma 16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> + func.return %0 : vector<8xi32> +} + +// ----- + +func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { + // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 16}} + %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> + func.return %0 : vector<8xi32> +} + +// ----- + +func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { + // expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 16}} + %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> + func.return %0 : vector<8xi32> +} + +// ----- + +func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { + // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 32 whose value is a power of two}} + %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> func.return %0 : vector<8xi32> } diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index f9c6899dadfc1..a185eb612c9ac 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -565,13 +565,20 @@ func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> { func.return %0 : vector<32xf32> } -// CHECK-LABEL: func @wmma -func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> { - // CHECK: amdgpu.wmma - %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16> +// CHECK-LABEL: func @wmma_f16_16x16x16_f16 +func.func @wmma_f16_16x16x16_f16(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> { + // CHECK: amdgpu.wmma 16x16x16 + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16> func.return %0 : vector<8xf16> } +// CHECK-LABEL: func @wmma_i32_16x16x32_i4 +func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) -> vector<8xi32> { + // CHECK: amdgpu.wmma 16x16x32 + %0 = amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg1 : vector<16xi4>, vector<16xi4>, vector<8xi32> + func.return %0 : vector<8xi32> +} + // CHECK-LABEL: func @swizzle_bitmode func.func @swizzle_bitmode(%arg0 : f32) -> f32 { // CHECK: amdgpu.swizzle_bitmode From f8d697a1b426d5ab8a279d4457485607a1b2c802 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 23 Oct 2025 21:00:26 -0400 Subject: [PATCH 2/3] Fix namespace --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h | 4 ++-- mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h index b6fe61ff1afa2..dcd9f95a7561f 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h @@ -26,7 +26,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.h.inc" -namespace mlir { +namespace mlir::amdgpu { /// Parser for the `custom` custom assembly format used by /// WMMAOp. ParseResult parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, @@ -47,7 +47,7 @@ inline void printMNKDimensionList(OpAsmPrinter &printer, Operation *, IntegerAttr m, IntegerAttr n, IntegerAttr k) { printMNKDimensionList(printer, m, n, k); } -} // namespace mlir +} // namespace mlir::amdgpu #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/AMDGPU/IR/AMDGPUAttributes.h.inc" diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index eb40374d61303..4c4965e67676e 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -361,8 +361,9 @@ LogicalResult ScaledExtPacked816Op::verify() { // WMMAOp //===----------------------------------------------------------------------===// -ParseResult mlir::parseMNKDimensionList(OpAsmParser &parser, IntegerAttr &m, - IntegerAttr &n, IntegerAttr &k) { +ParseResult mlir::amdgpu::parseMNKDimensionList(OpAsmParser &parser, + IntegerAttr &m, IntegerAttr &n, + IntegerAttr &k) { SmallVector dimensions; if (parser.parseDimensionList(dimensions, false, false)) return failure(); From 050028521e00e36297966552e3d8278b6b4a4749 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Fri, 24 Oct 2025 10:20:16 -0400 Subject: [PATCH 3/3] Simplify confined attrs --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 8 ++++---- mlir/include/mlir/IR/CommonAttrConstraints.td | 5 +++++ mlir/test/Dialect/AMDGPU/invalid.mlir | 6 +++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 3a808ff3a01e4..d74abc22acd5e 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -988,13 +988,13 @@ def AMDGPU_WMMAOp : AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>, Pure]>, Arguments<(ins - ConfinedAttr, IntMaxValue<16>]>:$m, - ConfinedAttr, IntMaxValue<16>]>:$n, - ConfinedAttr, IntMaxValue<32>, IntPowerOf2]>:$k, + ConfinedAttr]>:$m, + ConfinedAttr]>:$n, + ConfinedAttr]>:$k, WMMAInTypes:$sourceA, WMMAInTypes:$sourceB, WMMAOutTypes:$destC, - DefaultValuedAttr, IntMaxValue<1>]>, "0">:$subwordOffset, + DefaultValuedAttr]>, "0">:$subwordOffset, UnitAttr:$unsignedA, UnitAttr:$unsignedB, UnitAttr:$clamp)>, diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td index e1869c1821b11..b7e168a3e6f86 100644 --- a/mlir/include/mlir/IR/CommonAttrConstraints.td +++ b/mlir/include/mlir/IR/CommonAttrConstraints.td @@ -804,6 +804,11 @@ def IntPositivePowerOf2 : AllAttrOf<[IntPositive, IntPowerOf2]>; class IntValidAlignment: ConfinedAttr; +class IntIsOneOf values> : AttrConstraint< + CPred<"::llvm::is_contained({" # !interleave(!foreach(val, values, val), ", ") # + "}, ::llvm::cast<::mlir::IntegerAttr>($_self).getInt())">, + "whose value is one of {" # !interleave(!foreach(val, values, val), ", ") # "}">; + class ArrayMaxCount : AttrConstraint< CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() <= " # n>, "with at most " # n # " elements">; diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index aee4705b26958..6a2518a40cc99 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -145,7 +145,7 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector // ----- func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { - // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 16}} + // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}} %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> func.return %0 : vector<8xi32> } @@ -153,7 +153,7 @@ func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec // ----- func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { - // expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 16}} + // expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}} %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> func.return %0 : vector<8xi32> } @@ -161,7 +161,7 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec // ----- func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { - // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 16 whose maximum value is 32 whose value is a power of two}} + // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> func.return %0 : vector<8xi32> }