Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -912,9 +912,10 @@ def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
VectorOfLengthAndType<[4, 8, 16], [I8, SI8, UI8]>,
VectorOfLengthAndType<[4, 8], [F8E4M3FN, F8E5M2]>,
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
VectorOfLengthAndType<[4, 8, 16, 32], [I8, SI8, UI8]>,
VectorOfLengthAndType<[4, 8, 32, 64], [F8E4M3FN, F8E5M2]>,
VectorOfLengthAndType<[4, 8, 16], [I<4>, SI<4>, UI<4>]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>]>;
Expand Down Expand Up @@ -992,7 +993,7 @@ def AMDGPU_WMMAOp :
Arguments<(ins
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$m,
ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$k,
ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32, 64, 128]>]>:$k,
WMMAInTypes:$sourceA,
WMMAInTypes:$sourceB,
WMMAOutTypes:$destC,
Expand All @@ -1005,8 +1006,14 @@ def AMDGPU_WMMAOp :
let description = [{
The `amdgpu.wmma` op is an MLIR wrapper around intrinsics for various `wmma`
instructions in the AMDGPU architecture, which perform matrix multiplication.
Note that all wmma intrinsics have M=N=16 dimensions but vary by in allowed K
dimensions.

On gfx11/RDNA3, wmma intrinsics have M=N=K=16 dimensions.

On gfx12/RDNA4, wmma intrinsics have M=N=16 dimensions and support K=16 for
all element types, and K=32 for i4 sources.

On gfx1250, wmma intrinsics have M=N=16 and K dimensions of 4, 32, 64, or 128,
depending on the element types.

On gfx11/RDNA3, emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16
(or 16xbf16) vector containing only 8 valid values:
Expand All @@ -1022,7 +1029,13 @@ def AMDGPU_WMMAOp :

Example:
```mlir
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<16xf16>, vector<16xf16>, vector<8xf16>
%0 = amdgpu.wmma 16x16x16 %matA * %matB + %matC : vector<8xf16>, vector<8xf16>, vector<8xf16>

%1 = amdgpu.wmma 16x16x64 %matD * %matE + %matF : vector<32xi8>, vector<8xf32>, vector<8xf32>

%2 = amdgpu.wmma 16x16x128 %matG * %matH + %matI : vector<64xf4E2M1FN>, vector<64xf4E2M1FN>, vector<8xf32>

%3 = amdgpu.wmma 16x16x4 %matJ * %matK + %matL : vector<2xf32>, vector<2xf32>, vector<8xf32>
```
}];
let assemblyFormat = [{
Expand Down
175 changes: 146 additions & 29 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
smfma.getN(), smfma.getK(), 1u, chipset);
}

/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
Type elemSourceType = sourceVectorType.getElementType();
Type elemBSourceType = sourceBVectorType.getElementType();
Type elemDestType = destVectorType.getElementType();

const uint32_t k = wmma.getK();

/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// for RDNA3/4 architectures.
static std::optional<StringRef>
wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
Type elemDestType, uint32_t k, bool isRDNA3) {
using fp8 = Float8E4M3FNType;
using bf8 = Float8E5M2Type;

// Handle k == 16 for RDNA3/4.
if (k == 16) {
// Common patterns for RDNA3 and RDNA4.
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
Expand All @@ -1014,39 +1010,160 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
if (chipset.majorVersion == 11) {

// RDNA3 specific patterns.
if (isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
return std::nullopt;
}
}
if (chipset.majorVersion < 12)
return std::nullopt;

// gfx12+
if (k == 16) {
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
// RDNA4 specific patterns (fp8/bf8).
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
if (isa<Float8E4M3FNType>(elemSourceType) &&
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName();
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32())
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType) &&
elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName();
if (isa<Float8E5M2Type>(elemSourceType) &&
isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32())
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();

return std::nullopt;
}
if (k == 32) {

// Handle k == 32 for RDNA4.
if (k == 32 && !isRDNA3) {
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
}

llvm_unreachable("Unsupported k value");
}

/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// for the gfx1250 architecture.
static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
Type elemBSourceType,
Type elemDestType,
uint32_t k) {
using fp8 = Float8E4M3FNType;
using bf8 = Float8E5M2Type;

if (k == 4) {
if (elemSourceType.isF32() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x4_f32::getOperationName();

return std::nullopt;
}

if (k == 32) {
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x32_bf16::getOperationName();
if (elemSourceType.isF16() && elemDestType.isF16())
return ROCDL::wmma_f16_16x16x32_f16::getOperationName();
if (elemSourceType.isBF16() && elemDestType.isBF16())
return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName();

return std::nullopt;
}

if (k == 64) {
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName();
}
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName();
}
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName();
}
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName();
}
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x64_iu8::getOperationName();

return std::nullopt;
}

if (k == 128) {
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName();
}
if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName();
}
if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName();
}
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
if (elemDestType.isF32())
return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName();
if (elemDestType.isF16())
return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName();
}

return std::nullopt;
}

llvm_unreachable("Unsupported k value");
}

/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
/// if one exists. This includes checking to ensure the intrinsic is supported
/// on the architecture you are compiling for.
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
Chipset chipset) {
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
Type elemSourceType = sourceVectorType.getElementType();
Type elemBSourceType = sourceBVectorType.getElementType();
Type elemDestType = destVectorType.getElementType();

const uint32_t k = wmma.getK();
const bool isRDNA3 = chipset.majorVersion == 11;
const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;

// Handle RDNA3 and RDNA4.
if (isRDNA3 || isRDNA4)
return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
k, isRDNA3);

// Handle gfx1250.
if (chipset == Chipset{12, 5, 0})
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
elemDestType, k);

llvm_unreachable("unhandled WMMA case");
}

Expand Down
10 changes: 6 additions & 4 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {

if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
"source element types much match (except for fp8) but have ")
"source element types must match (except for fp8/bf8) but have ")
<< sourceAType << " and " << sourceBType;
}

if (!sourceAElemType.isInteger(4) && getK() != 16) {
return emitOpError("K dimension must be 16 for source element type ")
<< sourceAElemType;
if (isSrcFloat) {
if (getClamp())
return emitOpError("clamp flag is not supported for float types");
if (getUnsignedA() || getUnsignedB())
return emitOpError("unsigned flags are not supported for float types");
}
return success();
}
Expand Down
89 changes: 89 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s

// CHECK-LABEL: @wmma_k4
func.func @wmma_k4(%arg0 : vector<2xf32>, %arg1 : vector<8xf32>) {
// CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
amdgpu.wmma 16x16x4 %arg0 * %arg0 + %arg1 : vector<2xf32>, vector<2xf32>, vector<8xf32>
func.return
}

// CHECK-LABEL: @wmma_k32
func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vector<8xf32>,
%arg3 : vector<8xf16>, %arg4 : vector<8xbf16>) {
// CHECK: rocdl.wmma.f32.16x16x32.f16 %arg0, %arg0, %arg2
amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x32.f16 %arg0, %arg0, {{.*}} : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1)
amdgpu.wmma 16x16x32 %arg0 * %arg0 + %arg3 : vector<16xf16>, vector<16xf16>, vector<8xf16>

// CHECK: rocdl.wmma.f32.16x16x32.bf16 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<8xf32>

// CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
amdgpu.wmma 16x16x32 %arg1 * %arg1 + %arg4 : vector<16xbf16>, vector<16xbf16>, vector<8xbf16>

func.return
}

// CHECK-LABEL: @wmma_k64
func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 : vector<32xf8E5M2>,
%arg3 : vector<8xi32>, %arg4 : vector<8xf32>, %arg5 : vector<8xf16>) {
// CHECK: rocdl.wmma.i32.16x16x64.iu8 {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg3, {{.*}}
amdgpu.wmma 16x16x64 %arg0 * %arg0 + %arg3 {clamp} : vector<32xi8>, vector<32xi8>, vector<8xi32>

// CHECK: rocdl.wmma.f32.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg4
amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x64.fp8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x64 %arg1 * %arg1 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<8xf16>

// CHECK: rocdl.wmma.f32.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg4
amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg4 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x64.fp8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x64 %arg1 * %arg2 + %arg5 : vector<32xf8E4M3FN>, vector<32xf8E5M2>, vector<8xf16>

// CHECK: rocdl.wmma.f32.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg4
amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg4 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x64.bf8_bf8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x64 %arg2 * %arg2 + %arg5 : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<8xf16>

// CHECK: rocdl.wmma.f32.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg4
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg4 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x64 %arg2 * %arg1 + %arg5 : vector<32xf8E5M2>, vector<32xf8E4M3FN>, vector<8xf16>

func.return
}

// CHECK-LABEL: @wmma_k128
func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
%arg2 : vector<8xf32>, %arg3 : vector<8xf16>) {
// CHECK: rocdl.wmma.f32.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x128.fp8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x128 %arg0 * %arg0 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E4M3FN>, vector<8xf16>

// CHECK: rocdl.wmma.f32.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg2 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x128.fp8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x128 %arg0 * %arg1 + %arg3 : vector<64xf8E4M3FN>, vector<64xf8E5M2>, vector<8xf16>

// CHECK: rocdl.wmma.f32.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg2 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x128.bf8_bf8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x128 %arg1 * %arg1 + %arg3 : vector<64xf8E5M2>, vector<64xf8E5M2>, vector<8xf16>

// CHECK: rocdl.wmma.f32.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg2
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg2 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf32>

// CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
amdgpu.wmma 16x16x128 %arg1 * %arg0 + %arg3 : vector<64xf8E5M2>, vector<64xf8E4M3FN>, vector<8xf16>

func.return
}
Loading