-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp #155951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp #155951
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-backend-amdgpu Author: Muzammil (Muzammiluddin-Syed-ECE) ChangesThe ScaledMFMAOp accepts scales as a vector of 4 bytes ( This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors. Example:
to
Full diff: https://github.com/llvm/llvm-project/pull/155951.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index 2ccf350a359a8..a24a918357f2d 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
attr-dict
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
}];
+ let hasCanonicalizer = 1;
}
#endif // AMDGPU
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 11a40d663a201..4107ec53a0988 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
@@ -28,6 +29,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <cstdint>
#include <limits>
#include <optional>
@@ -631,6 +633,146 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ScaledMFMAOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Check if the scales input is used in other scaled mfma's while they exist.
+/// If theyre unused then pack the scales.
+struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ScaledMFMAOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // If this use of a scale has a non zero opsel, packing has already been
+ // done.
+ auto checkIfUnpackable = [&](OpOperand &op) {
+ if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
+ switch (op.getOperandNumber()) {
+ case 3:
+ return smfma.getScalesIdxA() != 0;
+ break;
+ case 4:
+ return smfma.getScalesIdxB() != 0;
+ break;
+ default:
+ return true;
+ break;
+ }
+ }
+ };
+
+ auto setOpsel = [&](unsigned idx, int64_t val) {
+ switch (idx) {
+ case 3:
+ return op.setScalesIdxA(val);
+ break;
+ case 4:
+ return op.setScalesIdxB(val);
+ break;
+ default:
+ break;
+ }
+ };
+
+ // Obtain flat index from offsets and shape.
+ auto getIdxFromExtract = [](vector::ExtractOp op) {
+ ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
+ int cumul = 1;
+ int idx = 0;
+ for (auto [offset, size] :
+ reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
+ idx += offset * cumul;
+ cumul *= size;
+ }
+ return idx;
+ };
+
+ // Obtain offsets for new shape from flat index.
+ auto getOffsetsFromIdx = [](int64_t idx, Type ty) {
+ SmallVector<int64_t> res;
+ ShapedType shapedty = static_cast<ShapedType>(ty);
+ int64_t numElements = shapedty.getNumElements();
+ for (auto size : shapedty.getShape()) {
+ numElements /= size;
+ res.push_back(idx / numElements);
+ idx -= (idx / numElements) * size;
+ }
+ return res;
+ };
+
+ // For every scale operand of this ScaledMFMAOp, if the scale follows the
+ // following pattern:
+ //
+ // %unit = vector.extract %ScaleSrc[offsets] : f8E8M0FNU from vector<?x?x?xf8E8M0FNU>
+ // %scale = vector.insert %unit, ... : f8E8M0FNU into vector<4xf8E8M0FNU>
+ // amdgpu.scaled_mfma(%scale[0] * ...
+ //
+ // rewrite to:
+ //
+ // %reshaped = vector.shape_cast %ScaleSrc : vector<?x?x?xf8E8M0FNU> to vector<?x4xf8E8M0FNU>
+ // %scale = vector.extract %reshaped[?] : vector<4xf8E8M0FNU> from vector<?x4xf8E8M0FNU>
+ // amdgpu.scaled_mfma(%scale[0-3] * ...
+ //
+ // This creates duplicate shape_casts for every use but these will be removed in CSE.
+ for (auto opIdx : SmallVector<int64_t>({3, 4})) {
+ auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
+ if (!insertOp) {
+ return failure();
+ }
+ if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
+ return failure();
+ }
+
+ auto extractOp =
+ insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
+ if (!extractOp) {
+ return failure();
+ }
+
+ Value scaleSrc = extractOp.getOperand(0);
+ auto stype = dyn_cast<ShapedType>(scaleSrc.getType());
+ if (!stype) {
+ return failure();
+ }
+ // We do not handle dynamic dims yet, assume that the input is padded to
+ // a static shape now.
+ if (llvm::any_of(llvm::seq<int64_t>(0, stype.getRank()),
+ [&](int64_t i) { return stype.isDynamicDim(i); })) {
+ return failure();
+ }
+
+ int64_t numElements = stype.getNumElements();
+ if (numElements <= 4) {
+ return failure();
+ }
+
+ Type newSrcType = VectorType::get(
+ SmallVector<int64_t>({numElements / 4, 4}), stype.getElementType());
+ Value newScaleSrc =
+ rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
+ int64_t idx = getIdxFromExtract(extractOp);
+ SmallVector<int64_t> offsets(getOffsetsFromIdx(idx, newSrcType));
+ auto scaleTy = VectorType::get({4}, stype.getElementType());
+ Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, newScaleSrc, SmallVector<int64_t>{offsets[0], 0},
+ SmallVector<int64_t>{1, 4}, SmallVector<int64_t>{1, 1});
+ Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
+ op.setOperand(opIdx, scale);
+ setOpsel(opIdx, offsets[1]);
+ }
+ return success();
+ }
+};
+} // namespace
+
+void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<PackScales>(context);
+}
+
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
index 2a019954c8356..5d14a05945e95 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
+ MLIRVectorDialect
MLIRIR
MLIRSideEffectInterfaces
MLIRMemRefUtils
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 5501ad42dbd90..75cbf29c95f29 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -159,3 +159,28 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
func.return
}
+
+// -----
+
+// CHECK-LABEL: func @scaled_mfma
+// CHECK: %[[SCALE_1:.*]] = vector.extract %{{.*}}[0] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_2:.*]] = vector.extract %{{.*}}[1] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: %[[SCALE_3:.*]] = vector.extract %{{.*}}[2] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: %[[SCALE_4:.*]] = vector.extract %{{.*}}[3] : vector<4xf8E8M0FNU> from vector<4x4xf8E8M0FNU>
+// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
+ %cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
+ %scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
+ %sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
+ %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
2eed475
to
2b6d917
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
50% minor notes, some substantive problems
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
... Oh, and while I'm here, the |
6c026b2
to
3873eda
Compare
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
168a45e
to
2404d99
Compare
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
24f5c4c
to
ab6b1ae
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM here
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
dd6dc05
to
9d8ffbd
Compare
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
offset = numElements - 4l; | ||
} | ||
Type scaleSrcElemType = scaleSrcType.getElementType(); | ||
auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need SmallVector here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, probably a good trivial followup once commit access goes through, good catch
The ScaledMFMAOp accepts scales as a vector of 4 bytes (
vector<4xf8E8M0FNU>
) that can be stored in a single register with a particular scale accessed using theOpSel
attribute. Currently, we only use one byte in this 4-byte vector, resulting in 3 wasted registers.This is fixed by identifying when single byte extractions are performed and rewriting them into extractions of 4-byte vectors.
Example:
to