-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][AMDGPU] Add scaled wmma ops for gfx1250 #169854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-amdgpu Author: Justin Rosner (justinrosner) ChangesThis PR adds scaled WMMA ops (available on gfx1250) and the lowering to ROCDL. Patch is 24.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169854.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e07c72b839e7c..a2201d3127370 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -951,6 +951,13 @@ def MFMAOutTypes : AnyTypeOf<[F64,
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
+
+// scaled_wmma
+def ScaledWMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+ VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+ VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
+
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
@@ -1218,6 +1225,59 @@ def AMDGPU_ScaledMFMAOp :
let hasCanonicalizer = 1;
}
+def AMDGPU_ScaledWMMAOp :
+ AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+ ScaledWMMAInTypes:$sourceA,
+ ScaledWMMAInTypes:$sourceB,
+ ScaledWMMAOutTypes:$destC,
+ AnyTypeOf<[I32, I64]>:$scaleA,
+ AnyTypeOf<[I32, I64]>:$scaleB,
+ DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+ DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+ DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+ DefaultValuedAttr<I32Attr, "0">:$fmtScaleB
+ )>,
+ Results<(outs ScaledWMMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for RDNA scaled wmma instructions";
+ let description = [{
+ The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
+ `wmma` instructions in the RDNA architecture. These instructions perform
+ matrix multiplication with per-block scaling of inputs, supporting fp4, fp6,
+ and fp8 data formats.
+
+ The scale instructions support two tile sizes:
+ - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
+ - 32x16x128 with f4 format only (output: vector<8xf32>)
+
+ The `scaleA` and `scaleB` parameters are scale exponents that can be either
+ i32 (for wmma.scale) or i64 (for wmma.scale16) to support per-block scaling.
+
+ Optional modifiers:
+ - `scaleAType`, `scaleBType`: Type of scale parameter
+ - `fmtScaleA`, `fmtScaleB`: Format of scale parameter
+
+ Example:
+ ```mlir
+ %0 = amdgpu.scaled_wmma (%sa * %matA) * (%sb * %matB) + %matC
+ { m = 16, n = 16, k = 128 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+ %1 = amdgpu.scaled_wmma (%sc * %matD) * (%sd * %matE) + %matF
+ { m = 32, n = 16, k = 128 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+ ```
+ }];
+ let assemblyFormat = [{
+ `(` $scaleA `*` $sourceA `)` `*` `(` $scaleB `*` $sourceB `)` `+` $destC
+ attr-dict
+ `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_MakeDmaBaseOp :
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b9a5e7d7f6eac..f4034f44d06b8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
-/// and LLVM AMDGPU intrinsics convention.
+/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
+/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
///
/// Specifically:
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
-static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input,
- bool allowBf16 = true) {
+static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
@@ -918,7 +918,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return std::nullopt;
}
-static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
.Case([](Float8E4M3FNType) { return 0u; })
.Case([](Float8E5M2Type) { return 1u; })
@@ -947,8 +947,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
if (!isa<Float32Type>(destType))
return std::nullopt;
- std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
- std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+ std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
+ std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
if (!aTypeCode || !bTypeCode)
return std::nullopt;
@@ -1212,11 +1212,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
- loweredOp.addOperands({convertMFMAVectorOperand(
- rewriter, loc, adaptor.getSourceA(), allowBf16),
- convertMFMAVectorOperand(
- rewriter, loc, adaptor.getSourceB(), allowBf16),
- adaptor.getDestC()});
+ loweredOp.addOperands(
+ {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
+ allowBf16),
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
+ allowBf16),
+ adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1261,8 +1262,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
- {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
- convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+ {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
Value scalesIdxA =
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
@@ -1363,6 +1364,103 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+ ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestD().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ if (chipset < Chipset(12, 5, 0))
+ return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+ int64_t m = op.getM();
+ int64_t n = op.getN();
+ int64_t k = op.getK();
+
+ Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+ Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+ std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+ std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+ if (!aFmtCode || !bFmtCode)
+ return op.emitOpError("unsupported element types for scaled_wmma");
+
+ // Determine which intrinsic to use based on dimensions and scale type
+ StringRef intrinsicName;
+ bool isScale16 = adaptor.getScaleA().getType().isInteger(64);
+ bool is32x16 = (m == 32 && n == 16 && k == 128);
+
+ if (m == 16 && n == 16 && k == 128) {
+ intrinsicName = isScale16
+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+ } else if (is32x16) {
+ intrinsicName = isScale16
+ ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+ } else {
+ return op.emitOpError("unsupported scaled_wmma dimensions: ")
+ << m << "x" << n << "x" << k;
+ }
+
+ SmallVector<NamedAttribute, 8> attrs;
+
+ // The f4 variant does not have fmtA and fmtB attributes
+ if (!is32x16) {
+ attrs.push_back(rewriter.getNamedAttr("fmtA",
+ rewriter.getI32IntegerAttr(*aFmtCode)));
+ attrs.push_back(rewriter.getNamedAttr("fmtB",
+ rewriter.getI32IntegerAttr(*bFmtCode)));
+ }
+
+ // Add modifier attributes - modC and reuse flags default to 0/false
+ attrs.push_back(rewriter.getNamedAttr("reuseA",
+ rewriter.getBoolAttr(false)));
+ attrs.push_back(rewriter.getNamedAttr("reuseB",
+ rewriter.getBoolAttr(false)));
+ attrs.push_back(rewriter.getNamedAttr("modC",
+ rewriter.getI16IntegerAttr(0)));
+
+ // Scale type/format parameters from the operation
+ attrs.push_back(rewriter.getNamedAttr("scaleAType",
+ rewriter.getI32IntegerAttr(op.getScaleAType())));
+ attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
+ rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+ attrs.push_back(rewriter.getNamedAttr("scaleBType",
+ rewriter.getI32IntegerAttr(op.getScaleBType())));
+ attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
+ rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+
+ // Convert typed float vectors to packed i32 format if needed
+ Value sourceA =
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
+ Value sourceB =
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
+
+ // Create the intrinsic call
+ OperationState loweredOp(loc, intrinsicName);
+ loweredOp.addTypes(outType);
+ loweredOp.addOperands({sourceA, sourceB, adaptor.getDestC(),
+ adaptor.getScaleA(), adaptor.getScaleB()});
+ loweredOp.addAttributes(attrs);
+
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered->getResults());
+
+ return success();
+ }
+};
+
struct TransposeLoadOpLowering
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
@@ -2329,7 +2427,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
+ WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
+ ScaledExtPacked816OpLowering,
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering, TransposeLoadOpLowering,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cdc10c60a42ae..87bd1903290ae 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -442,6 +442,77 @@ LogicalResult WMMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ // Validate output type is F32
+ if (!destType.getElementType().isF32())
+ return emitOpError("destination must have f32 element type");
+
+ // Validate source element types are small floats (fp4/fp6/fp8)
+ Type aElemType = sourceAType.getElementType();
+ Type bElemType = sourceBType.getElementType();
+
+ bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
+ aElemType.isFloat(8);
+ bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
+ bElemType.isFloat(8);
+
+ if (!aIsSmallFloat || !bIsSmallFloat)
+ return emitOpError("source operands must have small float element types "
+ "(fp4/fp6/fp8)");
+
+ // Validate scale types match (both i32 or both i64)
+ Type scaleAType = getScaleA().getType();
+ Type scaleBType = getScaleB().getType();
+ if (scaleAType != scaleBType)
+ return emitOpError("scaleA and scaleB must have the same type");
+
+ // Validate vector lengths based on dimensions
+ int64_t m = getM();
+ int64_t aLen = sourceAType.getNumElements();
+ int64_t bLen = sourceBType.getNumElements();
+ int64_t expectedOutLen = (m == 16) ? 4 : 8;
+
+ if (destType.getNumElements() != expectedOutLen)
+ return emitOpError("expected output vector of length " +
+ Twine(expectedOutLen) + " but got " +
+ Twine(destType.getNumElements()));
+
+ if (m == 16) {
+ // For 16×16×128: both A and B must be 64 elements
+ if (aLen != 64)
+ return emitOpError(
+ "for 16x16x128, sourceA must have 64 elements but got " +
+ Twine(aLen));
+ if (bLen != 64)
+ return emitOpError(
+ "for 16x16x128, sourceB must have 64 elements but got " +
+ Twine(bLen));
+ } else { // m == 32
+ // For 32×16×128: only fp4 is supported, A is 128, B is 64
+ if (!aElemType.isFloat(4))
+ return emitOpError("32x16x128 only supports fp4 element types");
+
+ if (aLen != 128)
+ return emitOpError(
+ "for 32x16x128, sourceA must have 128 elements but got " +
+ Twine(aLen));
+ if (bLen != 64)
+ return emitOpError(
+ "for 32x16x128, sourceB must have 64 elements but got " +
+ Twine(bLen));
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// MFMAOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 37259f6ed06eb..d187e62484059 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -89,6 +89,79 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
return
}
+// CHECK-LABEL: @wmma_scale_16x16x128_fp8
+func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+ %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 1 : i32, fmtB = 1 : i32} : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_fp6
+func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
+ %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E3M2FN>, i32, vector<64xf6E3M2FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_mixed
+func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+ %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<4xf32>,
+ %arg4 : i32, %arg5 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtB = 2 : i32} : (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg4 * %arg0) * (%arg5 * %arg1) + %arg3
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtA = 2 : i32, fmtB = 4 : i32} : (vector<12xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf4E2M1FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_16x16x128_fp8
+func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+ %arg2 : vector<4xf32>, %arg3 : i64, %arg4 : i64) {
+ // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_32x16x128_fp4
+func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+ ...
[truncated]
|
|
@llvm/pr-subscribers-backend-amdgpu Author: Justin Rosner (justinrosner) ChangesThis PR adds scaled WMMA ops (available on gfx1250) and the lowering to ROCDL. Patch is 24.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169854.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e07c72b839e7c..a2201d3127370 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -951,6 +951,13 @@ def MFMAOutTypes : AnyTypeOf<[F64,
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
+
+// scaled_wmma
+def ScaledWMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+ VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+ VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
+
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
@@ -1218,6 +1225,59 @@ def AMDGPU_ScaledMFMAOp :
let hasCanonicalizer = 1;
}
+def AMDGPU_ScaledWMMAOp :
+ AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>,
+ Pure]>,
+ Arguments<(ins
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+ ScaledWMMAInTypes:$sourceA,
+ ScaledWMMAInTypes:$sourceB,
+ ScaledWMMAOutTypes:$destC,
+ AnyTypeOf<[I32, I64]>:$scaleA,
+ AnyTypeOf<[I32, I64]>:$scaleB,
+ DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+ DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+ DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+ DefaultValuedAttr<I32Attr, "0">:$fmtScaleB
+ )>,
+ Results<(outs ScaledWMMAOutTypes: $destD)> {
+ let summary = "MLIR wrapper for RDNA scaled wmma instructions";
+ let description = [{
+ The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
+ `wmma` instructions in the RDNA architecture. These instructions perform
+ matrix multiplication with per-block scaling of inputs, supporting fp4, fp6,
+ and fp8 data formats.
+
+ The scale instructions support two tile sizes:
+ - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
+ - 32x16x128 with f4 format only (output: vector<8xf32>)
+
+ The `scaleA` and `scaleB` parameters are scale exponents that can be either
+ i32 (for wmma.scale) or i64 (for wmma.scale16) to support per-block scaling.
+
+ Optional modifiers:
+ - `scaleAType`, `scaleBType`: Type of scale parameter
+ - `fmtScaleA`, `fmtScaleB`: Format of scale parameter
+
+ Example:
+ ```mlir
+ %0 = amdgpu.scaled_wmma (%sa * %matA) * (%sb * %matB) + %matC
+ { m = 16, n = 16, k = 128 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+ %1 = amdgpu.scaled_wmma (%sc * %matD) * (%sd * %matE) + %matF
+ { m = 32, n = 16, k = 128 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+ ```
+ }];
+ let assemblyFormat = [{
+ `(` $scaleA `*` $sourceA `)` `*` `(` $scaleB `*` $sourceB `)` `+` $destC
+ attr-dict
+ `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
+ }];
+ let hasVerifier = 1;
+}
+
def AMDGPU_MakeDmaBaseOp :
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b9a5e7d7f6eac..f4034f44d06b8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
-/// and LLVM AMDGPU intrinsics convention.
+/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
+/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
///
/// Specifically:
/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
/// Note that the type of `input` has already been LLVM type converted:
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
-static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input,
- bool allowBf16 = true) {
+static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input,
+ bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16() && !allowBf16)
@@ -918,7 +918,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
return std::nullopt;
}
-static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
.Case([](Float8E4M3FNType) { return 0u; })
.Case([](Float8E5M2Type) { return 1u; })
@@ -947,8 +947,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
if (!isa<Float32Type>(destType))
return std::nullopt;
- std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
- std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+ std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
+ std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
if (!aTypeCode || !bTypeCode)
return std::nullopt;
@@ -1212,11 +1212,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
}();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
- loweredOp.addOperands({convertMFMAVectorOperand(
- rewriter, loc, adaptor.getSourceA(), allowBf16),
- convertMFMAVectorOperand(
- rewriter, loc, adaptor.getSourceB(), allowBf16),
- adaptor.getDestC()});
+ loweredOp.addOperands(
+ {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
+ allowBf16),
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
+ allowBf16),
+ adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1261,8 +1262,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
- {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
- convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+ {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC()});
Value scalesIdxA =
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
@@ -1363,6 +1364,103 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+ ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+ : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+ Chipset chipset;
+
+ LogicalResult
+ matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ auto outType =
+ typeConverter->convertType<VectorType>(op.getDestD().getType());
+ if (!outType)
+ return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+ if (chipset < Chipset(12, 5, 0))
+ return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+ int64_t m = op.getM();
+ int64_t n = op.getN();
+ int64_t k = op.getK();
+
+ Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+ Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+ std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+ std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+ if (!aFmtCode || !bFmtCode)
+ return op.emitOpError("unsupported element types for scaled_wmma");
+
+ // Determine which intrinsic to use based on dimensions and scale type
+ StringRef intrinsicName;
+ bool isScale16 = adaptor.getScaleA().getType().isInteger(64);
+ bool is32x16 = (m == 32 && n == 16 && k == 128);
+
+ if (m == 16 && n == 16 && k == 128) {
+ intrinsicName = isScale16
+ ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+ : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+ } else if (is32x16) {
+ intrinsicName = isScale16
+ ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+ : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+ } else {
+ return op.emitOpError("unsupported scaled_wmma dimensions: ")
+ << m << "x" << n << "x" << k;
+ }
+
+ SmallVector<NamedAttribute, 8> attrs;
+
+ // The f4 variant does not have fmtA and fmtB attributes
+ if (!is32x16) {
+ attrs.push_back(rewriter.getNamedAttr("fmtA",
+ rewriter.getI32IntegerAttr(*aFmtCode)));
+ attrs.push_back(rewriter.getNamedAttr("fmtB",
+ rewriter.getI32IntegerAttr(*bFmtCode)));
+ }
+
+ // Add modifier attributes - modC and reuse flags default to 0/false
+ attrs.push_back(rewriter.getNamedAttr("reuseA",
+ rewriter.getBoolAttr(false)));
+ attrs.push_back(rewriter.getNamedAttr("reuseB",
+ rewriter.getBoolAttr(false)));
+ attrs.push_back(rewriter.getNamedAttr("modC",
+ rewriter.getI16IntegerAttr(0)));
+
+ // Scale type/format parameters from the operation
+ attrs.push_back(rewriter.getNamedAttr("scaleAType",
+ rewriter.getI32IntegerAttr(op.getScaleAType())));
+ attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
+ rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+ attrs.push_back(rewriter.getNamedAttr("scaleBType",
+ rewriter.getI32IntegerAttr(op.getScaleBType())));
+ attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
+ rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+
+ // Convert typed float vectors to packed i32 format if needed
+ Value sourceA =
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
+ Value sourceB =
+ packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
+
+ // Create the intrinsic call
+ OperationState loweredOp(loc, intrinsicName);
+ loweredOp.addTypes(outType);
+ loweredOp.addOperands({sourceA, sourceB, adaptor.getDestC(),
+ adaptor.getScaleA(), adaptor.getScaleB()});
+ loweredOp.addAttributes(attrs);
+
+ Operation *lowered = rewriter.create(loweredOp);
+ rewriter.replaceOp(op, lowered->getResults());
+
+ return success();
+ }
+};
+
struct TransposeLoadOpLowering
: public ConvertOpToLLVMPattern<TransposeLoadOp> {
TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
@@ -2329,7 +2427,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
- WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
+ WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
+ ScaledExtPacked816OpLowering,
ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering, TransposeLoadOpLowering,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cdc10c60a42ae..87bd1903290ae 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -442,6 +442,77 @@ LogicalResult WMMAOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+ auto sourceAType = cast<VectorType>(getSourceA().getType());
+ auto sourceBType = cast<VectorType>(getSourceB().getType());
+ auto destType = cast<VectorType>(getDestC().getType());
+
+ // Validate output type is F32
+ if (!destType.getElementType().isF32())
+ return emitOpError("destination must have f32 element type");
+
+ // Validate source element types are small floats (fp4/fp6/fp8)
+ Type aElemType = sourceAType.getElementType();
+ Type bElemType = sourceBType.getElementType();
+
+ bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
+ aElemType.isFloat(8);
+ bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
+ bElemType.isFloat(8);
+
+ if (!aIsSmallFloat || !bIsSmallFloat)
+ return emitOpError("source operands must have small float element types "
+ "(fp4/fp6/fp8)");
+
+ // Validate scale types match (both i32 or both i64)
+ Type scaleAType = getScaleA().getType();
+ Type scaleBType = getScaleB().getType();
+ if (scaleAType != scaleBType)
+ return emitOpError("scaleA and scaleB must have the same type");
+
+ // Validate vector lengths based on dimensions
+ int64_t m = getM();
+ int64_t aLen = sourceAType.getNumElements();
+ int64_t bLen = sourceBType.getNumElements();
+ int64_t expectedOutLen = (m == 16) ? 4 : 8;
+
+ if (destType.getNumElements() != expectedOutLen)
+ return emitOpError("expected output vector of length " +
+ Twine(expectedOutLen) + " but got " +
+ Twine(destType.getNumElements()));
+
+ if (m == 16) {
+ // For 16×16×128: both A and B must be 64 elements
+ if (aLen != 64)
+ return emitOpError(
+ "for 16x16x128, sourceA must have 64 elements but got " +
+ Twine(aLen));
+ if (bLen != 64)
+ return emitOpError(
+ "for 16x16x128, sourceB must have 64 elements but got " +
+ Twine(bLen));
+ } else { // m == 32
+ // For 32×16×128: only fp4 is supported, A is 128, B is 64
+ if (!aElemType.isFloat(4))
+ return emitOpError("32x16x128 only supports fp4 element types");
+
+ if (aLen != 128)
+ return emitOpError(
+ "for 32x16x128, sourceA must have 128 elements but got " +
+ Twine(aLen));
+ if (bLen != 64)
+ return emitOpError(
+ "for 32x16x128, sourceB must have 64 elements but got " +
+ Twine(bLen));
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// MFMAOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 37259f6ed06eb..d187e62484059 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -89,6 +89,79 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
return
}
+// CHECK-LABEL: @wmma_scale_16x16x128_fp8
+func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+ %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 1 : i32, fmtB = 1 : i32} : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_fp6
+func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
+ %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E3M2FN>, i32, vector<64xf6E3M2FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_mixed
+func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+ %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<4xf32>,
+ %arg4 : i32, %arg5 : i32) {
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtB = 2 : i32} : (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg4 * %arg0) * (%arg5 * %arg1) + %arg3
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+ // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtA = 2 : i32, fmtB = 4 : i32} : (vector<12xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+ %1 = amdgpu.scaled_wmma (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf4E2M1FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_16x16x128_fp8
+func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+ %arg2 : vector<4xf32>, %arg3 : i64, %arg4 : i64) {
+ // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+ %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+ { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+
+ func.return
+}
+
+// CHECK-LABEL: @wmma_scale_32x16x128_fp4
+func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+ ...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The design is too low level (exposes constants that are just types at the MLIR level and doesn't properly expose the underlying types of the "i32" scales) and also there are no tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have a few nits but other than that it looks great!
edit: +1 to what Krzysztof said
Muzammiluddin-Syed-ECE
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good, just one last question for me
| // Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M2|E4M3) | ||
| if (aIsF8F6 && isE8M0(scaleAElemType) && bIsF4 && (isE4M3(scaleBElemType))) | ||
| return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we're not explicitly checking for B scale type E5M3 in this combination of legal A and B matrix data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't seem to have support for E5M3 yet in MLIR. See my comment here: #169854 (comment).
I was just waiting to see if there was a preference for implementing that in this PR, or in a separate one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see! I can't seem to find an issue or anything tracking support for F8E5M3, maybe we can create an issue and add a TODO here to document that this needs to be done.
Muzammiluddin-Syed-ECE
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! But let's hold off on merging until @krzysz00 gets a chance to review.
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks OK but we need some negative tests to exercise the verification logic (with // expected-error)
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but let's wait for @krzysz00
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, I'm liking this a lot better, thanks! (I think we're mainly in the nitpick zone)
| ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB, | ||
| ScaledWMMAOutTypes:$destC, | ||
| VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA, | ||
| ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleAIdx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you go look at the conversion intrinsics, these scale indices have a much more non-trivial mapping. I think the name is fine, but I'd suggest we document the semantics of this attribute carefully.
Also, nit, scale_a_idx to make the generic pretty-print a bit nicer (but that's not uniform style, you're welcome to ignore it)
|
|
||
| Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values | ||
| (either f8E8M0FNU, or f8E4M3FN). The index attributes (`scaleAIdx`, `scaleBIdx`) | ||
| select which element from the scale vector to use for scaling. During lowering, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is correct
Remember that the layout of the scales starts (at least for one of the vaules of this bit)
lane 0: (m=0, kOuter=0) (0, 1), (0, 2), (0, 3) ...
and we should really write the formulas out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated the description of what the indices are doing.
| attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt)); | ||
|
|
||
| // Reuse flags use default value of false. | ||
| attrs.emplace_back("reuseA", rewriter.getBoolAttr(false)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a plan to enable these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a concrete one as of yet. I was debating between exposing these attributes at the AMDGPU dialect level to give users the ability to set it themselves, or implementing a pass that detects when the same matrix value is used in multiple WMMA ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, if you had such a pass, it'd be at the AMDGPU level anyway
... in theory we could teach LLVM about this but they've got enough headaches already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you okay with opening up an issue to implement this in a separate PR? Or would you prefer it be included in this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yeah, separate PR"s fine - this isn't critical
| VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA, | ||
| ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleAIdx, | ||
| VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB, | ||
| ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleBIdx)>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it only 0 or 1? If it is packed into 32 bit it can also be 2 or 3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not represent a scale select like in a scaled_mfma where it is used to indicate which of the 4 scales should be used. Instead, here it is used to refer to MatrixScale an enum defined here and is used in the intrinsic (how? im not sure yet).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is probably not going to want [] syntax unlike the gfx950 version, because it's a weird global modifier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed the [] syntax and just have the indices in the attribute list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so, pending some other clarifications, these indices are something along the lines of wave_half_select - that is, 0 means you're reading from lanes 0-15 for scales, and 1 means you're on lanes 16-32
They're only applicable to some modes of the instruction (block size 32 and 16x16xN matrices - the ones where you pass 64 elements per lane, I think).
So I'll argue for a_first_scale_lane and b_first_scale_lane ... and also, we should go fix scaled_ext_packed816 to use a 0/16 numbering scheme and not a 0/1 one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use a_first_scale_lane and b_first_scale_lane. Also, I saw that you updated #170718 already.
| attrs.emplace_back("scaleAType", | ||
| rewriter.getI32IntegerAttr(op.getScaleAIdx())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is mapping Type with Idx ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
scaleAType is the name that the LLVM intrinsic uses to represent the lane selection logic (don't know why this name was chosen). See the updated AMDGPU.td description for how these indices are used.
57c19ea to
0508304
Compare
0508304 to
e5579b9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for scaled WMMA (Wave Matrix Multiply-Accumulate) operations for the gfx1250 AMD GPU architecture. These operations enable matrix multiplication with per-block scaling of inputs using low-precision floating-point formats (fp4, fp6, fp8).
Key changes:
- Adds
amdgpu.scaled_wmmaoperation with support for 16x16x128 and 32x16x128 tile dimensions - Implements verification logic to validate operand types, dimensions, and scale compatibility
- Provides lowering from AMDGPU dialect to ROCDL intrinsics with proper operand packing
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/Dialect/AMDGPU/ops.mlir | Adds test cases for the new scaled_wmma operation covering various type combinations and dimensions |
| mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir | Adds conversion tests showing proper lowering to ROCDL intrinsics, including error cases for invalid configurations |
| mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | Implements verification logic for ScaledWMMAOp to validate dimensions, operand types, and scale type combinations |
| mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | Implements lowering pattern for ScaledWMMAOp and refactors helper functions to support both MFMA and WMMA operations; contains unresolved merge conflict |
| mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | Defines the AMDGPU_ScaledWMMAOp operation with comprehensive documentation of supported formats and scaling behavior |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
2714938 to
e330cc1
Compare
| attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt)); | ||
|
|
||
| // Reuse flags use default value of false. | ||
| attrs.emplace_back("reuseA", rewriter.getBoolAttr(false)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, if you had such a pass, it'd be at the AMDGPU level anyway
... in theory we could teach LLVM about this but they've got enough headaches already.
| ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB, | ||
| ScaledWMMAOutTypes:$destC, | ||
| VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA, | ||
| ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$a_first_scale_lane, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, I'd like to call for a change similar to scaled_ext_packed_matrix's recent change, where we make the value here 0 or 16, not 1. That way, it's actually indexing into the first lane that'll be read from, instead of awkwardly gesturing at a wave half
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
| lanes provide scale values: | ||
| - Block size 32: For tile size 16x16x128, each matrix gets 64 scales stored in half | ||
| a VGPR, with `a_first_scale_lane`/`b_first_scale_lane` selecting lanes 0-15 (index=0) or | ||
| 16-31 (index=1). For a tile size of 32x16x128, matrix A gets 128 scales in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And updating this if the suggestion goes through
| lowering. The index attributes (`a_first_scale_lane`, `b_first_scale_lane`) select which register | ||
| lanes provide scale values: | ||
| - Block size 32: For tile size 16x16x128, each matrix gets 64 scales stored in half | ||
| a VGPR, with `a_first_scale_lane`/`b_first_scale_lane` selecting lanes 0-15 (index=0) or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... "half a VGPR" is, I don't think, correct, and also isn't the right level of abstraction here.
The scales are passed in as vectors of bytes, refer to elements of the vector on a particular lane.
Also, formulas, perhaps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the description. Let me know if there are additional formulas that you want added.
| a VGPR, with `a_first_scale_lane`/`b_first_scale_lane` selecting lanes 0-15 (index=0) or | ||
| 16-31 (index=1). For a tile size of 32x16x128, matrix A gets 128 scales in | ||
| a full VGPR (`a_first_scale_lane` is unused), while matrix B gets 64 scales in | ||
| half a VGPR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure this is wrong.
Also, as a validation, the scale index isn't unused, I'm pretty sure it must be 0 and non-zero values are reserved. So we need to be checking that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the descriptions.
8e6f75f to
0e21e23
Compare
| .Default([](Type) -> Value { | ||
| llvm_unreachable("unexpected input type for scale operand"); | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| .Default([](Type) -> Value { | |
| llvm_unreachable("unexpected input type for scale operand"); | |
| }); | |
| .DefaultUnreachable("unexpected input type for scale operand"); |
| static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc, | ||
| Value input) { | ||
| return TypeSwitch<Type, Value>(input.getType()) | ||
| .Case<IntegerType>([&](IntegerType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| .Case<IntegerType>([&](IntegerType) { | |
| .Case([&](IntegerType) { |
| return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(), | ||
| input); | ||
| }) | ||
| .Case<VectorType>([&](VectorType vectorType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| .Case<VectorType>([&](VectorType vectorType) { | |
| .Case([&](VectorType vectorType) { |
This PR adds scaled WMMA ops (available on gfx1250) and the lowering to ROCDL.