diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 37db096f1ba75..45cb67f0eee4a 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>; def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>; // wmma -def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>, - VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>, - VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>, +def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>, + VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>, + VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>, + VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>, VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>; def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>; @@ -992,7 +993,7 @@ def AMDGPU_WMMAOp : Arguments<(ins ConfinedAttr]>:$m, ConfinedAttr]>:$n, - ConfinedAttr]>:$k, + ConfinedAttr]>:$k, WMMAInTypes:$sourceA, WMMAInTypes:$sourceB, WMMAOutTypes:$destC, @@ -1005,8 +1006,14 @@ def AMDGPU_WMMAOp : let description = [{ The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma` instructions in the AMDGPU architecture, which perform matrix multiplication. - Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K - dimensions. + + On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions. + + On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for + all element types, and K=32 for i4 sources. + + On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128, + depending on the element types. On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector containing only 8 valid values: @@ -1022,7 +1029,13 @@ def AMDGPU_WMMAOp : Example: ```mlir - %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16> + %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16> + + %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32> + + %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32> + + %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32> ``` }]; let assemblyFormat = [{ diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 478b6aaaec83a..1eca43d96fe85 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) { smfma.getN(), smfma.getK(), 1u, chipset); } -/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` -/// if one exists. This includes checking to ensure the intrinsic is supported -/// on the architecture you are compiling for. -static std::optional wmmaOpToIntrinsic(WMMAOp wmma, - Chipset chipset) { - auto sourceVectorType = cast(wmma.getSourceA().getType()); - auto sourceBVectorType = cast(wmma.getSourceB().getType()); - auto destVectorType = cast(wmma.getDestC().getType()); - Type elemSourceType = sourceVectorType.getElementType(); - Type elemBSourceType = sourceBVectorType.getElementType(); - Type elemDestType = destVectorType.getElementType(); - - const uint32_t k = wmma.getK(); - +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for RDNA3/4 architectures. +static std::optional +wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType, + Type elemDestType, uint32_t k, bool isRDNA3) { + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + + // Handle k == 16 for RDNA3/4. if (k == 16) { + // Common patterns for RDNA3 and RDNA4. if (elemSourceType.isF16() && elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); if (elemSourceType.isBF16() && elemDestType.isF32()) @@ -1014,39 +1010,160 @@ static std::optional wmmaOpToIntrinsic(WMMAOp wmma, return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); - if (chipset.majorVersion == 11) { + + // RDNA3 specific patterns. + if (isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); + return std::nullopt; } - } - if (chipset.majorVersion < 12) - return std::nullopt; - // gfx12+ - if (k == 16) { - if (isa(elemSourceType) && - isa(elemBSourceType) && elemDestType.isF32()) + // RDNA4 specific patterns (fp8/bf8). + if (isa(elemSourceType) && isa(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName(); - if (isa(elemSourceType) && - isa(elemBSourceType) && elemDestType.isF32()) + if (isa(elemSourceType) && isa(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName(); - if (isa(elemSourceType) && - isa(elemBSourceType) && elemDestType.isF32()) + if (isa(elemSourceType) && isa(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName(); - if (isa(elemSourceType) && - isa(elemBSourceType) && elemDestType.isF32()) + if (isa(elemSourceType) && isa(elemBSourceType) && + elemDestType.isF32()) return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName(); if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x16_iu4::getOperationName(); return std::nullopt; } - if (k == 32) { + + // Handle k == 32 for RDNA4. + if (k == 32 && !isRDNA3) { if (elemSourceType.isInteger(4) && elemDestType.isInteger(32)) return ROCDL::wmma_i32_16x16x32_iu4::getOperationName(); + } + + llvm_unreachable("Unsupported k value"); +} + +/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// for the gfx1250 architecture. +static std::optional wmmaOpToIntrinsicGfx1250(Type elemSourceType, + Type elemBSourceType, + Type elemDestType, + uint32_t k) { + using fp8 = Float8E4M3FNType; + using bf8 = Float8E5M2Type; + + if (k == 4) { + if (elemSourceType.isF32() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x4_f32::getOperationName(); + return std::nullopt; } + if (k == 32) { + if (elemSourceType.isF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x32_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x32_bf16::getOperationName(); + if (elemSourceType.isF16() && elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x32_f16::getOperationName(); + if (elemSourceType.isBF16() && elemDestType.isBF16()) + return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName(); + + return std::nullopt; + } + + if (k == 64) { + if (isa(elemSourceType) && isa(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName(); + } + if (isa(elemSourceType) && isa(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName(); + } + if (isa(elemSourceType) && isa(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName(); + } + if (isa(elemSourceType) && isa(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName(); + } + if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) + return ROCDL::wmma_i32_16x16x64_iu8::getOperationName(); + + return std::nullopt; + } + + if (k == 128) { + if (isa(elemSourceType) && isa(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName(); + } + if (isa(elemSourceType) && isa(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName(); + } + if (isa(elemSourceType) && isa(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName(); + } + if (isa(elemSourceType) && isa(elemBSourceType)) { + if (elemDestType.isF32()) + return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName(); + if (elemDestType.isF16()) + return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName(); + } + + return std::nullopt; + } + + llvm_unreachable("Unsupported k value"); +} + +/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// if one exists. This includes checking to ensure the intrinsic is supported +/// on the architecture you are compiling for. +static std::optional wmmaOpToIntrinsic(WMMAOp wmma, + Chipset chipset) { + auto sourceVectorType = cast(wmma.getSourceA().getType()); + auto sourceBVectorType = cast(wmma.getSourceB().getType()); + auto destVectorType = cast(wmma.getDestC().getType()); + Type elemSourceType = sourceVectorType.getElementType(); + Type elemBSourceType = sourceBVectorType.getElementType(); + Type elemDestType = destVectorType.getElementType(); + + const uint32_t k = wmma.getK(); + const bool isRDNA3 = chipset.majorVersion == 11; + const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0; + + // Handle RDNA3 and RDNA4. + if (isRDNA3 || isRDNA4) + return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType, + k, isRDNA3); + + // Handle gfx1250. + if (chipset == Chipset{12, 5, 0}) + return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType, + elemDestType, k); + llvm_unreachable("unhandled WMMA case"); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 585b6dacfa648..df955fc90b45f 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() { if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) { return emitOpError( - "source element types much match (except for fp8) but have ") + "source element types must match (except for fp8/bf8) but have ") << sourceAType << " and " << sourceBType; } - if (!sourceAElemType.isInteger(4) && getK() != 16) { - return emitOpError("K dimension must be 16 for source element type ") - << sourceAElemType; + if (isSrcFloat) { + if (getClamp()) + return emitOpError("clamp flag is not supported for float types"); + if (getUnsignedA() || getUnsignedB()) + return emitOpError("unsigned flags are not supported for float types"); } return success(); } diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir new file mode 100644 index 0000000000000..bcbdef040ebe3 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir @@ -0,0 +1,89 @@ +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @wmma_k4 +func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) { + // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1 + amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32> + func.return +} + +// CHECK-LABEL: @wmma_k32 +func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>, + %arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) { + // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2 + amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) + amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> + + // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) + amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16> + + func.return +} + +// CHECK-LABEL: @wmma_k64 +func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>, + %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) { + // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}} + amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32> + + // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4 + amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4 + amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4 + amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4 + amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16> + + func.return +} + +// CHECK-LABEL: @wmma_k128 +func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>, + %arg2 : vector<8xf32>, %arg3 : vector<8xf16>) { + // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16> + + // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2 + amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32> + + // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1) + amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16> + + func.return +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 57847641a2d03..4c6f62a045405 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -156,14 +156,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector // ----- -func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { - // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}} - %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> - func.return %0 : vector<8xi32> -} - -// ----- - func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { // expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}} %0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> @@ -173,14 +165,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec // ----- func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> { - // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}} + // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}} %0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32> func.return %0 : vector<8xi32> } // ----- -// Missinng `resetOffset` +func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op source vectors have different lengths}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> { + // expected-error@+1 {{'amdgpu.wmma' op source element types must match (except for fp8/bf8)}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32> + func.return %0 : vector<8xi32> +} + +// ----- + +func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op clamp flag is not supported for float types}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}} + %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// ----- + +// Missing `resetOffset` func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space>) -> memref<8xi32, #amdgpu.address_space> { // expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space>' but got 'memref<8xi32, #amdgpu.address_space>'}} %ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<[1], offset: ?>, #gpu.address_space> to memref<8xi32, #amdgpu.address_space> diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index a33096750ee23..09134cb4704bb 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -586,6 +586,41 @@ func.func @wmma_i32_16x16x32_i4(%arg0 : vector<16xi4>, %arg1 : vector<8xi32>) -> func.return %0 : vector<8xi32> } +// CHECK-LABEL: func @wmma_f32_16x16x4_f32 +func.func @wmma_f32_16x16x4_f32(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // CHECK: amdgpu.wmma 16x16x4 + %0 = amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @wmma_f32_16x16x64_f8 +func.func @wmma_f32_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // CHECK: amdgpu.wmma 16x16x64 + %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @wmma_f32_16x16x64_bf8 +func.func @wmma_f32_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf32>) -> vector<8xf32> { + // CHECK: amdgpu.wmma 16x16x64 + %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32> + func.return %0 : vector<8xf32> +} + +// CHECK-LABEL: func @wmma_f16_16x16x64_bf8 +func.func @wmma_f16_16x16x64_bf8(%arg0 : vector<32xf8E5M2>, %arg1 : vector<8xf16>) -> vector<8xf16> { + // CHECK: amdgpu.wmma 16x16x64 + %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16> + func.return %0 : vector<8xf16> +} + +// CHECK-LABEL: func @wmma_f16_16x16x64_f8 +func.func @wmma_f16_16x16x64_f8(%arg0 : vector<32xf8E4M3FN>, %arg1 : vector<8xf16>) -> vector<8xf16> { + // CHECK: amdgpu.wmma 16x16x64 + %0 = amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg1 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16> + func.return %0 : vector<8xf16> +} + // CHECK-LABEL: func @swizzle_bitmode func.func @swizzle_bitmode(%arg0 : f32) -> f32 { // CHECK: amdgpu.swizzle_bitmode