Skip to content

Conversation

@justinrosner
Copy link
Contributor

This PR adds scaled WMMA ops (available on gfx1250) and the lowering to ROCDL.

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-amdgpu

Author: Justin Rosner (justinrosner)

Changes

This PR adds scaled WMMA ops (available on gfx1250) and the lowering to ROCDL.


Patch is 24.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169854.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+60)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+115-16)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+71)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir (+73)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+26)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e07c72b839e7c..a2201d3127370 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -951,6 +951,13 @@ def MFMAOutTypes : AnyTypeOf<[F64,
 def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
+
+// scaled_wmma
+def ScaledWMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+                                   VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+                                   VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
+
 // wmma
 def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
                              VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
@@ -1218,6 +1225,59 @@ def AMDGPU_ScaledMFMAOp :
   let hasCanonicalizer = 1;
 }
 
+def AMDGPU_ScaledWMMAOp :
+    AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>,
+                              Pure]>,
+    Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+                   ScaledWMMAInTypes:$sourceA,
+                   ScaledWMMAInTypes:$sourceB,
+                   ScaledWMMAOutTypes:$destC,
+                   AnyTypeOf<[I32, I64]>:$scaleA,
+                   AnyTypeOf<[I32, I64]>:$scaleB,
+                   DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+                   DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+                   DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+                   DefaultValuedAttr<I32Attr, "0">:$fmtScaleB
+                   )>,
+    Results<(outs ScaledWMMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for RDNA scaled wmma instructions";
+  let description = [{
+    The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
+    `wmma` instructions in the RDNA architecture. These instructions perform
+    matrix multiplication with per-block scaling of inputs, supporting fp4, fp6,
+    and fp8 data formats.
+
+    The scale instructions support two tile sizes:
+    - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
+    - 32x16x128 with f4 format only (output: vector<8xf32>)
+
+    The `scaleA` and `scaleB` parameters are scale exponents that can be either
+    i32 (for wmma.scale) or i64 (for wmma.scale16) to support per-block scaling.
+
+    Optional modifiers:
+    - `scaleAType`, `scaleBType`: Type of scale parameter
+    - `fmtScaleA`, `fmtScaleB`: Format of scale parameter
+
+    Example:
+    ```mlir
+      %0 = amdgpu.scaled_wmma (%sa * %matA) * (%sb * %matB) + %matC
+        { m = 16, n = 16, k = 128 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+      %1 = amdgpu.scaled_wmma (%sc * %matD) * (%sd * %matE) + %matF
+        { m = 32, n = 16, k = 128 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+    ```
+  }];
+  let assemblyFormat = [{
+    `(` $scaleA `*` $sourceA `)` `*` `(` $scaleB `*` $sourceB `)` `+` $destC
+    attr-dict
+    `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_MakeDmaBaseOp :
     AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
     Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b9a5e7d7f6eac..f4034f44d06b8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
-/// and LLVM AMDGPU intrinsics convention.
+/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
+/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
 ///
 /// Specifically:
 /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// Note that the type of `input` has already been LLVM type converted:
 /// therefore 8-bit and smaller floats are represented as their corresponding
 /// `iN` integers.
-static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
-                                      Location loc, Value input,
-                                      bool allowBf16 = true) {
+static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
+                                         Location loc, Value input,
+                                         bool allowBf16 = true) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16() && !allowBf16)
@@ -918,7 +918,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   return std::nullopt;
 }
 
-static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
   return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
       .Case([](Float8E4M3FNType) { return 0u; })
       .Case([](Float8E5M2Type) { return 1u; })
@@ -947,8 +947,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
   if (!isa<Float32Type>(destType))
     return std::nullopt;
 
-  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
-  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+  std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
+  std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
   if (!aTypeCode || !bTypeCode)
     return std::nullopt;
 
@@ -1212,11 +1212,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     }();
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
-    loweredOp.addOperands({convertMFMAVectorOperand(
-                               rewriter, loc, adaptor.getSourceA(), allowBf16),
-                           convertMFMAVectorOperand(
-                               rewriter, loc, adaptor.getSourceB(), allowBf16),
-                           adaptor.getDestC()});
+    loweredOp.addOperands(
+        {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
+                                      allowBf16),
+         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
+                                      allowBf16),
+         adaptor.getDestC()});
     if (isScaled) {
       Value zero = createI32Constant(rewriter, loc, 0);
       auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1261,8 +1262,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
-        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
-         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+        {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC()});
     Value scalesIdxA =
         createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
@@ -1363,6 +1364,103 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+  ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto outType =
+        typeConverter->convertType<VectorType>(op.getDestD().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    if (chipset < Chipset(12, 5, 0))
+      return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+    int64_t m = op.getM();
+    int64_t n = op.getN();
+    int64_t k = op.getK();
+
+    Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+    Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+    std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+    std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+    if (!aFmtCode || !bFmtCode)
+      return op.emitOpError("unsupported element types for scaled_wmma");
+
+    // Determine which intrinsic to use based on dimensions and scale type
+    StringRef intrinsicName;
+    bool isScale16 = adaptor.getScaleA().getType().isInteger(64);
+    bool is32x16 = (m == 32 && n == 16 && k == 128);
+
+    if (m == 16 && n == 16 && k == 128) {
+      intrinsicName = isScale16
+                ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+                : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+    } else if (is32x16) {
+      intrinsicName = isScale16
+                ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+                : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+    } else {
+      return op.emitOpError("unsupported scaled_wmma dimensions: ")
+             << m << "x" << n << "x" << k;
+    }
+
+    SmallVector<NamedAttribute, 8> attrs;
+
+    // The f4 variant does not have fmtA and fmtB attributes
+    if (!is32x16) {
+      attrs.push_back(rewriter.getNamedAttr("fmtA",
+                              rewriter.getI32IntegerAttr(*aFmtCode)));
+      attrs.push_back(rewriter.getNamedAttr("fmtB",
+                              rewriter.getI32IntegerAttr(*bFmtCode)));
+    }
+
+    // Add modifier attributes - modC and reuse flags default to 0/false
+    attrs.push_back(rewriter.getNamedAttr("reuseA",
+                              rewriter.getBoolAttr(false)));
+    attrs.push_back(rewriter.getNamedAttr("reuseB",
+                              rewriter.getBoolAttr(false)));
+    attrs.push_back(rewriter.getNamedAttr("modC",
+                              rewriter.getI16IntegerAttr(0)));
+
+    // Scale type/format parameters from the operation
+    attrs.push_back(rewriter.getNamedAttr("scaleAType",
+                              rewriter.getI32IntegerAttr(op.getScaleAType())));
+    attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
+                              rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+    attrs.push_back(rewriter.getNamedAttr("scaleBType",
+                              rewriter.getI32IntegerAttr(op.getScaleBType())));
+    attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
+                              rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+
+    // Convert typed float vectors to packed i32 format if needed
+    Value sourceA =
+        packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
+    Value sourceB =
+        packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
+
+    // Create the intrinsic call
+    OperationState loweredOp(loc, intrinsicName);
+    loweredOp.addTypes(outType);
+    loweredOp.addOperands({sourceA, sourceB, adaptor.getDestC(),
+                           adaptor.getScaleA(), adaptor.getScaleB()});
+    loweredOp.addAttributes(attrs);
+
+    Operation *lowered = rewriter.create(loweredOp);
+    rewriter.replaceOp(op, lowered->getResults());
+
+    return success();
+  }
+};
+
 struct TransposeLoadOpLowering
     : public ConvertOpToLLVMPattern<TransposeLoadOp> {
   TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
@@ -2329,7 +2427,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
            SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
-           WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
+           WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
+           ScaledExtPacked816OpLowering,
            ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
            GatherToLDSOpLowering, TransposeLoadOpLowering,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cdc10c60a42ae..87bd1903290ae 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -442,6 +442,77 @@ LogicalResult WMMAOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+  
+  // Validate output type is F32
+  if (!destType.getElementType().isF32())
+    return emitOpError("destination must have f32 element type");
+
+  // Validate source element types are small floats (fp4/fp6/fp8)
+  Type aElemType = sourceAType.getElementType();
+  Type bElemType = sourceBType.getElementType();
+
+  bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
+                       aElemType.isFloat(8);
+  bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
+                       bElemType.isFloat(8);
+
+  if (!aIsSmallFloat || !bIsSmallFloat)
+    return emitOpError("source operands must have small float element types "
+                       "(fp4/fp6/fp8)");
+
+  // Validate scale types match (both i32 or both i64)
+  Type scaleAType = getScaleA().getType();
+  Type scaleBType = getScaleB().getType();
+  if (scaleAType != scaleBType)
+    return emitOpError("scaleA and scaleB must have the same type");
+
+  // Validate vector lengths based on dimensions
+  int64_t m = getM();
+  int64_t aLen = sourceAType.getNumElements();
+  int64_t bLen = sourceBType.getNumElements();
+  int64_t expectedOutLen = (m == 16) ? 4 : 8;
+  
+  if (destType.getNumElements() != expectedOutLen)
+    return emitOpError("expected output vector of length " +
+                       Twine(expectedOutLen) + " but got " +
+                       Twine(destType.getNumElements()));
+
+  if (m == 16) {
+    // For 16×16×128: both A and B must be 64 elements
+    if (aLen != 64)
+      return emitOpError(
+          "for 16x16x128, sourceA must have 64 elements but got " +
+          Twine(aLen));
+    if (bLen != 64)
+      return emitOpError(
+          "for 16x16x128, sourceB must have 64 elements but got " +
+          Twine(bLen));
+  } else { // m == 32
+    // For 32×16×128: only fp4 is supported, A is 128, B is 64
+    if (!aElemType.isFloat(4))
+      return emitOpError("32x16x128 only supports fp4 element types");
+
+    if (aLen != 128)
+      return emitOpError(
+          "for 32x16x128, sourceA must have 128 elements but got " +
+          Twine(aLen));
+    if (bLen != 64)
+      return emitOpError(
+          "for 32x16x128, sourceB must have 64 elements but got " +
+          Twine(bLen));
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MFMAOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 37259f6ed06eb..d187e62484059 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -89,6 +89,79 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
   return
 }
 
+// CHECK-LABEL: @wmma_scale_16x16x128_fp8
+func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                                    %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 1 : i32, fmtB = 1 : i32} : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_fp6
+func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
+                                    %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E3M2FN>, i32, vector<64xf6E3M2FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_mixed
+func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+                                      %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<4xf32>,
+                                      %arg4 : i32, %arg5 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtB = 2 : i32} : (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg4 * %arg0) * (%arg5 * %arg1) + %arg3
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtA = 2 : i32, fmtB = 4 : i32} : (vector<12xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf4E2M1FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_16x16x128_fp8
+func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                                      %arg2 : vector<4xf32>, %arg3 : i64, %arg4 : i64) {
+  // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_32x16x128_fp4
+func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+                      ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Justin Rosner (justinrosner)

Changes

This PR adds scaled WMMA ops (available on gfx1250) and the lowering to ROCDL.


Patch is 24.10 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169854.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+60)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+115-16)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+71)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir (+73)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+26)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index e07c72b839e7c..a2201d3127370 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -951,6 +951,13 @@ def MFMAOutTypes : AnyTypeOf<[F64,
 def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
                                    VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
 def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
+
+// scaled_wmma
+def ScaledWMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[64], [F8E5M2, F8E4M3FN]>,
+                                   VectorOfLengthAndType<[64], [F6E2M3FN, F6E3M2FN]>,
+                                   VectorOfLengthAndType<[64, 128], [F4E2M1FN]>]>;
+def ScaledWMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32]>]>;
+
 // wmma
 def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[2], [F32]>,
                              VectorOfLengthAndType<[4, 8, 16], [F16, BF16]>,
@@ -1218,6 +1225,59 @@ def AMDGPU_ScaledMFMAOp :
   let hasCanonicalizer = 1;
 }
 
+def AMDGPU_ScaledWMMAOp :
+    AMDGPU_Op<"scaled_wmma", [AllTypesMatch<["destC", "destD"]>,
+                              Pure]>,
+    Arguments<(ins
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[16]>]>:$n,
+                   ConfinedAttr<I32Attr, [IntIsOneOf<[128]>]>:$k,
+                   ScaledWMMAInTypes:$sourceA,
+                   ScaledWMMAInTypes:$sourceB,
+                   ScaledWMMAOutTypes:$destC,
+                   AnyTypeOf<[I32, I64]>:$scaleA,
+                   AnyTypeOf<[I32, I64]>:$scaleB,
+                   DefaultValuedAttr<I32Attr, "0">:$scaleAType,
+                   DefaultValuedAttr<I32Attr, "0">:$fmtScaleA,
+                   DefaultValuedAttr<I32Attr, "0">:$scaleBType,
+                   DefaultValuedAttr<I32Attr, "0">:$fmtScaleB
+                   )>,
+    Results<(outs ScaledWMMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for RDNA scaled wmma instructions";
+  let description = [{
+    The `amdgpu.scaled_wmma` op is an MLIR wrapper around intrinsics for scaled
+    `wmma` instructions in the RDNA architecture. These instructions perform
+    matrix multiplication with per-block scaling of inputs, supporting fp4, fp6,
+    and fp8 data formats.
+
+    The scale instructions support two tile sizes:
+    - 16x16x128 with mixed f8/f6/f4 formats (output: vector<4xf32>)
+    - 32x16x128 with f4 format only (output: vector<8xf32>)
+
+    The `scaleA` and `scaleB` parameters are scale exponents that can be either
+    i32 (for wmma.scale) or i64 (for wmma.scale16) to support per-block scaling.
+
+    Optional modifiers:
+    - `scaleAType`, `scaleBType`: Type of scale parameter
+    - `fmtScaleA`, `fmtScaleB`: Format of scale parameter
+
+    Example:
+    ```mlir
+      %0 = amdgpu.scaled_wmma (%sa * %matA) * (%sb * %matB) + %matC
+        { m = 16, n = 16, k = 128 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+      %1 = amdgpu.scaled_wmma (%sc * %matD) * (%sd * %matE) + %matF
+        { m = 32, n = 16, k = 128 } : i32, vector<128xf4E2M1FN>, i32, vector<64xf4E2M1FN>, vector<8xf32>
+    ```
+  }];
+  let assemblyFormat = [{
+    `(` $scaleA `*` $sourceA `)` `*` `(` $scaleB `*` $sourceB `)` `+` $destC
+    attr-dict
+    `:` type($scaleA) `,` type($sourceA) `,` type($scaleB) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
+
 def AMDGPU_MakeDmaBaseOp :
     AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
     Arguments<(ins
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index b9a5e7d7f6eac..f4034f44d06b8 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -612,8 +612,8 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
-/// and LLVM AMDGPU intrinsics convention.
+/// Pack small float vector operands (fp4/fp6/fp8/bf16) into the format
+/// expected by scaled matrix multiply intrinsics (MFMA/WMMA).
 ///
 /// Specifically:
 /// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
@@ -627,9 +627,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 /// Note that the type of `input` has already been LLVM type converted:
 /// therefore 8-bit and smaller floats are represented as their corresponding
 /// `iN` integers.
-static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
-                                      Location loc, Value input,
-                                      bool allowBf16 = true) {
+static Value packSmallFloatVectorOperand(ConversionPatternRewriter &rewriter,
+                                         Location loc, Value input,
+                                         bool allowBf16 = true) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16() && !allowBf16)
@@ -918,7 +918,7 @@ static std::optional<StringRef> mfmaOpToIntrinsic(MFMAOp mfma,
   return std::nullopt;
 }
 
-static std::optional<uint32_t> mfmaTypeSelectCode(Type mlirElemType) {
+static std::optional<uint32_t> smallFloatTypeToFormatCode(Type mlirElemType) {
   return llvm::TypeSwitch<Type, std::optional<uint32_t>>(mlirElemType)
       .Case([](Float8E4M3FNType) { return 0u; })
       .Case([](Float8E5M2Type) { return 1u; })
@@ -947,8 +947,8 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
   if (!isa<Float32Type>(destType))
     return std::nullopt;
 
-  std::optional<uint32_t> aTypeCode = mfmaTypeSelectCode(aType);
-  std::optional<uint32_t> bTypeCode = mfmaTypeSelectCode(bType);
+  std::optional<uint32_t> aTypeCode = smallFloatTypeToFormatCode(aType);
+  std::optional<uint32_t> bTypeCode = smallFloatTypeToFormatCode(bType);
   if (!aTypeCode || !bTypeCode)
     return std::nullopt;
 
@@ -1212,11 +1212,12 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     }();
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
-    loweredOp.addOperands({convertMFMAVectorOperand(
-                               rewriter, loc, adaptor.getSourceA(), allowBf16),
-                           convertMFMAVectorOperand(
-                               rewriter, loc, adaptor.getSourceB(), allowBf16),
-                           adaptor.getDestC()});
+    loweredOp.addOperands(
+        {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA(),
+                                      allowBf16),
+         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB(),
+                                      allowBf16),
+         adaptor.getDestC()});
     if (isScaled) {
       Value zero = createI32Constant(rewriter, loc, 0);
       auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
@@ -1261,8 +1262,8 @@ struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
     OperationState loweredOp(loc, intrinsicName);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
-        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
-         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+        {packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC()});
     Value scalesIdxA =
         createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
@@ -1363,6 +1364,103 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   }
 };
 
+struct ScaledWMMAOpLowering : public ConvertOpToLLVMPattern<ScaledWMMAOp> {
+  ScaledWMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<ScaledWMMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledWMMAOp op, ScaledWMMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    auto outType =
+        typeConverter->convertType<VectorType>(op.getDestD().getType());
+    if (!outType)
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+
+    if (chipset < Chipset(12, 5, 0))
+      return op->emitOpError("WMMA scale only supported on gfx1250+");
+
+    int64_t m = op.getM();
+    int64_t n = op.getN();
+    int64_t k = op.getK();
+
+    Type aElemType = getElementTypeOrSelf(op.getSourceA().getType());
+    Type bElemType = getElementTypeOrSelf(op.getSourceB().getType());
+
+    std::optional<uint32_t> aFmtCode = smallFloatTypeToFormatCode(aElemType);
+    std::optional<uint32_t> bFmtCode = smallFloatTypeToFormatCode(bElemType);
+
+    if (!aFmtCode || !bFmtCode)
+      return op.emitOpError("unsupported element types for scaled_wmma");
+
+    // Determine which intrinsic to use based on dimensions and scale type
+    StringRef intrinsicName;
+    bool isScale16 = adaptor.getScaleA().getType().isInteger(64);
+    bool is32x16 = (m == 32 && n == 16 && k == 128);
+
+    if (m == 16 && n == 16 && k == 128) {
+      intrinsicName = isScale16
+                ? ROCDL::wmma_scale16_f32_16x16x128_f8f6f4::getOperationName()
+                : ROCDL::wmma_scale_f32_16x16x128_f8f6f4::getOperationName();
+    } else if (is32x16) {
+      intrinsicName = isScale16
+                ? ROCDL::wmma_scale16_f32_32x16x128_f4::getOperationName()
+                : ROCDL::wmma_scale_f32_32x16x128_f4::getOperationName();
+    } else {
+      return op.emitOpError("unsupported scaled_wmma dimensions: ")
+             << m << "x" << n << "x" << k;
+    }
+
+    SmallVector<NamedAttribute, 8> attrs;
+
+    // The f4 variant does not have fmtA and fmtB attributes
+    if (!is32x16) {
+      attrs.push_back(rewriter.getNamedAttr("fmtA",
+                              rewriter.getI32IntegerAttr(*aFmtCode)));
+      attrs.push_back(rewriter.getNamedAttr("fmtB",
+                              rewriter.getI32IntegerAttr(*bFmtCode)));
+    }
+
+    // Add modifier attributes - modC and reuse flags default to 0/false
+    attrs.push_back(rewriter.getNamedAttr("reuseA",
+                              rewriter.getBoolAttr(false)));
+    attrs.push_back(rewriter.getNamedAttr("reuseB",
+                              rewriter.getBoolAttr(false)));
+    attrs.push_back(rewriter.getNamedAttr("modC",
+                              rewriter.getI16IntegerAttr(0)));
+
+    // Scale type/format parameters from the operation
+    attrs.push_back(rewriter.getNamedAttr("scaleAType",
+                              rewriter.getI32IntegerAttr(op.getScaleAType())));
+    attrs.push_back(rewriter.getNamedAttr("fmtScaleA",
+                              rewriter.getI32IntegerAttr(op.getFmtScaleA())));
+    attrs.push_back(rewriter.getNamedAttr("scaleBType",
+                              rewriter.getI32IntegerAttr(op.getScaleBType())));
+    attrs.push_back(rewriter.getNamedAttr("fmtScaleB",
+                              rewriter.getI32IntegerAttr(op.getFmtScaleB())));
+
+    // Convert typed float vectors to packed i32 format if needed
+    Value sourceA =
+        packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceA());
+    Value sourceB =
+        packSmallFloatVectorOperand(rewriter, loc, adaptor.getSourceB());
+
+    // Create the intrinsic call
+    OperationState loweredOp(loc, intrinsicName);
+    loweredOp.addTypes(outType);
+    loweredOp.addOperands({sourceA, sourceB, adaptor.getDestC(),
+                           adaptor.getScaleA(), adaptor.getScaleB()});
+    loweredOp.addAttributes(attrs);
+
+    Operation *lowered = rewriter.create(loweredOp);
+    rewriter.replaceOp(op, lowered->getResults());
+
+    return success();
+  }
+};
+
 struct TransposeLoadOpLowering
     : public ConvertOpToLLVMPattern<TransposeLoadOp> {
   TransposeLoadOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
@@ -2329,7 +2427,8 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, MemoryCounterWaitOpLowering, LDSBarrierOpLowering,
            SchedBarrierOpLowering, MFMAOpLowering, ScaledMFMAOpLowering,
-           WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPacked816OpLowering,
+           WMMAOpLowering, ScaledWMMAOpLowering, ExtPackedFp8OpLowering,
+           ScaledExtPacked816OpLowering,
            ScaledExtPackedOpLowering, PackedScaledTruncOpLowering,
            PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
            GatherToLDSOpLowering, TransposeLoadOpLowering,
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index cdc10c60a42ae..87bd1903290ae 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -442,6 +442,77 @@ LogicalResult WMMAOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledWMMAOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ScaledWMMAOp::verify() {
+  auto sourceAType = cast<VectorType>(getSourceA().getType());
+  auto sourceBType = cast<VectorType>(getSourceB().getType());
+  auto destType = cast<VectorType>(getDestC().getType());
+  
+  // Validate output type is F32
+  if (!destType.getElementType().isF32())
+    return emitOpError("destination must have f32 element type");
+
+  // Validate source element types are small floats (fp4/fp6/fp8)
+  Type aElemType = sourceAType.getElementType();
+  Type bElemType = sourceBType.getElementType();
+
+  bool aIsSmallFloat = aElemType.isFloat(4) || aElemType.isFloat(6) ||
+                       aElemType.isFloat(8);
+  bool bIsSmallFloat = bElemType.isFloat(4) || bElemType.isFloat(6) ||
+                       bElemType.isFloat(8);
+
+  if (!aIsSmallFloat || !bIsSmallFloat)
+    return emitOpError("source operands must have small float element types "
+                       "(fp4/fp6/fp8)");
+
+  // Validate scale types match (both i32 or both i64)
+  Type scaleAType = getScaleA().getType();
+  Type scaleBType = getScaleB().getType();
+  if (scaleAType != scaleBType)
+    return emitOpError("scaleA and scaleB must have the same type");
+
+  // Validate vector lengths based on dimensions
+  int64_t m = getM();
+  int64_t aLen = sourceAType.getNumElements();
+  int64_t bLen = sourceBType.getNumElements();
+  int64_t expectedOutLen = (m == 16) ? 4 : 8;
+  
+  if (destType.getNumElements() != expectedOutLen)
+    return emitOpError("expected output vector of length " +
+                       Twine(expectedOutLen) + " but got " +
+                       Twine(destType.getNumElements()));
+
+  if (m == 16) {
+    // For 16×16×128: both A and B must be 64 elements
+    if (aLen != 64)
+      return emitOpError(
+          "for 16x16x128, sourceA must have 64 elements but got " +
+          Twine(aLen));
+    if (bLen != 64)
+      return emitOpError(
+          "for 16x16x128, sourceB must have 64 elements but got " +
+          Twine(bLen));
+  } else { // m == 32
+    // For 32×16×128: only fp4 is supported, A is 128, B is 64
+    if (!aElemType.isFloat(4))
+      return emitOpError("32x16x128 only supports fp4 element types");
+
+    if (aLen != 128)
+      return emitOpError(
+          "for 32x16x128, sourceA must have 128 elements but got " +
+          Twine(aLen));
+    if (bLen != 64)
+      return emitOpError(
+          "for 32x16x128, sourceB must have 64 elements but got " +
+          Twine(bLen));
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // MFMAOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
index 37259f6ed06eb..d187e62484059 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir
@@ -89,6 +89,79 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
   return
 }
 
+// CHECK-LABEL: @wmma_scale_16x16x128_fp8
+func.func @wmma_scale_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                                    %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf8E4M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 1 : i32, fmtB = 1 : i32} : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E5M2>, i32, vector<64xf8E5M2>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_fp6
+func.func @wmma_scale_16x16x128_fp6(%arg0 : vector<64xf6E2M3FN>, %arg1 : vector<64xf6E3M2FN>,
+                                    %arg2 : vector<4xf32>, %arg3 : i32, %arg4 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 2 : i32, fmtB = 2 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 {fmtA = 3 : i32, fmtB = 3 : i32} : (vector<12xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg3 * %arg1) * (%arg4 * %arg1) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E3M2FN>, i32, vector<64xf6E3M2FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_16x16x128_mixed
+func.func @wmma_scale_16x16x128_mixed(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf6E2M3FN>,
+                                      %arg2 : vector<64xf4E2M1FN>, %arg3 : vector<4xf32>,
+                                      %arg4 : i32, %arg5 : i32) {
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtB = 2 : i32} : (vector<16xi32>, vector<12xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg4 * %arg0) * (%arg5 * %arg1) + %arg3
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf8E4M3FN>, i32, vector<64xf6E2M3FN>, vector<4xf32>
+
+  // CHECK: rocdl.wmma.scale.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg3, %arg4, %arg5 {fmtA = 2 : i32, fmtB = 4 : i32} : (vector<12xi32>, vector<8xi32>, vector<4xf32>, i32, i32) -> vector<4xf32>
+  %1 = amdgpu.scaled_wmma (%arg4 * %arg1) * (%arg5 * %arg2) + %arg3
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i32, vector<64xf6E2M3FN>, i32, vector<64xf4E2M1FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale16_16x16x128_fp8
+func.func @wmma_scale16_16x16x128_fp8(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
+                                      %arg2 : vector<4xf32>, %arg3 : i64, %arg4 : i64) {
+  // CHECK: rocdl.wmma.scale16.f32.16x16x128.f8f6f4 {{.*}}, {{.*}}, %arg2, %arg3, %arg4 : (vector<16xi32>, vector<16xi32>, vector<4xf32>, i64, i64) -> vector<4xf32>
+  %0 = amdgpu.scaled_wmma (%arg3 * %arg0) * (%arg4 * %arg0) + %arg2
+    { m = 16 : i32, n = 16 : i32, k = 128 : i32 } : i64, vector<64xf8E4M3FN>, i64, vector<64xf8E4M3FN>, vector<4xf32>
+
+  func.return
+}
+
+// CHECK-LABEL: @wmma_scale_32x16x128_fp4
+func.func @wmma_scale_32x16x128_fp4(%arg0 : vector<128xf4E2M1FN>, %arg1 : vector<64xf4E2M1FN>,
+                      ...
[truncated]

@github-actions
Copy link

github-actions bot commented Nov 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design is too low level (exposes constants that are just types at the MLIR level and doesn't properly expose the underlying types of the "i32" scales) and also there are no tests

Copy link
Contributor

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a few nits but other than that it looks great!

edit: +1 to what Krzysztof said

Copy link
Contributor

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good, just one last question for me

Comment on lines 532 to 534
// Matrix A (F8|F6) x Matrix B (F4) with Scale A (E8M0), Scale B (E5M2|E4M3)
if (aIsF8F6 && isE8M0(scaleAElemType) && bIsF4 && (isE4M3(scaleBElemType)))
return success();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we're not explicitly checking for B scale type E5M3 in this combination of legal A and B matrix data?

Copy link
Contributor Author

@justinrosner justinrosner Nov 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't seem to have support for E5M3 yet in MLIR. See my comment here: #169854 (comment).

I was just waiting to see if there was a preference for implementing that in this PR, or in a separate one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see! I can't seem to find an issue or anything tracking support for F8E5M3, maybe we can create an issue and add a TODO here to document that this needs to be done.

Copy link
Contributor

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! But let's hold off on merging until @krzysz00 gets a chance to review.

@justinrosner justinrosner requested a review from kuhar December 1, 2025 14:21
@justinrosner justinrosner requested a review from kuhar December 1, 2025 14:59
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks OK but we need some negative tests to exercise the verification logic (with // expected-error)

@justinrosner justinrosner requested a review from kuhar December 1, 2025 16:20
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but let's wait for @krzysz00

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I'm liking this a lot better, thanks! (I think we're mainly in the nitpick zone)

ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB,
ScaledWMMAOutTypes:$destC,
VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA,
ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleAIdx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you go look at the conversion intrinsics, these scale indices have a much more non-trivial mapping. I think the name is fine, but I'd suggest we document the semantics of this attribute carefully.

Also, nit, scale_a_idx to make the generic pretty-print a bit nicer (but that's not uniform style, you're welcome to ignore it)


Scale parameters (`scaleA`, `scaleB`) are small vectors of f8 scale values
(either f8E8M0FNU, or f8E4M3FN). The index attributes (`scaleAIdx`, `scaleBIdx`)
select which element from the scale vector to use for scaling. During lowering,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct

Remember that the layout of the scales starts (at least for one of the vaules of this bit)

lane 0: (m=0, kOuter=0) (0, 1), (0, 2), (0, 3) ...

and we should really write the formulas out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the description of what the indices are doing.

attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));

// Reuse flags use default value of false.
attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a plan to enable these?

Copy link
Contributor Author

@justinrosner justinrosner Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a concrete one as of yet. I was debating between exposing these attributes at the AMDGPU dialect level to give users the ability to set it themselves, or implementing a pass that detects when the same matrix value is used in multiple WMMA ops.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if you had such a pass, it'd be at the AMDGPU level anyway

... in theory we could teach LLVM about this but they've got enough headaches already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you okay with opening up an issue to implement this in a separate PR? Or would you prefer it be included in this one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yeah, separate PR"s fine - this isn't critical

VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA,
ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleAIdx,
VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleB,
ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$scaleBIdx)>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it only 0 or 1? If it is packed into 32 bit it can also be 2 or 3.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not represent a scale select like in a scaled_mfma where it is used to indicate which of the 4 scales should be used. Instead, here it is used to refer to MatrixScale an enum defined here and is used in the intrinsic (how? im not sure yet).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is probably not going to want [] syntax unlike the gfx950 version, because it's a weird global modifier

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the [] syntax and just have the indices in the attribute list.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so, pending some other clarifications, these indices are something along the lines of wave_half_select - that is, 0 means you're reading from lanes 0-15 for scales, and 1 means you're on lanes 16-32

They're only applicable to some modes of the instruction (block size 32 and 16x16xN matrices - the ones where you pass 64 elements per lane, I think).

So I'll argue for a_first_scale_lane and b_first_scale_lane ... and also, we should go fix scaled_ext_packed816 to use a 0/16 numbering scheme and not a 0/1 one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use a_first_scale_lane and b_first_scale_lane. Also, I saw that you updated #170718 already.

Comment on lines 1469 to 1472
attrs.emplace_back("scaleAType",
rewriter.getI32IntegerAttr(op.getScaleAIdx()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is mapping Type with Idx ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaleAType is the name that the LLVM intrinsic uses to represent the lane selection logic (don't know why this name was chosen). See the updated AMDGPU.td description for how these indices are used.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for scaled WMMA (Wave Matrix Multiply-Accumulate) operations for the gfx1250 AMD GPU architecture. These operations enable matrix multiplication with per-block scaling of inputs using low-precision floating-point formats (fp4, fp6, fp8).

Key changes:

  • Adds amdgpu.scaled_wmma operation with support for 16x16x128 and 32x16x128 tile dimensions
  • Implements verification logic to validate operand types, dimensions, and scale compatibility
  • Provides lowering from AMDGPU dialect to ROCDL intrinsics with proper operand packing

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
mlir/test/Dialect/AMDGPU/ops.mlir Adds test cases for the new scaled_wmma operation covering various type combinations and dimensions
mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir Adds conversion tests showing proper lowering to ROCDL intrinsics, including error cases for invalid configurations
mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp Implements verification logic for ScaledWMMAOp to validate dimensions, operand types, and scale type combinations
mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp Implements lowering pattern for ScaledWMMAOp and refactors helper functions to support both MFMA and WMMA operations; contains unresolved merge conflict
mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td Defines the AMDGPU_ScaledWMMAOp operation with comprehensive documentation of supported formats and scaling behavior

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@github-actions
Copy link

github-actions bot commented Dec 3, 2025

🐧 Linux x64 Test Results

  • 7202 tests passed
  • 597 tests skipped

✅ The build succeeded and all tests passed.

attrs.emplace_back("fmtScaleB", rewriter.getI32IntegerAttr(*scaleBFmt));

// Reuse flags use default value of false.
attrs.emplace_back("reuseA", rewriter.getBoolAttr(false));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if you had such a pass, it'd be at the AMDGPU level anyway

... in theory we could teach LLVM about this but they've got enough headaches already.

ScaledWMMAInTypes:$sourceA, ScaledWMMAInTypes:$sourceB,
ScaledWMMAOutTypes:$destC,
VectorOfLengthAndType<[4, 8], [F8E8M0FNU, F8E4M3FN]>:$scaleA,
ConfinedAttr<I32Attr, [IntIsOneOf<[0, 1]>]>:$a_first_scale_lane,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I'd like to call for a change similar to scaled_ext_packed_matrix's recent change, where we make the value here 0 or 16, not 1. That way, it's actually indexing into the first lane that'll be read from, instead of awkwardly gesturing at a wave half

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

lanes provide scale values:
- Block size 32: For tile size 16x16x128, each matrix gets 64 scales stored in half
a VGPR, with `a_first_scale_lane`/`b_first_scale_lane` selecting lanes 0-15 (index=0) or
16-31 (index=1). For a tile size of 32x16x128, matrix A gets 128 scales in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And updating this if the suggestion goes through

lowering. The index attributes (`a_first_scale_lane`, `b_first_scale_lane`) select which register
lanes provide scale values:
- Block size 32: For tile size 16x16x128, each matrix gets 64 scales stored in half
a VGPR, with `a_first_scale_lane`/`b_first_scale_lane` selecting lanes 0-15 (index=0) or
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... "half a VGPR" is, I don't think, correct, and also isn't the right level of abstraction here.

The scales are passed in as vectors of bytes, refer to elements of the vector on a particular lane.

Also, formulas, perhaps.

Copy link
Contributor Author

@justinrosner justinrosner Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the description. Let me know if there are additional formulas that you want added.

a VGPR, with `a_first_scale_lane`/`b_first_scale_lane` selecting lanes 0-15 (index=0) or
16-31 (index=1). For a tile size of 32x16x128, matrix A gets 128 scales in
a full VGPR (`a_first_scale_lane` is unused), while matrix B gets 64 scales in
half a VGPR.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure this is wrong.

Also, as a validation, the scale index isn't unused, I'm pretty sure it must be 0 and non-zero values are reserved. So we need to be checking that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the descriptions.

@justinrosner justinrosner requested a review from krzysz00 December 8, 2025 22:15
Comment on lines +690 to +692
.Default([](Type) -> Value {
llvm_unreachable("unexpected input type for scale operand");
});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.Default([](Type) -> Value {
llvm_unreachable("unexpected input type for scale operand");
});
.DefaultUnreachable("unexpected input type for scale operand");

static Value castScaleOperand(ConversionPatternRewriter &rewriter, Location loc,
Value input) {
return TypeSwitch<Type, Value>(input.getType())
.Case<IntegerType>([&](IntegerType) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.Case<IntegerType>([&](IntegerType) {
.Case([&](IntegerType) {

return LLVM::ZExtOp::create(rewriter, loc, rewriter.getI32Type(),
input);
})
.Case<VectorType>([&](VectorType vectorType) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.Case<VectorType>([&](VectorType vectorType) {
.Case([&](VectorType vectorType) {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants