diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td index fe97f8e9e9368..040af3b9b27c3 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td @@ -10,6 +10,7 @@ #define AMDGPU include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def AMDGPU_Dialect : Dialect { @@ -23,6 +24,7 @@ def AMDGPU_Dialect : Dialect { }]; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useDefaultAttributePrinterParser = 1; } //===----------------------------------------------------------------------===// @@ -182,5 +184,92 @@ def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> { let assemblyFormat = "attr-dict"; } +def AMDGPU_MFMAPermB : I32EnumAttr<"MFMAPermB", + "The possible permutations of the lanes storing B available in an MFMA", + [ + I32EnumAttrCase<"none", 0>, + I32EnumAttrCase<"bcast_first_32", 1>, + I32EnumAttrCase<"bcast_second_32", 2>, + I32EnumAttrCase<"rotate_16_right", 3>, + I32EnumAttrCase<"bcast_first_16", 4>, + I32EnumAttrCase<"bcast_second_16", 5>, + I32EnumAttrCase<"bcast_third_16", 6>, + I32EnumAttrCase<"bcast_fourth_16", 7> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::amdgpu"; +} + +def AMDGPU_MFMAPermBAttr : EnumAttr; + +// mfma +def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64, + VectorOfLengthAndType<[2], [F32]>, + VectorOfLengthAndType<[4], [F16]>, + VectorOfLengthAndType<[2, 4], [BF16]>, + VectorOfLengthAndType<[4, 8], [I8]>]>; +def MFMAOutTypes : AnyTypeOf<[F64, + VectorOfLengthAndType<[4, 16, 32], [F32]>, + VectorOfLengthAndType<[4, 16, 32], [I32]>, + VectorOfLengthAndType<[4], [F64]>]>; + +def AMDGPU_MFMAOp : + AMDGPU_Op<"mfma", [AllTypesMatch<["sourceA", "sourceB"]>, + AllTypesMatch<["destC", "destD"]>, + NoSideEffect]>, + Arguments<(ins + I32Attr:$m, + I32Attr:$n, + I32Attr:$k, + I32Attr:$blocks, + MFMAInTypes:$sourceA, + MFMAInTypes:$sourceB, + MFMAOutTypes:$destC, + DefaultValuedAttr:$cbsz, + DefaultValuedAttr:$abid, + DefaultValuedAttr:$blgp, + UnitAttr:$reducePrecision, + UnitAttr:$negateA, + UnitAttr:$negateB, + UnitAttr:$negateC)>, + Results<(outs MFMAOutTypes: $destD)> { + let summary = "MLIR wrapper for CDNA mfma instructions"; + let description = [{ + The `amdgpu.mfma` op is an MLIR wrapper around intrinsics + for various `mfma` instructions in the CDNA architecture, which perform + multiple outer products in order to allow fast matrix multiplication. + + The wrapper will select an appropriate `mfma` instruction, if one is available, + based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the + types of the source and destination arguments. + + For information on the layouts of the input and output matrces (which are stored + in `sourceA`, `sourceB`, `destC`, and `destD`), see the CDNA ISA documentation. + + The `cbsz`, `abid`, and `blgp` parameters control how the lanes of the wave + are permuted when matrix data is being loaded: `blgp` can be any number of + fixed permutations, `cbsz` specifies the log_2 of the number of chunks the lanes + holding sourceA are split into, and `abid` selects one of those chunks. + + Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA + intrinsics that take an integer type of width `4K`. For example, + one can provide a vector<4xi8> as an argument to an MFMA instruction that + logically takes 4 i8s but whose intrinsics are specified to take an i32. + In these cases, the bytes in the vector will be concatenated in little-endian + order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on). + + The negateA, negateB, and negateC flags are only supported for double-precision + operations on gfx940+. + }]; + let assemblyFormat = [{ + $sourceA `*` $sourceB `+` $destC + attr-dict + `blgp` `=` $blgp + `:` type($sourceA) `,` type($destC) + }]; + let hasVerifier = 1; +} #endif // AMDGPU diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h index 92193736d3297..7a7b86695ccf2 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h @@ -21,6 +21,11 @@ #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h.inc" +#include "mlir/Dialect/AMDGPU/AMDGPUEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/AMDGPU.h.inc" diff --git a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt index a2d56b067d721..ed074c205a551 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt @@ -2,3 +2,11 @@ add_mlir_dialect(AMDGPU amdgpu) add_mlir_doc(AMDGPU AMDGPU Dialects/ -gen-dialect-doc) set(LLVM_TARGET_DEFINITIONS AMDGPU.td) +mlir_tablegen(AMDGPUEnums.h.inc -gen-enum-decls) +mlir_tablegen(AMDGPUEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRAMDGPUEnumsGen) + +set(LLVM_TARGET_DEFINITIONS AMDGPU.td) +mlir_tablegen(AMDGPUAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu) +mlir_tablegen(AMDGPUAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu) +add_public_tablegen_target(MLIRAMDGPUAttributesIncGen) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 81b8f6e7a03a5..7a3bf22327834 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -11,9 +11,12 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + namespace mlir { #define GEN_PASS_DEF_CONVERTAMDGPUTOROCDL #include "mlir/Conversion/Passes.h.inc" @@ -25,7 +28,7 @@ using namespace mlir::amdgpu; static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { Type llvmI32 = rewriter.getI32Type(); - return rewriter.create(loc, llvmI32, value); + return rewriter.createOrFold(loc, llvmI32, value); } namespace { @@ -123,7 +126,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern { MemRefDescriptor memrefDescriptor(memref); Type llvmI64 = this->typeConverter->convertType(rewriter.getI64Type()); Type llvm2xI32 = this->typeConverter->convertType(VectorType::get(2, i32)); - Value c32I64 = rewriter.create(loc, llvmI64, 32); + Value c32I64 = rewriter.create( + loc, llvmI64, rewriter.getI64IntegerAttr(32)); Value resource = rewriter.create(loc, llvm4xI32); @@ -273,6 +277,173 @@ struct LDSBarrierOpLowering : public ConvertOpToLLVMPattern { return success(); } }; +} // namespace + +/// If `input` is a vector of bytes, concatentate those bytes in little-endian +/// order to form a single integer of size 8 * [vector length]. This works +/// around a wart in the AMDGPU intrinsics where operations that logically take +/// vectors of bytes instead integers. Since we do not want to expose this +/// implementation detail to MLIR, we correct for it here. +static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, + Location loc, Value input) { + Type inputType = input.getType(); + if (auto vectorType = inputType.dyn_cast()) { + if (!vectorType.getElementType().isInteger(8)) + return input; + int64_t numBytes = vectorType.getNumElements(); + Type destType = rewriter.getIntegerType(numBytes * 8); + Value result = rewriter.createOrFold( + loc, destType, rewriter.getIntegerAttr(destType, 0)); + for (int64_t i = 0; i < numBytes; ++i) { + Value idxConst = createI32Constant(rewriter, loc, i); + Value element = + rewriter.create(loc, input, idxConst); + Value extended = rewriter.create(loc, destType, element); + Value shiftConst = rewriter.createOrFold( + loc, destType, rewriter.getIntegerAttr(destType, i * 8)); + Value shifted = rewriter.create(loc, extended, shiftConst); + result = rewriter.create(loc, result, shifted); + } + return result; + } + return input; +} + +/// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` +/// if one exists. This includes checking to ensure the intrinsic is supported +/// on the architecture you are compiling for. +static Optional mfmaOpToIntrinsic(MFMAOp mfma, Chipset chipset) { + uint32_t m = mfma.getM(), n = mfma.getN(), k = mfma.getK(), + b = mfma.getBlocks(); + Type sourceElem = mfma.getSourceA().getType(); + if (auto sourceType = sourceElem.dyn_cast()) + sourceElem = sourceType.getElementType(); + Type destElem = mfma.getDestC().getType(); + if (auto destType = destElem.dyn_cast()) + destElem = destType.getElementType(); + + if (sourceElem.isF32() && destElem.isF32()) { + if (mfma.getReducePrecision() && chipset.minorVersion >= 0x40) { + if (m == 32 && n == 32 && k == 4 && b == 1) + return ROCDL::mfma_f32_32x32x4_xf32::getOperationName(); + if (m == 16 && n == 16 && k == 8 && b == 1) + return ROCDL::mfma_f32_16x16x8_xf32::getOperationName(); + } + if (m == 32 && n == 32 && k == 1 && b == 2) + return ROCDL::mfma_f32_32x32x1f32::getOperationName(); + if (m == 16 && n == 16 && k == 1 && b == 4) + return ROCDL::mfma_f32_16x16x1f32::getOperationName(); + if (m == 4 && n == 4 && k == 1 && b == 16) + return ROCDL::mfma_f32_4x4x1f32::getOperationName(); + if (m == 32 && n == 32 && k == 2 && b == 1) + return ROCDL::mfma_f32_32x32x2f32::getOperationName(); + if (m == 16 && n == 16 && k == 4 && b == 1) + return ROCDL::mfma_f32_16x16x4f32::getOperationName(); + } + + if (sourceElem.isF16() && destElem.isF32()) { + if (m == 32 && n == 32 && k == 4 && b == 2) + return ROCDL::mfma_f32_32x32x4f16::getOperationName(); + if (m == 16 && n == 16 && k == 4 && b == 4) + return ROCDL::mfma_f32_16x16x4f16::getOperationName(); + if (m == 4 && n == 4 && k == 4 && b == 16) + return ROCDL::mfma_f32_4x4x4f16::getOperationName(); + if (m == 32 && n == 32 && k == 8 && b == 1) + return ROCDL::mfma_f32_32x32x8f16::getOperationName(); + if (m == 16 && n == 16 && k == 16 && b == 1) + return ROCDL::mfma_f32_16x16x16f16::getOperationName(); + } + + if (sourceElem.isBF16() && destElem.isF32() && chipset.minorVersion >= 0x0a) { + if (m == 32 && n == 32 && k == 4 && b == 2) + return ROCDL::mfma_f32_32x32x4bf16_1k::getOperationName(); + if (m == 16 && n == 16 && k == 4 && b == 4) + return ROCDL::mfma_f32_16x16x4bf16_1k::getOperationName(); + if (m == 4 && n == 4 && k == 4 && b == 16) + return ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName(); + if (m == 32 && n == 32 && k == 8 && b == 1) + return ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName(); + if (m == 16 && n == 16 && k == 16 && b == 1) + return ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName(); + } + + if (sourceElem.isBF16() && destElem.isF32()) { + if (m == 32 && n == 32 && k == 2 && b == 2) + return ROCDL::mfma_f32_32x32x2bf16::getOperationName(); + if (m == 16 && n == 16 && k == 2 && b == 4) + return ROCDL::mfma_f32_16x16x2bf16::getOperationName(); + if (m == 4 && n == 4 && k == 2 && b == 16) + return ROCDL::mfma_f32_4x4x2bf16::getOperationName(); + if (m == 32 && n == 32 && k == 4 && b == 1) + return ROCDL::mfma_f32_32x32x4bf16::getOperationName(); + if (m == 16 && n == 16 && k == 8 && b == 1) + return ROCDL::mfma_f32_16x16x8bf16::getOperationName(); + } + + if (sourceElem.isa() && destElem.isInteger(32)) { + if (m == 32 && n == 32 && k == 4 && b == 2) + return ROCDL::mfma_i32_32x32x4i8::getOperationName(); + if (m == 16 && n == 16 && k == 4 && b == 4) + return ROCDL::mfma_i32_16x16x4i8::getOperationName(); + if (m == 4 && n == 4 && k == 4 && b == 16) + return ROCDL::mfma_i32_4x4x4i8::getOperationName(); + if (m == 32 && n == 32 && k == 8 && b == 1) + return ROCDL::mfma_i32_32x32x8i8::getOperationName(); + if (m == 16 && n == 16 && k == 16 && b == 1) + return ROCDL::mfma_i32_16x16x16i8::getOperationName(); + if (m == 32 && n == 32 && k == 16 && b == 1 && chipset.minorVersion >= 0x40) + return ROCDL::mfma_i32_32x32x16_i8::getOperationName(); + if (m == 16 && n == 16 && k == 32 && b == 1 && chipset.minorVersion >= 0x40) + return ROCDL::mfma_i32_16x16x32_i8::getOperationName(); + } + + if (sourceElem.isF64() && destElem.isF64() && chipset.minorVersion >= 0x0a) { + if (m == 16 && n == 16 && k == 4 && b == 1) + return ROCDL::mfma_f64_16x16x4f64::getOperationName(); + if (m == 4 && n == 4 && k == 4 && b == 4) + return ROCDL::mfma_f64_4x4x4f64::getOperationName(); + } + return None; +} + +namespace { +struct MFMAOpLowering : public ConvertOpToLLVMPattern { + MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Type outType = typeConverter->convertType(op.getDestD().getType()); + + if (chipset.majorVersion != 9 || chipset.minorVersion < 0x08) + return op->emitOpError("MFMA only supported on gfx908+"); + uint32_t getBlgpField = static_cast(op.getBlgp()); + if (op.getNegateA() || op.getNegateB() || op.getNegateC()) { + if (chipset.minorVersion < 0x40) + return op.emitOpError("negation unsupported on older than gfx840"); + getBlgpField |= + op.getNegateA() | (op.getNegateB() << 1) | (op.getNegateC() << 2); + } + Optional maybeIntrinsic = mfmaOpToIntrinsic(op, chipset); + if (!maybeIntrinsic.has_value()) + return op.emitOpError("no intrinsic matching MFMA size on given chipset"); + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes(outType); + loweredOp.addOperands( + {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()), + mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()), + adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()), + createI32Constant(rewriter, loc, op.getAbid()), + createI32Constant(rewriter, loc, getBlgpField)}); + Operation *lowered = rewriter.create(loweredOp); + rewriter.replaceOp(op, lowered->getResults()); + return success(); + } +}; struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLBase { @@ -290,6 +461,7 @@ struct ConvertAMDGPUToROCDLPass LLVMTypeConverter converter(ctx); populateAMDGPUToROCDLConversionPatterns(converter, patterns, *maybeChipset); LLVMConversionTarget target(getContext()); + target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); if (failed(applyPartialConversion(getOperation(), target, @@ -306,8 +478,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, patterns.add< RawBufferOpLowering, RawBufferOpLowering, - RawBufferOpLowering>( - converter, chipset); + RawBufferOpLowering, + MFMAOpLowering>(converter, chipset); } std::unique_ptr mlir::createConvertAMDGPUToROCDLPass() { diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 68639da21428b..05b0e621cf0f9 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -11,19 +11,29 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" + #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; +using namespace mlir::amdgpu; #include "mlir/Dialect/AMDGPU/AMDGPUDialect.cpp.inc" -void amdgpu::AMDGPUDialect::initialize() { +void AMDGPUDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc" + >(); } //===----------------------------------------------------------------------===// @@ -44,17 +54,78 @@ static LogicalResult verifyRawBufferOp(T &op) { return success(); } -LogicalResult amdgpu::RawBufferLoadOp::verify() { - return verifyRawBufferOp(*this); -} +LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); } + +LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); } -LogicalResult amdgpu::RawBufferStoreOp::verify() { +LogicalResult RawBufferAtomicFaddOp::verify() { return verifyRawBufferOp(*this); } -LogicalResult amdgpu::RawBufferAtomicFaddOp::verify() { - return verifyRawBufferOp(*this); +//===----------------------------------------------------------------------===// +// MFMAOp +//===----------------------------------------------------------------------===// +LogicalResult MFMAOp::verify() { + constexpr uint32_t waveSize = 64; + Builder b(getContext()); + + Type sourceType = getSourceA().getType(); + Type destType = getDestC().getType(); + + Type sourceElem = sourceType, destElem = destType; + uint32_t sourceLen = 1, destLen = 1; + if (auto sourceVector = sourceType.dyn_cast()) { + sourceLen = sourceVector.getNumElements(); + sourceElem = sourceVector.getElementType(); + } + if (auto destVector = destType.dyn_cast()) { + destLen = destVector.getNumElements(); + destElem = destVector.getElementType(); + } + + // Normalize the wider integer types the compiler expects to i8 + if (sourceElem.isInteger(32)) { + sourceLen *= 4; + sourceElem = b.getI8Type(); + } + if (sourceElem.isInteger(64)) { + sourceLen *= 8; + sourceElem = b.getI8Type(); + } + + int64_t numSourceElems = (getM() * getK() * getBlocks()) / waveSize; + if (sourceLen != numSourceElems) + return emitOpError("expected " + Twine(numSourceElems) + + " source values for this operation but got " + + Twine(sourceLen)); + + int64_t numDestElems = (getM() * getN() * getBlocks()) / waveSize; + if (destLen != numDestElems) + return emitOpError("expected " + Twine(numDestElems) + + " result values for this operation but got " + + Twine(destLen)); + + if (destElem.isF64() && getBlgp() != MFMAPermB::none) + return emitOpError( + "double-precision ops do not support permuting lanes of B"); + if (destElem.isF64() && getCbsz() != 0) + return emitOpError( + "double-precision ops do not support permuting lanes of A"); + if (getAbid() >= (1 << getCbsz())) + return emitOpError( + "block ID for permuting A (abid) must be below 2 ** cbsz"); + + if ((getNegateA() || getNegateB() || getNegateC()) && !destElem.isF64()) + return emitOpError( + "negation flags only available for double-precision operations"); + + return success(); } +#include "mlir/Dialect/AMDGPU/AMDGPUEnums.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc" diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt index 60834c45babf1..1b80265baa90b 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt @@ -5,6 +5,8 @@ add_mlir_dialect_library(MLIRAMDGPUDialect ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU DEPENDS + MLIRAMDGPUEnumsGen + MLIRAMDGPUAttributesIncGen MLIRAMDGPUIncGen LINK_LIBS PUBLIC diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir new file mode 100644 index 0000000000000..9117cd7b2126c --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s +func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>, + %arg2 : vector<16xf32>, %arg3 : vector<4xf32>, + %arg4 : vector<4xf16>, %arg5 : vector<4xi8>, + %arg6 : vector<32xi32>, %arg7 : vector<16xi32>, + %arg8 : vector<4xi32>, %arg9 : vector<2xbf16>, + %arg10 : vector<4xbf16>, %arg11 : f64, + %arg12 : vector<4xf64>, %arg13 : vector<8xi8>, + %arg14 : vector<2xf32>) { + // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : f32, vector<32xf32> + // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : f32, vector<16xf32> + // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : f32, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : f32, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f32, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xf16>, vector<32xf32> + // CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xf16>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf32> + // CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32> + amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<32xi32> + // CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<16xi32> + // CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xi8>, vector<4xi32> + // CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<16xi32> + // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi32> + // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<32xf32> + // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<32xf32> + // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xf32> + // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64> + amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, vector<4xf64> + // CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64 + amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 } blgp = none : f64, f64 + // CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<4xi32> + // CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<16xi32> + // CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<16xf32> + func.return +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir new file mode 100644 index 0000000000000..9ac8038655dd6 --- /dev/null +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func.func @bad_source_arguments(%a: vector<2xf32>, %b: vector<2xf32>, + %c: vector<32xf32>) -> vector<32xf32> { + // expected-error@+1 {{'amdgpu.mfma' op expected 1 source values for this operation but got 2}} + %d = amdgpu.mfma %a * %b + %c { + m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, + abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<32xf32> + func.return %d : vector<32xf32> +} + +// ----- + +func.func @bad_source_arguments_i8(%a: vector<8xi8>, %b: vector<8xi8>, + %c: vector<4xi32>) -> vector<4xi32> { + // expected-error@+1 {{'amdgpu.mfma' op expected 4 source values for this operation but got 8}} + %d = amdgpu.mfma %a * %b + %c { + m = 32 : i32, n = 32 : i32, k = 4 : i32, blocks = 2 : i32, + abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<4xi32> + func.return %d : vector<4xi32> +} + +// ----- + +func.func @bad_dest_type(%a: f32, %b: f32, %c: vector<16xf32>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.mfma' op expected 32 result values for this operation but got 16}} + %d = amdgpu.mfma %a * %b + %c { + m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, + abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, vector<16xf32> + return %d : vector<16xf32> +} + +// ----- + +func.func @f64_permuting_b(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64> { + // expected-error@+1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of B}} + %d = amdgpu.mfma %a * %b + %c { + m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32, + abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, vector<4xf64> + return %d : vector<4xf64> +} + +// ----- + +func.func @f64_permuting_a(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64> { + // expected-error@+1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of A}} + %d = amdgpu.mfma %a * %b + %c { + m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32, + abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, vector<4xf64> + return %d : vector<4xf64> +} + +// ----- + +func.func @abid_without_bradcast(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> { + // expected-error@+1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}} + %d = amdgpu.mfma %a * %b + %c { + m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, + abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, vector<32xf32> + func.return %d : vector<32xf32> +} + +// ----- + +func.func @abid_too_large(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> { + // expected-error@+1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}} + %d = amdgpu.mfma %a * %b + %c { + m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, + abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, vector<32xf32> + func.return %d : vector<32xf32> +} + +// ----- + +func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> { + // expected-error@+1 {{'amdgpu.mfma' op negation flags only available for double-precision operations}} + %d = amdgpu.mfma %a * %b + %c { + m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, + abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, vector<32xf32> + func.return %d : vector<32xf32> +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir index 3fff10c666ba2..40a88e25c4dd8 100644 --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -66,3 +66,10 @@ func.func @lds_barrier() { amdgpu.lds_barrier func.return } + +// CHECK-LABEL: func @mfma +func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> { + // CHECK: amdgpu.mfma + %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, vector<32xf32> + func.return %0 : vector<32xf32> +}