Skip to content

Conversation

Muzammiluddin-Syed-ECE
Copy link
Contributor

The ScaledMFMAOp accepts scales as a vector of 4 bytes (vector<4xf8E8M0FNU>) that can be stored in a single register with a particular scale accessed using the OpSel attribute. Currently, we only use one byte in this 4-byte vector, resulting in 3 wasted registers.

This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors.

Example:

  %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
  %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
  amdgpu.scaled_mfma(%scale[0] * ...

to

  %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU> 
  %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
  amdgpu.scaled_mfma(%scale[0-3] * ...

@llvmbot
Copy link
Member

llvmbot commented Aug 29, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-amdgpu

@llvm/pr-subscribers-backend-amdgpu

Author: Muzammil (Muzammiluddin-Syed-ECE)

Changes

The ScaledMFMAOp accepts scales as a vector of 4 bytes (vector&lt;4xf8E8M0FNU&gt;) that can be stored in a single register with a particular scale accessed using the OpSel attribute. Currently, we only use one byte in this 4-byte vector, resulting in 3 wasted registers.

This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors.

Example:

  %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector&lt;?x?x?xf8E8M0FNU&gt;
  %scale = vector.insert %unit, ... : f8E8M0FNU into vector&lt;4xf8E8M0FNU&gt;
  amdgpu.scaled_mfma(%scale[0] * ...

to

  %reshaped = vector.shape_cast %ScaleSrc : vector&lt;?x?x?xf8E8M0FNU&gt; to vector&lt;?x4xf8E8M0FNU&gt; 
  %scale = vector.extract %reshaped[?] : vector&lt;4xf8E8M0FNU&gt; from vector&lt;?x4xf8E8M0FNU&gt;
  amdgpu.scaled_mfma(%scale[0-3] * ...

Full diff: https://github.com/llvm/llvm-project/pull/155951.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+1)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+142)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt (+1)
  • (modified) mlir/test/Dialect/AMDGPU/canonicalize.mlir (+25)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2ccf350a359a8..a24a918357f2d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
     attr-dict
     `:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
   }];
+  let hasCanonicalizer = 1;
 }
 #endif // AMDGPU
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 11a40d663a201..4107ec53a0988 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
@@ -28,6 +29,7 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include <cstdint>
 #include <limits>
 #include <optional>
 
@@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ScaledMFMAOp op,
+                                PatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    // If this use of a scale has a non zero opsel, packing has already been
+    // done.
+    auto checkIfUnpackable = [&](OpOperand &op) {
+      if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+        switch (op.getOperandNumber()) {
+        case 3:
+          return smfma.getScalesIdxA() != 0;
+          break;
+        case 4:
+          return smfma.getScalesIdxB() != 0;
+          break;
+        default:
+          return true;
+          break;
+        }
+      }
+    };
+
+    auto setOpsel = [&](unsigned idx, int64_t val) {
+      switch (idx) {
+      case 3:
+        return op.setScalesIdxA(val);
+        break;
+      case 4:
+        return op.setScalesIdxB(val);
+        break;
+      default:
+        break;
+      }
+    };
+
+    // Obtain flat index from offsets and shape.
+    auto getIdxFromExtract = [](vector::ExtractOp op) {
+      ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+      int cumul = 1;
+      int idx = 0;
+      for (auto [offset, size] :
+           reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
+        idx += offset * cumul;
+        cumul *= size;
+      }
+      return idx;
+    };
+
+    // Obtain offsets for new shape from flat index.
+    auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
+      SmallVector<int64_t> res;
+      ShapedType shapedty = static_cast<ShapedType>(ty);
+      int64_t numElements = shapedty.getNumElements();
+      for (auto size : shapedty.getShape()) {
+        numElements /= size;
+        res.push_back(idx / numElements);
+        idx -= (idx / numElements) * size;
+      }
+      return res;
+    };
+
+    // For every scale operand of this ScaledMFMAOp, if the scale follows the
+    // following pattern:
+    //
+    // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
+    // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
+    // amdgpu.scaled_mfma(%scale[0] * ...
+    //
+    // rewrite to:
+    //
+    // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
+    // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+    // amdgpu.scaled_mfma(%scale[0-3] * ...
+    //
+    // This creates duplicate shape_casts for every use but these will be removed in CSE.
+    for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+      auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+      if (!insertOp) {
+        return failure();
+      }
+      if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
+        return failure();
+      }
+
+      auto extractOp =
+          insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
+      if (!extractOp) {
+        return failure();
+      }
+
+      Value scaleSrc = extractOp.getOperand(0);
+      auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
+      if (!stype) {
+        return failure();
+      }
+      // We do not handle dynamic dims yet, assume that the input is padded to
+      // a static shape now.
+      if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
+                       [&](int64_t i) { return stype.isDynamicDim(i); })) {
+        return failure();
+      }
+
+      int64_t numElements = stype.getNumElements();
+      if (numElements <= 4) {
+        return failure();
+      }
+
+      Type newSrcType = VectorType::get(
+          SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
+      Value newScaleSrc =
+          rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
+      int64_t idx = getIdxFromExtract(extractOp);
+      SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
+      auto scaleTy = VectorType::get({4}, stype.getElementType());
+      Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
+          SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
+      Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
+      op.setOperand(opIdx, scale);
+      setOpsel(opIdx, offsets[1]);
+    }
+    return success();
+  }
+};
+} // namespace
+
+void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  results.add<PackScales>(context);
+}
+
 #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 2a019954c8356..5d14a05945e95 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
   MLIRROCDLDialect
   # Needed for GPU address space enum definition
   MLIRGPUDialect
+  MLIRVectorDialect
   MLIRIR
   MLIRSideEffectInterfaces
   MLIRMemRefUtils
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 5501ad42dbd90..75cbf29c95f29 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
     : f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
   func.return
 }
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma
+// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
+  %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+  %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+  %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+  %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+  %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+  return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
+}

Copy link

github-actions bot commented Aug 29, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

50% minor notes, some substantive problems

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
@krzysz00
Copy link
Contributor

... Oh, and while I'm here, the setOperand()s and such should be wrapped in a rewriter.modifyOpInplace()

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM here

Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
@krzysz00 krzysz00 enabled auto-merge (squash) September 18, 2025 19:19
@krzysz00 krzysz00 merged commit 9628061 into llvm:main Sep 18, 2025
9 checks passed
offset = numElements - 4l;
}
Type scaleSrcElemType = scaleSrcType.getElementType();
auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need SmallVector here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably a good trivial followup once commit access goes through, good catch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants