-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][amdgpu][rocdl] Add gfx1250 wmma ops #165064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-amdgpu Author: Jakub Kuderski (kuhar) ChangesUpdate Patch is 22.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165064.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d74abc22acd5e..99e10b231482f 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// wmma
-def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
- VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
- VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
+def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
+ VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
+ VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
+ VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
@@ -990,7 +991,7 @@ def AMDGPU_WMMAOp :
Arguments<(ins
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
- ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
WMMAInTypes:$sourceA,
WMMAInTypes:$sourceB,
WMMAOutTypes:$destC,
@@ -1003,8 +1004,14 @@ def AMDGPU_WMMAOp :
let description = [{
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
instructions in the AMDGPU architecture, which perform matrix multiplication.
- Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
- dimensions.
+
+ On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.
+
+ On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
+ all element types, and K=32 for i4 sources.
+
+ On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
+ depending on the element types.
On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
(or 16xbf16) vector containing only 8 valid values:
@@ -1020,7 +1027,13 @@ def AMDGPU_WMMAOp :
Example:
```mlir
- %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
+ %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>
+
+ %1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>
+
+ %2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>
+
+ %3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
```
}];
let assemblyFormat = [{
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 478b6aaaec83a..90e731c11da62 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -1002,8 +1002,13 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Type elemDestType = destVectorType.getElementType();
const uint32_t k = wmma.getK();
+ const bool isRDNA3 = chipset.majorVersion == 11;
+ const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
if (k == 16) {
+ if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+ return std::nullopt;
+
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1019,34 +1024,124 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
}
}
- if (chipset.majorVersion < 12)
+ if (isRDNA3)
return std::nullopt;
+ using fp8 = Float8E4M3FNType;
+ using bf8 = Float8E5M2Type;
+
// gfx12+
if (k == 16) {
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
+ return std::nullopt;
+
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
- if (isa<Float8E4M3FNType>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
- if (isa<Float8E5M2Type>(elemSourceType) &&
- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
+ elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
+
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
return std::nullopt;
}
if (k == 32) {
- if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
- return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ if (isRDNA4) {
+ if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
+ return std::nullopt;
+ }
+
+ // gfx1250
+ if (elemSourceType.isF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
+ if (elemSourceType.isF16() && elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
+ if (elemSourceType.isBF16() && elemDestType.isBF16())
+ return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (isRDNA4)
+ return std::nullopt;
+
+ // gfx1250
+ if (k == 4) {
+ if (elemSourceType.isF32() && elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
return std::nullopt;
}
+ if (k == 64) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
+ }
+ if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
+ return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();
+
+ return std::nullopt;
+ }
+
+ if (k == 128) {
+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
+ }
+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
+ }
+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
+ if (elemDestType.isF32())
+ return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
+ if (elemDestType.isF16())
+ return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
+ }
+
+ return std::nullopt;
+ }
llvm_unreachable("unhandled WMMA case");
}
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4c4965e67676e..7b4e248dcf5e4 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
- "source element types much match (except for fp8) but have ")
+ "source element types much match (except for fp8/bf8) but have ")
<< sourceAType << " and " << sourceBType;
}
- if (!sourceAElemType.isInteger(4) && getK() != 16) {
- return emitOpError("K dimension must be 16 for source element type ")
- << sourceAElemType;
+ if (isSrcFloat) {
+ if (getClamp())
+ return emitOpError("clamp flag is not supported for float types");
+ if (getUnsignedA() || getUnsignedB())
+ return emitOpError("unsigned flags are not supported for float types");
}
return success();
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
new file mode 100644
index 0000000000000..bcbdef040ebe3
--- /dev/null
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -0,0 +1,89 @@
+// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: @wmma_k4
+func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
+ // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
+ amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
+ func.return
+}
+
+// CHECK-LABEL: @wmma_k32
+func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
+ %arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
+ // CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
+ amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
+ amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_k64
+func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
+ %arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
+ // CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
+ amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>
+
+ // CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
+ amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
+ amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
+ amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
+ amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_k128
+func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+ %arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
+ // CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>
+
+ // CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
+ amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>
+
+ // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
+ amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>
+
+ func.return
+}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6a2518a40cc99..14a3fe3af60ec 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -144,14 +144,6 @@ func.func @wmma_no_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector
// -----
-func.func @wmma_wrong_m_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
- // expected-error@+1 {{'amdgpu.wmma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
- %0 = amdgpu.wmma 32x16x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
- func.return %0 : vector<8xi32>
-}
-
-// -----
-
func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
// expected-error@+1 {{'amdgpu.wmma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16}}}
%0 = amdgpu.wmma 16x32x16 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
@@ -161,14 +153,62 @@ func.func @wmma_wrong_n_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vec
// -----
func.func @wmma_wrong_k_dim(%arg0 : vector<16xi8>, %arg1 : vector<8xi32>) -> vector<8xi32> {
- // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ // expected-error@+1 {{'amdgpu.wmma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32, 64, 128}}}
%0 = amdgpu.wmma 16x16x24 %arg0 * %arg0 + %arg1 : vector<16xi8>, vector<16xi8>, vector<8xi32>
func.return %0 : vector<8xi32>
}
// -----
-// Missinng `resetOffset`
+func.func @wmma_source_length_mismatch(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op source vectors have different lengths}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<16xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_float_types(%arg0 : vector<8xf16>, %arg1 : vector<8xbf16>, %arg2 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xf16>, vector<8xbf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_mismatched_int_types(%arg0 : vector<8xi8>, %arg1 : vector<8xi4>, %arg2 : vector<8xi32>) -> vector<8xi32> {
+ // expected-error@+1 {{'amdgpu.wmma' op source element types much match (except for fp8/bf8)}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg1 + %arg2 : vector<8xi8>, vector<8xi4>, vector<8xi32>
+ func.return %0 : vector<8xi32>
+}
+
+// -----
+
+func.func @wmma_clamp_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op clamp flag is not supported for float types}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {clamp} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedA_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedA} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+func.func @wmma_unsignedB_float(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) -> vector<8xf32> {
+ // expected-error@+1 {{'amdgpu.wmma' op unsigned flags are not supported for float types}}
+ %0 = amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 {unsignedB} : vector<8xf16>, vector<8xf16>, vector<8xf32>
+ func.return %0 : vector<8xf32>
+}
+
+// -----
+
+// Missing `resetOffset`
func.func @fat_raw_buffer_cast_stripped_offset(%m: memref<8xi32, strided<[1], offset: ?>, #gpu.address_space<global>>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
// expected-error@+1 {{'amdgpu.fat_raw_buffer_cast' op expected result type to be 'memref<8xi32, strided<[1], offset: ?>, #amdgpu.address_space<fat_raw_buffer>>' but got 'memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>'}}
%ret = amdgpu.fat_raw_buffer_cast %m : memref<8xi32, strided<...
[truncated]
|
| if (isRDNA3) | ||
| return std::nullopt; | ||
|
|
||
| using fp8 = Float8E4M3FNType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that's a good idea. We should do it in more places.
c0e8747 to
f58a79a
Compare
Update `amdgpu.wmma` op definition and implement amdgpu to rocdl conversion for new variants.
f58a79a to
7e67848
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Update `amdgpu.wmma` op definition and implement amdgpu to rocdl conversion for new variants.
Update `amdgpu.wmma` op definition and implement amdgpu to rocdl conversion for new variants.
Update
amdgpu.wmmaop definition and implement amdgpu to rocdl conversion for new variants.