-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][amdgpu] Add explicit intrinsic shape to wmma #164920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -912,12 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN | |
| VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>; | ||
| def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>; | ||
| // wmma | ||
| def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType< | ||
| [4, 8, 16], | ||
| [F16, BF16, | ||
| I8, SI8, UI8, | ||
| I<4>, SI<4>, UI<4>, | ||
| F8E4M3FN, F8E5M2]>]>; | ||
| def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>, | ||
| VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>, | ||
| VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>, | ||
| VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>; | ||
| def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, | ||
| VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>; | ||
|
|
||
|
|
@@ -968,6 +966,14 @@ def AMDGPU_MFMAOp : | |
|
|
||
| The negateA, negateB, and negateC flags are only supported for double-precision | ||
| operations on gfx94x. | ||
|
|
||
| Example: | ||
| ```mlir | ||
| %0 = amdgpu.mfma %matA * %matB + %matC | ||
| { abid = 1 : i32, cbsz = 1 : i32, | ||
| m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 } | ||
| blgp = bcast_second_32 : f32, f32, vector<32xf32> | ||
| ``` | ||
| }]; | ||
| let assemblyFormat = [{ | ||
| $sourceA `*` $sourceB `+` $destC | ||
|
|
@@ -982,36 +988,43 @@ def AMDGPU_WMMAOp : | |
| AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>, | ||
| Pure]>, | ||
| Arguments<(ins | ||
| ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m, | ||
| ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n, | ||
| ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k, | ||
| WMMAInTypes:$sourceA, | ||
| WMMAInTypes:$sourceB, | ||
| WMMAOutTypes:$destC, | ||
| DefaultValuedAttr<ConfinedAttr<I32Attr, [IntMinValue<0>, IntMaxValue<1>]>, "0">:$subwordOffset, | ||
| DefaultValuedAttr<ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>, "0">:$subwordOffset, | ||
| UnitAttr:$unsignedA, | ||
| UnitAttr:$unsignedB, | ||
| UnitAttr:$clamp)>, | ||
| Results<(outs WMMAOutTypes: $destD)> { | ||
| let summary = "MLIR wrapper for RDNA3 wmma instructions"; | ||
| let summary = "MLIR wrapper for wmma instructions"; | ||
| let description = [{ | ||
| The `amdgpu.wmma` op is an MLIR wrapper around intrinsics | ||
| for various `wmma` instructions in the RDNA3 or RDNA4 architecture, which | ||
| perform a 16x16 * 16x16 matrix multiplication for different data types. | ||
| Note that in gfx12/RDNA4, there is also a 16x32 * 32x16 instruction for 4-bit | ||
| integer inputs. | ||
| The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma` | ||
| instructions in the AMDGPU architecture, which perform matrix multiplication. | ||
| Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K | ||
| dimensions. | ||
|
|
||
| On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 | ||
| (or 16xbf16) vector containing only 8 valid values: | ||
| - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14. | ||
| - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15. | ||
| On gfx12/RDNA4, the result is instead returned as a vector<8 x f16/bf16> where | ||
| all values are valid and the `subwordOffset` must be `0`, as it cannot be used. | ||
| On gfx12/RDNA4 and gfx1250, the result is instead returned as vector where all | ||
| the values are valid and the `subwordOffset` must be `0`, as it cannot be used. | ||
|
|
||
| `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned. | ||
|
|
||
| The `clamp` flag is used to saturate the output of type T to numeric_limits<T>::max() | ||
| The `clamp` flag is used to saturate the output of type T to `numeric_limits<T>::max()` | ||
| in case of overflow. | ||
|
|
||
| Example: | ||
| ```mlir | ||
| %0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16> | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While the syntax is okay, it is weird that the mfma instructions encode this stuff as an attribute dict while wmma does it as a custom parser There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can give mfma the same syntax, I didn't want to make too many changes in the same PR though |
||
| ``` | ||
| }]; | ||
| let assemblyFormat = [{ | ||
| $sourceA `*` $sourceB `+` $destC | ||
| custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC | ||
| attr-dict | ||
| `:` type($sourceA) `,` type($sourceB) `,` type($destC) | ||
| }]; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,35 +1,36 @@ | ||
| // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s | ||
| // RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: @wmma_to_rocdl | ||
| func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>, | ||
| %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>, | ||
| %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>, | ||
| %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) { | ||
| // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> | ||
| amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> | ||
| amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> | ||
| amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> | ||
| amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> | ||
| amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> | ||
| // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> | ||
| amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> | ||
| amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> | ||
| // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> | ||
| amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> | ||
| amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> | ||
| // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> | ||
| amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> | ||
| amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> | ||
| // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> | ||
| // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16> | ||
| amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> | ||
| amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> | ||
| // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> | ||
| // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> | ||
| amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> | ||
| amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> | ||
| // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> | ||
| amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> | ||
| amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> | ||
| // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> | ||
| amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> | ||
| amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> | ||
| // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> | ||
| amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> | ||
| amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> | ||
| // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> | ||
| amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> | ||
| amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> | ||
|
|
||
| func.return | ||
| } |
Uh oh!
There was an error while loading. Please reload this page.