Skip to content

Commit

Permalink
[AMD][Navi31] Convert WMMA dot op to LLVM
Browse files Browse the repository at this point in the history
Add WMMA conversion logic for dot operation.

Signed-off-by: joviliast <iveselov.nn@gmail.com>
  • Loading branch information
joviliast committed Feb 29, 2024
1 parent 57ae184 commit f2f21c1
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 42 deletions.
1 change: 1 addition & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt
Expand Up @@ -2,6 +2,7 @@ add_triton_library(TritonAMDGPUToLLVM
ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp
ConvertLayoutOpToLLVM.cpp
DotOpToLLVM/MFMA.cpp
DotOpToLLVM/WMMA.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
LoadStoreOpToLLVM.cpp
Expand Down
84 changes: 42 additions & 42 deletions third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp
Expand Up @@ -38,7 +38,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MfmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;

enum class MatrixCoreType : uint8_t {
enum class MFMAInstrType : uint8_t {
// D = AB + C
FP32_FP8_FP8_FP32,
FP32_FP8_BF8_FP32,
Expand All @@ -55,7 +55,7 @@ enum class MatrixCoreType : uint8_t {
};

struct MFMAInstrDescr {
MatrixCoreType coreType;
MFMAInstrType coreType;
unsigned size;
};

Expand All @@ -82,52 +82,52 @@ struct DotOpMFMAConversionHelper {
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
}

Value generateMFMA32Op(MatrixCoreType coreType, Value valA, Value valB,
Value generateMFMA32Op(MFMAInstrType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP8_FP8_FP32:
case MFMAInstrType::FP32_FP8_FP8_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP8_BF8_FP32:
case MFMAInstrType::FP32_FP8_BF8_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_FP8_FP32:
case MFMAInstrType::FP32_BF8_FP8_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_BF8_FP32:
case MFMAInstrType::FP32_BF8_BF8_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP16_FP16_FP32:
case MFMAInstrType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
case MFMAInstrType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
case MFMAInstrType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
case MFMAInstrType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
case MFMAInstrType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
case MFMAInstrType::INT32_INT8_INT8_INT32_CDNA3:
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
case MFMAInstrType::FP64_FP64_FP64_FP64:
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
Expand All @@ -136,36 +136,36 @@ struct DotOpMFMAConversionHelper {
}
}

Value generateMFMA16Op(MatrixCoreType coreType, Value valA, Value valB,
Value generateMFMA16Op(MFMAInstrType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP16_FP16_FP32:
case MFMAInstrType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
case MFMAInstrType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
case MFMAInstrType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
case MFMAInstrType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
case MFMAInstrType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
case MFMAInstrType::INT32_INT8_INT8_INT32_CDNA3:
return rewriter.create<ROCDL::mfma_i32_16x16x32_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
case MFMAInstrType::FP64_FP64_FP64_FP64:
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
Expand All @@ -174,28 +174,28 @@ struct DotOpMFMAConversionHelper {
}
}

Value generateMFMA4Op(MatrixCoreType coreType, Value valA, Value valB,
Value generateMFMA4Op(MFMAInstrType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP16_FP16_FP32:
case MFMAInstrType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x4f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
case MFMAInstrType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x2bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
case MFMAInstrType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_4x4x4bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
case MFMAInstrType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x1f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
case MFMAInstrType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_4x4x4i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
Expand Down Expand Up @@ -241,7 +241,7 @@ struct DotOpMFMAConversionHelper {
}

// TODO unify this function with Utility.cpp:supportMFMATypes
static MatrixCoreType getMatrixCoreTypeFromDot(DotOp op) {
static MFMAInstrType getMFMAInstrTypeFromDot(DotOp op) {
auto aOperandTy = op.getA().getType();
auto aTensorTy = aOperandTy.cast<RankedTensorType>();
auto aElemTy = aTensorTy.getElementType();
Expand All @@ -252,48 +252,48 @@ struct DotOpMFMAConversionHelper {
auto dotOpEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto mfmaEncoding = dotOpEncoding.getParent().cast<MfmaEncodingAttr>();
if (aElemTy.isFloat8E4M3FNUZ() && bElemTy.isFloat8E4M3FNUZ())
return MatrixCoreType::FP32_FP8_FP8_FP32;
return MFMAInstrType::FP32_FP8_FP8_FP32;
if (aElemTy.isFloat8E4M3FNUZ() && bElemTy.isFloat8E5M2FNUZ())
return MatrixCoreType::FP32_FP8_BF8_FP32;
return MFMAInstrType::FP32_FP8_BF8_FP32;
if (aElemTy.isFloat8E5M2FNUZ() && bElemTy.isFloat8E4M3FNUZ())
return MatrixCoreType::FP32_BF8_FP8_FP32;
return MFMAInstrType::FP32_BF8_FP8_FP32;
if (aElemTy.isFloat8E5M2FNUZ() && bElemTy.isFloat8E5M2FNUZ())
return MatrixCoreType::FP32_BF8_BF8_FP32;
return MFMAInstrType::FP32_BF8_BF8_FP32;
if (aElemTy.isF16())
return MatrixCoreType::FP32_FP16_FP16_FP32;
return MFMAInstrType::FP32_FP16_FP16_FP32;
if (aElemTy.isF32())
return MatrixCoreType::FP32_FP32_FP32_FP32;
return MFMAInstrType::FP32_FP32_FP32_FP32;
if (aElemTy.isBF16()) {
auto nonKDim = mfmaEncoding.getNonKDim();
auto kWidth = dotOpEncoding.getKWidth();
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4) {
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
return MFMAInstrType::FP32_BF16_BF16_FP32_1K;
} else {
assert((nonKDim == 32 && kWidth == 2) ||
(nonKDim == 16 && kWidth == 2) || (nonKDim == 4 && kWidth == 2));
return MatrixCoreType::FP32_BF16_BF16_FP32;
return MFMAInstrType::FP32_BF16_BF16_FP32;
}
}
if (aElemTy.isInteger(8)) {
auto nonKDim = mfmaEncoding.getNonKDim();
auto kWidth = dotOpEncoding.getKWidth();
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 8) {
return MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3;
return MFMAInstrType::INT32_INT8_INT8_INT32_CDNA3;
} else {
assert((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4);
return MatrixCoreType::INT32_INT8_INT8_INT32;
return MFMAInstrType::INT32_INT8_INT8_INT32;
}
}
if (aElemTy.isF64())
return MatrixCoreType::FP64_FP64_FP64_FP64;
return MatrixCoreType::NOT_APPLICABLE;
return MFMAInstrType::FP64_FP64_FP64_FP64;
return MFMAInstrType::NOT_APPLICABLE;
}

static MFMAInstrDescr getMatrixInstrDescr(DotOp op) {
MFMAInstrDescr descr;
auto tensorTy = op.getD().getType().cast<RankedTensorType>();
auto encoding = tensorTy.getEncoding().cast<MfmaEncodingAttr>();
descr.coreType = getMatrixCoreTypeFromDot(op);
descr.coreType = getMFMAInstrTypeFromDot(op);
descr.size = encoding.getNonKDim();
return descr;
}
Expand Down

0 comments on commit f2f21c1

Please sign in to comment.