Skip to content

[mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper #137498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 2, 2025

Conversation

Muzammiluddin-Syed-ECE
Copy link
Contributor

Create a wrapper around the new scaled MFMAs that operate on specific element types and tile sizes.

See Issue.

@llvmbot
Copy link
Member

llvmbot commented Apr 27, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir-gpu

Author: Muzammil (Muzammiluddin-Syed-ECE)

Changes

Create a wrapper around the new scaled MFMAs that operate on specific element types and tile sizes.

See Issue.


Full diff: https://github.com/llvm/llvm-project/pull/137498.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+48)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+60-4)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+33)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+47)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f14aa5a2e1564..d1c601882fc93 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -830,4 +830,52 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_ScaledMFMAOp :
+    AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
+                        Pure]>,
+    Arguments<(ins
+                   I32Attr:$m,
+                   I32Attr:$n,
+                   I32Attr:$k,
+                   MFMAInTypes:$sourceA,
+                   MFMAInTypes:$sourceB,
+                   MFMAOutTypes:$destC,
+                   I32Attr:$scaleA,
+                   I32Attr:$scaleB,
+                   I32Attr:$opselA,
+                   I32Attr:$opselB)>,
+    Results<(outs MFMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
+    for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
+    multiple outer products in order to allow fast matrix multiplication.
+
+    The wrapper will select an appropriate `mfma` instruction, if one is available,
+    based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+    types of the source and destination arguments.
+
+    Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+    intrinsics that take an integer type of width `4K`. For example,
+    one can provide a vector<4xi8> as an argument to an MFMA instruction that
+    logically takes 4 i8s but whose intrinsics are specified to take an i32.
+    In these cases, the bytes in the vector will be concatenated in little-endian
+    order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+    This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
+    - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and 
+    fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile 
+    size. 
+    - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp` 
+    are omitted from this wrapper.
+    - The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for 
+    double-precision operations on gfx94x and so are not included here. 
+  }];
+  let assemblyFormat = [{
+    $sourceA `*` $sourceB `+` $destC
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
 #endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91dbc2de65c4e..527c22ad34782 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -803,7 +803,6 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
   aType = getElementTypeOrSelf(aType);
   bType = getElementTypeOrSelf(bType);
   destType = getElementTypeOrSelf(destType);
-
   if (chipset < kGfx950)
     return std::nullopt;
   if (!isa<Float32Type>(destType))
@@ -833,6 +832,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
       mfma.getBlocks(), chipset);
 }
 
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+                                 smfma.getSourceB().getType(),
+                                 smfma.getDestC().getType(), smfma.getM(),
+                                 smfma.getN(), smfma.getK(), 1u, chipset);
+}
+
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -954,6 +961,54 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
   }
 };
 
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type outType = typeConverter->convertType(op.getDestD().getType());
+    Type intrinsicOutType = outType;
+    if (auto outVecType = dyn_cast<VectorType>(outType))
+      if (outVecType.getElementType().isBF16())
+        intrinsicOutType = outVecType.clone(rewriter.getI16Type());
+
+    if (chipset.majorVersion != 9 || chipset < kGfx908)
+      return op->emitOpError("Scaled MFMA only supported on gfx908+");
+    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+    if (!maybeScaledIntrinsic.has_value())
+      return op.emitOpError(
+          "no intrinsic matching Scaled MFMA size on given chipset");
+
+    StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
+    OperationState loweredOp(loc, intrinsicName);
+    loweredOp.addTypes(intrinsicOutType);
+    loweredOp.addOperands(
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+         adaptor.getDestC()});
+    Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
+    Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
+    Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
+    Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
+    auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+    loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+                           createI32Constant(rewriter, loc, bTypeCode),
+                           /*scale A byte=*/opselA, /*scale A=*/scaleA,
+                           /*scale B byte=*/opselB, /*scale B=*/scaleB});
+    Value lowered = rewriter.create(loweredOp)->getResult(0);
+    if (outType != intrinsicOutType)
+      lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+    rewriter.replaceOp(op, lowered);
+    return success();
+  }
+};
+
 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1529,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
-           MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GatherToLDSOpLowering>(converter, chipset);
+           MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
+           ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
+                                                                 chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 549a4376a4a04..90131a66448fc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -506,6 +506,39 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+LogicalResult ScaledMFMAOp::verify() {
+  unsigned opselA = getOpselA();
+  unsigned opselB = getOpselB();
+
+  opselA >>= 8;
+  opselB >>= 8;
+
+  if (opselA != 0)
+    return emitOpError("Opsel A must be a zero extended 8 bit value.");
+
+  if (opselB != 0)
+    return emitOpError("Opsel B must be a zero extended 8 bit value.");
+
+  auto valid = [&](Type mlirElemType){
+    return llvm::TypeSwitch<Type, bool>(mlirElemType)
+    .Case([](Float8E4M3FNType) { return true; })
+    .Case([](Float8E5M2Type) { return true; })
+    .Case([](Float6E2M3FNType) { return true; })
+    .Case([](Float6E3M2FNType) { return true; })
+    .Case([](Float4E2M1FNType) { return true; })
+    .Default([](Type) { return false; });
+  };
+  
+  Type aType = getSourceA().getType();
+  Type bType = getSourceB().getType();
+  if (!valid(aType))
+    return emitOpError("Source A must be of element type fp4, fp6 or fp8."); 
+  if (!valid(bType)) 
+    return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
+  
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index de63f249bb530..f525a37e5ec80 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -51,3 +51,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
 
   func.return
 }
+
+func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
+                    %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
+                    %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
+                    %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
+  
+  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
+
+  func.return
+}

@llvmbot
Copy link
Member

llvmbot commented Apr 27, 2025

@llvm/pr-subscribers-mlir-amdgpu

Author: Muzammil (Muzammiluddin-Syed-ECE)

Changes

Create a wrapper around the new scaled MFMAs that operate on specific element types and tile sizes.

See Issue.


Full diff: https://github.com/llvm/llvm-project/pull/137498.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+48)
  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+60-4)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+33)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+47)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index f14aa5a2e1564..d1c601882fc93 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -830,4 +830,52 @@ def AMDGPU_GatherToLDSOp :
   let hasVerifier = 1;
 }
 
+def AMDGPU_ScaledMFMAOp :
+    AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
+                        Pure]>,
+    Arguments<(ins
+                   I32Attr:$m,
+                   I32Attr:$n,
+                   I32Attr:$k,
+                   MFMAInTypes:$sourceA,
+                   MFMAInTypes:$sourceB,
+                   MFMAOutTypes:$destC,
+                   I32Attr:$scaleA,
+                   I32Attr:$scaleB,
+                   I32Attr:$opselA,
+                   I32Attr:$opselB)>,
+    Results<(outs MFMAOutTypes: $destD)> {
+  let summary = "MLIR wrapper for CDNA mfma instructions";
+  let description = [{
+    The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
+    for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
+    multiple outer products in order to allow fast matrix multiplication.
+
+    The wrapper will select an appropriate `mfma` instruction, if one is available,
+    based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
+    types of the source and destination arguments.
+
+    Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
+    intrinsics that take an integer type of width `4K`. For example,
+    one can provide a vector<4xi8> as an argument to an MFMA instruction that
+    logically takes 4 i8s but whose intrinsics are specified to take an i32.
+    In these cases, the bytes in the vector will be concatenated in little-endian
+    order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
+
+    This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
+    - `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and 
+    fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile 
+    size. 
+    - `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp` 
+    are omitted from this wrapper.
+    - The negateA, negateB, and negateC flags in `amdgpu.mfma` are only supported for 
+    double-precision operations on gfx94x and so are not included here. 
+  }];
+  let assemblyFormat = [{
+    $sourceA `*` $sourceB `+` $destC
+    attr-dict
+    `:` type($sourceA) `,` type($sourceB) `,` type($destC)
+  }];
+  let hasVerifier = 1;
+}
 #endif // AMDGPU
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 91dbc2de65c4e..527c22ad34782 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -803,7 +803,6 @@ mfmaOpToScaledIntrinsic(Type aType, Type bType, Type destType, uint32_t m,
   aType = getElementTypeOrSelf(aType);
   bType = getElementTypeOrSelf(bType);
   destType = getElementTypeOrSelf(destType);
-
   if (chipset < kGfx950)
     return std::nullopt;
   if (!isa<Float32Type>(destType))
@@ -833,6 +832,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
       mfma.getBlocks(), chipset);
 }
 
+static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
+  return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
+                                 smfma.getSourceB().getType(),
+                                 smfma.getDestC().getType(), smfma.getM(),
+                                 smfma.getN(), smfma.getK(), 1u, chipset);
+}
+
 /// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
 /// if one exists. This includes checking to ensure the intrinsic is supported
 /// on the architecture you are compiling for.
@@ -954,6 +961,54 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
   }
 };
 
+struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
+  ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
+      : ConvertOpToLLVMPattern<ScaledMFMAOp>(converter), chipset(chipset) {}
+
+  Chipset chipset;
+
+  LogicalResult
+  matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Type outType = typeConverter->convertType(op.getDestD().getType());
+    Type intrinsicOutType = outType;
+    if (auto outVecType = dyn_cast<VectorType>(outType))
+      if (outVecType.getElementType().isBF16())
+        intrinsicOutType = outVecType.clone(rewriter.getI16Type());
+
+    if (chipset.majorVersion != 9 || chipset < kGfx908)
+      return op->emitOpError("Scaled MFMA only supported on gfx908+");
+    std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
+        maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
+    if (!maybeScaledIntrinsic.has_value())
+      return op.emitOpError(
+          "no intrinsic matching Scaled MFMA size on given chipset");
+
+    StringRef intrinsicName = std::get<0>(*maybeScaledIntrinsic);
+    OperationState loweredOp(loc, intrinsicName);
+    loweredOp.addTypes(intrinsicOutType);
+    loweredOp.addOperands(
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
+         adaptor.getDestC()});
+    Value scaleA = createI32Constant(rewriter, loc, adaptor.getScaleA());
+    Value scaleB = createI32Constant(rewriter, loc, adaptor.getScaleB());
+    Value opselA = createI32Constant(rewriter, loc, adaptor.getOpselA());
+    Value opselB = createI32Constant(rewriter, loc, adaptor.getOpselB());
+    auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
+    loweredOp.addOperands({createI32Constant(rewriter, loc, aTypeCode),
+                           createI32Constant(rewriter, loc, bTypeCode),
+                           /*scale A byte=*/opselA, /*scale A=*/scaleA,
+                           /*scale B byte=*/opselB, /*scale B=*/scaleB});
+    Value lowered = rewriter.create(loweredOp)->getResult(0);
+    if (outType != intrinsicOutType)
+      lowered = rewriter.create<LLVM::BitcastOp>(loc, outType, lowered);
+    rewriter.replaceOp(op, lowered);
+    return success();
+  }
+};
+
 struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
   WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
       : ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1529,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
            RawBufferOpLowering<RawBufferAtomicCmpswapOp,
                                ROCDL::RawPtrBufferAtomicCmpSwap>,
            AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
-           MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
-           PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
-           GatherToLDSOpLowering>(converter, chipset);
+           MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
+           ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
+           PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
+                                                                 chipset);
   patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
 }
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 549a4376a4a04..90131a66448fc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -506,6 +506,39 @@ LogicalResult GatherToLDSOp::verify() {
   return success();
 }
 
+LogicalResult ScaledMFMAOp::verify() {
+  unsigned opselA = getOpselA();
+  unsigned opselB = getOpselB();
+
+  opselA >>= 8;
+  opselB >>= 8;
+
+  if (opselA != 0)
+    return emitOpError("Opsel A must be a zero extended 8 bit value.");
+
+  if (opselB != 0)
+    return emitOpError("Opsel B must be a zero extended 8 bit value.");
+
+  auto valid = [&](Type mlirElemType){
+    return llvm::TypeSwitch<Type, bool>(mlirElemType)
+    .Case([](Float8E4M3FNType) { return true; })
+    .Case([](Float8E5M2Type) { return true; })
+    .Case([](Float6E2M3FNType) { return true; })
+    .Case([](Float6E3M2FNType) { return true; })
+    .Case([](Float4E2M1FNType) { return true; })
+    .Default([](Type) { return false; });
+  };
+  
+  Type aType = getSourceA().getType();
+  Type bType = getSourceB().getType();
+  if (!valid(aType))
+    return emitOpError("Source A must be of element type fp4, fp6 or fp8."); 
+  if (!valid(bType)) 
+    return emitOpError("Source B must be of element type fp4, fp6 or fp8.");
+  
+  return success();
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index de63f249bb530..f525a37e5ec80 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -51,3 +51,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
 
   func.return
 }
+
+func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
+                    %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
+                    %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
+                    %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>) {
+  
+  // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
+  // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
+
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
+
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
+  
+  // CHECK: llvm.bitcast
+  // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
+  
+  // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
+  amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
+  // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
+  amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32,  scaleA = 1 : i32, opselA = 1 : i32, scaleB = 2 : i32, opselB = 2 : i32 } : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
+
+  func.return
+}

Copy link

github-actions bot commented Apr 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE force-pushed the muzasyed/scaled branch 2 times, most recently from 24f7bcd to 131d99a Compare April 27, 2025 06:12
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE changed the title Define a amdgpu.scaling_mfma wrapper Define an amdgpu.scaling_mfma wrapper Apr 27, 2025
@@ -833,6 +833,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
mfma.getBlocks(), chipset);
}

static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than creating an overloaded function, you can pass the operation* and then do the dyn_cast + if_else.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of branching on the operation type?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be templatized as well; my point was to remove the unnecessary function that does the same thing.

Copy link
Contributor Author

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE Apr 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, templating doesn't remove the need for branching since I would still be branching off the class of the input op to pass in the appropriate arguments (they differ between the proposed scaled_mfma and the existing mfma). So, I opted for the branching on op type which I find clearer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the reason this isn't a pure template<typename MfmaOp> is because the arguments don't match (MFMA has a getBlocks(), ScalingMfma doesn't, for example)

@kuhar kuhar changed the title Define an amdgpu.scaling_mfma wrapper [mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper Apr 27, 2025
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some drive-by nits

@@ -833,6 +833,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
mfma.getBlocks(), chipset);
}

static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the benefit of branching on the operation type?

@kuhar kuhar requested a review from krzysz00 April 27, 2025 19:44
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Substantial semantics note

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall feels better, I've got some comments

Also, can you stick a test under mlir/test/Dialect/AMDGPU to make sure this parses/prints correctly?

@Muzammiluddin-Syed-ECE Muzammiluddin-Syed-ECE force-pushed the muzasyed/scaled branch 3 times, most recently from 9082e3f to 55bd3b5 Compare April 30, 2025 19:34
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on the test, is looking about right

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for all your work on this

@krzysz00 krzysz00 merged commit 105ce58 into llvm:main May 2, 2025
9 of 10 checks passed
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
Create a wrapper around the new scaled MFMAs that operate on specific
element types and tile sizes.

See [Issue](iree-org/iree#20616).

---------

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
Create a wrapper around the new scaled MFMAs that operate on specific
element types and tile sizes.

See [Issue](iree-org/iree#20616).

---------

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
Create a wrapper around the new scaled MFMAs that operate on specific
element types and tile sizes.

See [Issue](iree-org/iree#20616).

---------

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
Create a wrapper around the new scaled MFMAs that operate on specific
element types and tile sizes.

See [Issue](iree-org/iree#20616).

---------

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants