diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 4865dc13f324b..8565a6b727fd1 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -990,6 +990,34 @@ def MFMAOutTypes : AnyTypeOf<[F64, VectorOfLengthAndType<[4, 16, 32], [F32]>, VectorOfLengthAndType<[4, 16, 32], [I32]>, VectorOfLengthAndType<[4], [F64]>]>; + +// sparse_mfma (smfmac) +def SMFMACSparseInTypes : AnyTypeOf<[ + VectorOfLengthAndType<[4, 8], [F16]>, + VectorOfLengthAndType<[4, 8], [BF16]>, + VectorOfLengthAndType<[8, 16], [I8]>, + VectorOfLengthAndType<[8, 16], [F8E4M3FN, F8E5M2]>, + VectorOfLengthAndType<[8, 16], [F8E4M3FNUZ, F8E5M2FNUZ]> +]>; + +def SMFMACDenseInTypes : AnyTypeOf<[ + VectorOfLengthAndType<[8, 16], [F16]>, + VectorOfLengthAndType<[8, 16], [BF16]>, + VectorOfLengthAndType<[16, 32], [I8]>, + VectorOfLengthAndType<[16, 32], [F8E4M3FN, F8E5M2]>, + VectorOfLengthAndType<[16, 32], [F8E4M3FNUZ, F8E5M2FNUZ]> +]>; + +def SMFMACOutTypes : AnyTypeOf<[ + VectorOfLengthAndType<[4, 16], [F32]>, + VectorOfLengthAndType<[4, 16], [I32]> +]>; + +def SMFMACIdxTypes : AnyTypeOf<[ + FixedVectorOfLengthAndType<[4], [I8]>, + FixedVectorOfLengthAndType<[2], [I16]> +]>; + // scaled_mfma def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>, VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>; @@ -1138,6 +1166,66 @@ def AMDGPU_WMMAOp : let hasVerifier = 1; } +def AMDGPU_SparseMFMAOp : + AMDGPU_Op<"sparse_mfma", [AllTypesMatch<["destC", "destD"]>, + Pure]>, + Arguments<(ins + ConfinedAttr]>:$m, + ConfinedAttr]>:$n, + ConfinedAttr]>:$k, + SMFMACSparseInTypes:$sourceA, + SMFMACDenseInTypes:$sourceB, + SMFMACOutTypes:$destC, + SMFMACIdxTypes:$sparseIdx, + DefaultValuedAttr:$cbsz, + DefaultValuedAttr:$abid)>, + Results<(outs SMFMACOutTypes: $destD)> { + let summary = "MLIR wrapper for CDNA sparse mfma (smfmac) instructions"; + let description = [{ + The `amdgpu.sparse_mfma` op is an MLIR wrapper around intrinsics for various + `smfmac` instructions in the AMDGPU architecture, which perform matrix + multiply-accumulate operations using 2:4 structured sparsity on matrix A + with dense matrices B, C, and D. + + On gfx942, smfmac intrinsics support: + - M=N=16, K=32 and M=N=32, K=16 for f16 and bf16 sources + - M=N=16, K=64 and M=N=32, K=32 for i8 and fp8 sources + + On gfx950, smfmac intrinsics additionally support: + - M=N=16, K=64 and M=N=32, K=32 for f16 and bf16 sources + - M=N=16, K=128 and M=N=32, K=64 for i8 and fp8 sources + + The `sparseIdx` parameter contains packed indices identifying the positions + of non-zero elements in the 2:4 sparse matrix A. For 16-bit source data, + use `vector<4xi8>` (four 8-bit indices). For 8-bit source data, use + `vector<2xi16>` (two 16-bit indices). + + The `cbsz` and `abid` parameters are repurposed to select the index set. + If `cbsz == 0`, then `abid[1:0]` selects which index set to use. + If `cbsz != 0`, then the very first is selected. + + Example: + ```mlir + %0 = amdgpu.sparse_mfma 16x16x32 %matA * %matB + %matC sparse(%idx : vector<4xi8>) + : vector<4xf16>, vector<8xf16>, vector<4xf32> + + %1 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>) + : vector<8xi8>, vector<16xi8>, vector<4xi32> + + %2 = amdgpu.sparse_mfma 16x16x64 %matA * %matB + %matC sparse(%idx : vector<2xi16>) + { cbsz = 0 : i32, abid = 1 : i32 } + : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32> + ``` + }]; + let assemblyFormat = [{ + custom($m, $n, $k) $sourceA `*` $sourceB `+` $destC + `sparse` `(` $sparseIdx `:` type($sparseIdx) `)` + attr-dict + `:` type($sourceA) `,` type($sourceB) `,` type($destC) + }]; + let hasVerifier = 1; +} + def AMDGPU_GatherToLDSOp : AMDGPU_Op<"gather_to_lds", [AttrSizedOperandSegments]>, Arguments<(ins diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 73d5376f970ae..5dcd24019412a 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -661,6 +661,27 @@ static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter, return input; } +/// Converts sparse MFMA (smfmac) operands to the expected ROCDL types. +static Value convertSparseMFMAVectorOperand(ConversionPatternRewriter &rewriter, + Location loc, Value input, + bool allowBf16 = true) { + Type inputType = input.getType(); + auto vectorType = cast(inputType); + // bf16 -> i16 when not allowed (pre-gfx950). + if (vectorType.getElementType().isBF16() && !allowBf16) + return LLVM::BitcastOp::create( + rewriter, loc, vectorType.clone(rewriter.getI16Type()), input); + // i8/fp8 vectors -> vector. + if (isa(vectorType.getElementType()) && + vectorType.getElementTypeBitWidth() <= 8) { + int64_t numWords = llvm::divideCeil( + vectorType.getNumElements() * vectorType.getElementTypeBitWidth(), 32); + return LLVM::BitcastOp::create( + rewriter, loc, VectorType::get(numWords, rewriter.getI32Type()), input); + } + return input; +} + /// Converts the scaled MFMA/WMMA operands, `scalesA` and `scalesB`, from MLIR /// AMDGPU dialect convention to ROCDL and LLVM AMDGPU intrinsics convention. /// @@ -1171,6 +1192,105 @@ static std::optional wmmaOpToIntrinsicGfx1250(Type elemSourceType, return std::nullopt; } +/// Returns the `rocdl` intrinsic corresponding to a SparseMFMA (smfmac) +/// operation if one exists. This includes checking to ensure the intrinsic is +/// supported on the architecture you are compiling for. +static std::optional smfmacOpToIntrinsic(SparseMFMAOp op, + Chipset chipset) { + bool isGfx950 = chipset >= kGfx950; + auto isFp8 = [&](Type t) { return typeIsExpectedFp8ForChipset(chipset, t); }; + auto isBf8 = [&](Type t) { return typeIsExpectedBf8ForChipset(chipset, t); }; + + uint32_t m = op.getM(), n = op.getN(), k = op.getK(); + Type sourceAElem = getElementTypeOrSelf(op.getSourceA().getType()); + Type sourceBElem = getElementTypeOrSelf(op.getSourceB().getType()); + Type destElem = getElementTypeOrSelf(op.getDestC().getType()); + + if (m == 16 && n == 16 && k == 32) { + if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x32_f16::getOperationName(); + if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x32_bf16::getOperationName(); + } + + if (m == 16 && n == 16 && k == 64) { + if (isGfx950) { + if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x64_f16::getOperationName(); + if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x64_bf16::getOperationName(); + } + if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) && + destElem.isInteger(32)) + return ROCDL::smfmac_i32_16x16x64_i8::getOperationName(); + if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x64_fp8_fp8::getOperationName(); + if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x64_fp8_bf8::getOperationName(); + if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x64_bf8_fp8::getOperationName(); + if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x64_bf8_bf8::getOperationName(); + } + + if (m == 16 && n == 16 && k == 128 && isGfx950) { + if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) && + destElem.isInteger(32)) + return ROCDL::smfmac_i32_16x16x128_i8::getOperationName(); + if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x128_fp8_fp8::getOperationName(); + if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x128_fp8_bf8::getOperationName(); + if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x128_bf8_fp8::getOperationName(); + if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_16x16x128_bf8_bf8::getOperationName(); + } + + if (m == 32 && n == 32 && k == 16) { + if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x16_f16::getOperationName(); + if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x16_bf16::getOperationName(); + } + + if (m == 32 && n == 32 && k == 32) { + if (isGfx950) { + if (sourceAElem.isF16() && sourceBElem.isF16() && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x32_f16::getOperationName(); + if (sourceAElem.isBF16() && sourceBElem.isBF16() && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x32_bf16::getOperationName(); + } + if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) && + destElem.isInteger(32)) + return ROCDL::smfmac_i32_32x32x32_i8::getOperationName(); + if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x32_fp8_fp8::getOperationName(); + if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x32_fp8_bf8::getOperationName(); + if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x32_bf8_fp8::getOperationName(); + if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x32_bf8_bf8::getOperationName(); + } + + if (m == 32 && n == 32 && k == 64 && isGfx950) { + if (sourceAElem.isInteger(8) && sourceBElem.isInteger(8) && + destElem.isInteger(32)) + return ROCDL::smfmac_i32_32x32x64_i8::getOperationName(); + if (isFp8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x64_fp8_fp8::getOperationName(); + if (isFp8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x64_fp8_bf8::getOperationName(); + if (isBf8(sourceAElem) && isFp8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x64_bf8_fp8::getOperationName(); + if (isBf8(sourceAElem) && isBf8(sourceBElem) && destElem.isF32()) + return ROCDL::smfmac_f32_32x32x64_bf8_bf8::getOperationName(); + } + + return std::nullopt; +} + /// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. @@ -1326,6 +1446,52 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern { } }; +struct SparseMFMAOpLowering : public ConvertOpToLLVMPattern { + SparseMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(SparseMFMAOp op, SparseMFMAOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto outType = + typeConverter->convertType(op.getDestC().getType()); + if (!outType) + return rewriter.notifyMatchFailure(op, "type conversion failed"); + + // smfmac is supported on gfx942 and gfx950. + if (chipset.majorVersion != 9 || chipset < kGfx942) + return op->emitOpError("sparse MFMA (smfmac) only supported on gfx942+"); + bool isGfx950 = chipset >= kGfx950; + + Value a = convertSparseMFMAVectorOperand(rewriter, loc, + adaptor.getSourceA(), isGfx950); + Value b = convertSparseMFMAVectorOperand(rewriter, loc, + adaptor.getSourceB(), isGfx950); + Value c = adaptor.getDestC(); + + std::optional maybeIntrinsic = smfmacOpToIntrinsic(op, chipset); + if (!maybeIntrinsic.has_value()) + return op.emitOpError( + "no intrinsic matching sparse MFMA on the given chipset"); + + // Bitcast sparse indices from vector<4xi8> or vector<2xi16> to i32. + Value sparseIdx = LLVM::BitcastOp::create( + rewriter, loc, rewriter.getI32Type(), adaptor.getSparseIdx()); + + OperationState loweredOp(loc, maybeIntrinsic.value()); + loweredOp.addTypes(outType); + loweredOp.addOperands({a, b, c, sparseIdx, + createI32Constant(rewriter, loc, op.getCbsz()), + createI32Constant(rewriter, loc, op.getAbid())}); + Value lowered = rewriter.create(loweredOp)->getResult(0); + rewriter.replaceOp(op, lowered); + return success(); + } +}; + struct WMMAOpLowering : public ConvertOpToLLVMPattern { WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset) : ConvertOpToLLVMPattern(converter), chipset(chipset) {} @@ -3367,12 +3533,12 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicCmpSwap>, AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering, - WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering, - ScaledExtPackedMatrixOpLowering, ScaledExtPackedOpLowering, - PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, - PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, - TransposeLoadOpLowering, AMDGPUPermlaneLowering, - AMDGPUMakeDmaBaseLowering, + SparseMFMAOpLowering, WMMAOpLowering, ScaledWMMAOpLowering, + ExtPackedFp8OpLowering, ScaledExtPackedMatrixOpLowering, + ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, + PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, + GatherToLDSOpLowering, TransposeLoadOpLowering, + AMDGPUPermlaneLowering, AMDGPUMakeDmaBaseLowering, AMDGPUMakeDmaBaseLowering, AMDGPULowerDescriptor, AMDGPULowerDescriptor, diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index bef0328c7c73e..e77d131509add 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -632,6 +632,78 @@ LogicalResult MFMAOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// SparseMFMAOp +//===----------------------------------------------------------------------===// + +LogicalResult SparseMFMAOp::verify() { + constexpr uint32_t waveSize = 64; + + auto sparseType = cast(getSourceA().getType()); + auto denseType = cast(getSourceB().getType()); + auto destType = cast(getDestC().getType()); + + Type sparseElem = sparseType.getElementType(); + Type denseElem = denseType.getElementType(); + int64_t sparseLen = sparseType.getNumElements(); + int64_t denseLen = denseType.getNumElements(); + int64_t destLen = destType.getNumElements(); + + if (denseLen != 2 * sparseLen) + return emitOpError("expected dense source operand to have exactly double " + "the number of elements of the sparse source operand"); + + // Check that source element types are compatible. + // For fp8/bf8 mixed operations, element types can differ (e.g., fp8 * bf8). + // For other types, element types must match exactly. + bool bothFloat8 = sparseElem.isFloat(8) && denseElem.isFloat(8); + if (!bothFloat8 && sparseElem != denseElem) + return emitOpError( + "expected source operands to have the same element type"); + + // When CBSZ == 0, ABID selects the index set within the sparse index VGPR. + // When CBSZ != 0, the first index set is always used (ABID ignored). + bool is8BitSource = sparseElem.isFloat(8) || sparseElem.isInteger(8); + // 8-bit source: ABID selects one of two 16-bit index sets. + if (getCbsz() == 0 && is8BitSource && getAbid() > 1) + return emitOpError("ABID must be 0 or 1 for 8-bit source data"); + // 16-bit source: ABID selects one of four 8-bit index sets (0-3 all valid). + if (getCbsz() == 0 && !is8BitSource && getAbid() > 3) + return emitOpError("ABID must be between 0 and 3 for 16-bit source data"); + + // Validate sparseIdx type matches source element type. + auto sparseIdxType = cast(getSparseIdx().getType()); + if (is8BitSource) { + // 8-bit source data requires vector<2xi16> sparse indices. + if (sparseIdxType.getNumElements() != 2 || + !sparseIdxType.getElementType().isInteger(16)) + return emitOpError("expected vector<2xi16> sparse indices for 8-bit " + "source data, but got ") + << getSparseIdx().getType(); + } else { + // 16-bit source data requires vector<4xi8> sparse indices. + if (sparseIdxType.getNumElements() != 4 || + !sparseIdxType.getElementType().isInteger(8)) + return emitOpError("expected vector<4xi8> sparse indices for 16-bit " + "source data, but got ") + << getSparseIdx().getType(); + } + + int64_t expectedSourceElems = (getM() * getK()) / waveSize; + if (denseLen != expectedSourceElems) + return emitOpError("expected " + Twine(expectedSourceElems) + + " source values for this operation but got " + + Twine(denseLen)); + + int64_t expectedDestElems = (getM() * getN()) / waveSize; + if (destLen != expectedDestElems) + return emitOpError("expected " + Twine(expectedDestElems) + + " result values for this operation but got " + + Twine(destLen)); + + return success(); +} + //===----------------------------------------------------------------------===// // DPPOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir new file mode 100644 index 0000000000000..266e0e7e15595 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma-gfx950.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 -cse | FileCheck %s +func.func @sparse_mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf16>, + %arg2 : vector<4xf32>, %arg3 : vector<16xf32>, + %arg4 : vector<8xbf16>, %arg5 : vector<16xbf16>, + %arg6 : vector<16xi8>, %arg7 : vector<32xi8>, + %arg8 : vector<4xi32>, %arg9 : vector<16xi32>, + %arg10 : vector<16xf8E4M3FN>, %arg11 : vector<16xf8E5M2>, + %arg12 : vector<32xf8E4M3FN>, %arg13 : vector<32xf8E5M2>, + %arg14 : vector<4xi8>, %arg15 : vector<2xi16>) { + // CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32 + // CHECK: rocdl.smfmac.f32.16x16x64.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x64 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<4xf32> + + // CHECK: rocdl.smfmac.f32.16x16x64.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x64 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<4xf32> + + // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32> + // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32> + // CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32 + // CHECK: rocdl.smfmac.i32.16x16x128.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + amdgpu.sparse_mfma 16x16x128 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<4xi32> + + // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32> + // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32> + // CHECK: rocdl.smfmac.f32.16x16x128.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x128 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32> + + // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32> + // CHECK: llvm.bitcast {{.*}} : vector<32xi8> to vector<8xi32> + // CHECK: rocdl.smfmac.f32.16x16x128.bf8.bf8 {{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x128 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<4xf32> + + // CHECK: rocdl.smfmac.f32.16x16x128.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x128 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<4xf32> + + // CHECK: rocdl.smfmac.f32.16x16x128.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x128 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<4xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.f16{{.*}}: (vector<8xf16>, vector<16xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x32 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf16>, vector<16xf16>, vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.bf16{{.*}}: (vector<8xbf16>, vector<16xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x32 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xbf16>, vector<16xbf16>, vector<16xf32> + + // CHECK: rocdl.smfmac.i32.32x32x64.i8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + amdgpu.sparse_mfma 32x32x64 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xi8>, vector<32xi8>, vector<16xi32> + + // CHECK: rocdl.smfmac.f32.32x32x64.fp8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x64 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x64.bf8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x64 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E5M2>, vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x64.fp8.bf8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x64 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E4M3FN>, vector<32xf8E5M2>, vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x64.bf8.fp8{{.*}}: (vector<4xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x64 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<16xf8E5M2>, vector<32xf8E4M3FN>, vector<16xf32> + + func.return +} diff --git a/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir new file mode 100644 index 0000000000000..b2c91c3d9bed1 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/sparse-mfma.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 -cse | FileCheck %s +func.func @sparse_mfma_to_rocdl(%arg0 : vector<4xf16>, %arg1 : vector<8xf16>, + %arg2 : vector<4xf32>, %arg3 : vector<16xf32>, + %arg4 : vector<4xbf16>, %arg5 : vector<8xbf16>, + %arg6 : vector<8xi8>, %arg7 : vector<16xi8>, + %arg8 : vector<4xi32>, %arg9 : vector<16xi32>, + %arg10 : vector<8xf8E4M3FNUZ>, %arg11 : vector<8xf8E5M2FNUZ>, + %arg12 : vector<16xf8E4M3FNUZ>, %arg13 : vector<16xf8E5M2FNUZ>, + %arg14 : vector<4xi8>, %arg15 : vector<2xi16>) { + // CHECK: llvm.bitcast %{{.*}} : vector<4xi8> to i32 + // CHECK: rocdl.smfmac.f32.16x16x32.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x32 %arg0 * %arg1 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32> + + // CHECK: llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16> + // CHECK: llvm.bitcast {{.*}} : vector<8xbf16> to vector<8xi16> + // CHECK: rocdl.smfmac.f32.16x16x32.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x32 %arg4 * %arg5 + %arg2 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<4xf32> + + // CHECK: rocdl.smfmac.f32.32x32x16.f16{{.*}}: (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x16 %arg0 * %arg1 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x16.bf16 {{.*}}: (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x16 %arg4 * %arg5 + %arg3 sparse(%arg14 : vector<4xi8>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<4xbf16>, vector<8xbf16>, vector<16xf32> + + // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32> + // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32> + // CHECK: llvm.bitcast %{{.*}} : vector<2xi16> to i32 + // CHECK: rocdl.smfmac.i32.16x16x64.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + amdgpu.sparse_mfma 16x16x64 %arg6 * %arg7 + %arg8 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32> + + // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32> + // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32> + // CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x64 %arg10 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32> + + // CHECK: llvm.bitcast {{.*}} : vector<8xi8> to vector<2xi32> + // CHECK: llvm.bitcast {{.*}} : vector<16xi8> to vector<4xi32> + // CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x64 %arg11 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32> + + // CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x64 %arg10 * %arg13 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<4xf32> + + // CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.sparse_mfma 16x16x64 %arg11 * %arg12 + %arg2 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<4xf32> + + // CHECK: rocdl.smfmac.i32.32x32x32.i8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + amdgpu.sparse_mfma 32x32x32 %arg6 * %arg7 + %arg9 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<16xi32> + + // CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x32 %arg10 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x32 %arg11 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x32 %arg10 * %arg13 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E4M3FNUZ>, vector<16xf8E5M2FNUZ>, vector<16xf32> + + // CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8{{.*}}: (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.sparse_mfma 32x32x32 %arg11 * %arg12 + %arg3 sparse(%arg15 : vector<2xi16>) { abid = 0 : i32, cbsz = 0 : i32 } : vector<8xf8E5M2FNUZ>, vector<16xf8E4M3FNUZ>, vector<16xf32> + + func.return +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir index 9ece57e9ec6a3..1299f3b14b14f 100644 --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -452,3 +452,67 @@ func.func @make_gather_dma_descriptor_invalid_index_types(%base: !amdgpu.tdm_gat amdgpu.make_gather_dma_descriptor %base[%indices] globalSize [4, 4] globalStride [1, 1] sharedSize [1, 2] : !amdgpu.tdm_gather_base, vector<8xi32> -> !amdgpu.tdm_descriptor func.return } + +// ----- + +func.func @sparse_mfma_dense_not_double_sparse(%a: vector<4xf16>, %b: vector<4xf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> { + // expected-error@+1 {{'amdgpu.sparse_mfma' op operand #1 must be vector of 16-bit float values of length 8/16 or vector of bfloat16 type values of length 8/16 or vector of 8-bit signless integer values of length 16/32 or vector of f8E4M3FN type or f8E5M2 type values of length 16/32 or vector of f8E4M3FNUZ type or f8E5M2FNUZ type values of length 16/32, but got 'vector<4xf16>'}} + %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<4xf16>, vector<4xf32> + func.return %d : vector<4xf32> +} + +// ----- + +func.func @sparse_mfma_mismatched_source_types(%a: vector<4xf16>, %b: vector<8xbf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> { + // expected-error@+1 {{'amdgpu.sparse_mfma' op expected source operands to have the same element type}} + %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xbf16>, vector<4xf32> + func.return %d : vector<4xf32> +} + +// ----- + +func.func @sparse_mfma_abid_invalid_for_8bit(%a: vector<8xi8>, %b: vector<16xi8>, %c: vector<4xi32>, %idx: vector<2xi16>) -> vector<4xi32> { + // expected-error@+1 {{'amdgpu.sparse_mfma' op ABID must be 0 or 1 for 8-bit source data}} + %d = amdgpu.sparse_mfma 16x16x64 %a * %b + %c sparse(%idx : vector<2xi16>) { abid = 2 : i32, cbsz = 0 : i32 } : vector<8xi8>, vector<16xi8>, vector<4xi32> + func.return %d : vector<4xi32> +} + +// ----- + +func.func @sparse_mfma_abid_invalid_for_16bit(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<4xf32>, %idx: vector<4xi8>) -> vector<4xf32> { + // expected-error@+1 {{'amdgpu.sparse_mfma' op ABID must be between 0 and 3 for 16-bit source data}} + %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) { abid = 4 : i32, cbsz = 0 : i32 } : vector<4xf16>, vector<8xf16>, vector<4xf32> + func.return %d : vector<4xf32> +} + +// ----- + +func.func @sparse_mfma_wrong_idx_type_for_8bit(%a: vector<8xi8>, %b: vector<16xi8>, %c: vector<4xi32>, %idx: vector<4xi8>) -> vector<4xi32> { + // expected-error@+1 {{'amdgpu.sparse_mfma' op expected vector<2xi16> sparse indices for 8-bit source data, but got 'vector<4xi8>'}} + %d = amdgpu.sparse_mfma 16x16x64 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<8xi8>, vector<16xi8>, vector<4xi32> + func.return %d : vector<4xi32> +} + +// ----- + +func.func @sparse_mfma_wrong_idx_type_for_16bit(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<4xf32>, %idx: vector<2xi16>) -> vector<4xf32> { + // expected-error@+1 {{'amdgpu.sparse_mfma' op expected vector<4xi8> sparse indices for 16-bit source data, but got 'vector<2xi16>'}} + %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<2xi16>) : vector<4xf16>, vector<8xf16>, vector<4xf32> + func.return %d : vector<4xf32> +} + +// ----- + +func.func @sparse_mfma_wrong_source_count(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<16xf32>, %idx: vector<4xi8>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.sparse_mfma' op expected 16 source values for this operation but got 8}} + %d = amdgpu.sparse_mfma 32x32x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32> + func.return %d : vector<16xf32> +} + +// ----- + +func.func @sparse_mfma_wrong_dest_count(%a: vector<4xf16>, %b: vector<8xf16>, %c: vector<16xf32>, %idx: vector<4xi8>) -> vector<16xf32> { + // expected-error@+1 {{'amdgpu.sparse_mfma' op expected 4 result values for this operation but got 16}} + %d = amdgpu.sparse_mfma 16x16x32 %a * %b + %c sparse(%idx : vector<4xi8>) : vector<4xf16>, vector<8xf16>, vector<16xf32> + func.return %d : vector<16xf32> +}