Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -1048,5 +1048,6 @@ def AMDGPU_ScaledMFMAOp :
attr-dict
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
}];
let hasCanonicalizer = 1;
}
#endif // AMDGPU
137 changes: 137 additions & 0 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand All @@ -26,8 +27,11 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"

#include <algorithm>
#include <cstdint>
#include <limits>
#include <optional>

Expand Down Expand Up @@ -631,6 +635,139 @@ LogicalResult TransposeLoadOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// ScaledMFMAOp
//===----------------------------------------------------------------------===//

namespace {
/// Check if the scales input is used in other scaled mfma's while they exist.
/// If theyre unused then pack the scales.
struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ScaledMFMAOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto setOpsel = [&op](unsigned idx, int64_t val) {
switch (idx) {
case 3:
op.setScalesIdxA(val);
break;
case 4:
op.setScalesIdxB(val);
break;
default:
break;
}
};

// For every scale operand of this ScaledMFMAOp, if the scale is produced by
// the extraction of a single scale from some vector, then attempt to
// extract 4 values from that vector instead.
//
// Example: (f8 here means f8E8M0FNU)
// %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
// %scale = vector.insert %unit, ... : f8 into vector<4xf8>
// amdgpu.scaled_mfma(%scale[0] * ...
//
// rewrite to:
//
// %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8>
// %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8>
// amdgpu.scaled_mfma(%scale[0-3] * ...
//
// This creates duplicate shape_casts for every use but these will be
// removed in CSE.
for (auto opIdx : std::array<int64_t, 2>({3, 4})) {
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>();
if (!insertOp) {
return rewriter.notifyMatchFailure(op,
"defining op not a vector.insert");
}
// If the extracted value is not a single scalar, then it has been packed.
if (isa<VectorType>(insertOp.getValueToStore().getType())) {
return rewriter.notifyMatchFailure(
op, "scaled mfma operand already packed");
}

auto extractOp =
insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
if (!extractOp) {
return rewriter.notifyMatchFailure(op,
"defining op not a vector.extract");
}

Value scaleSrc = extractOp.getOperand(0);
auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
if (!scaleSrcType) {
return rewriter.notifyMatchFailure(op, "not a vector type");
}

// We do not handle dynamic dims yet, assume that the input is padded to
// a static shape now.
if (!scaleSrcType.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"dynamic dims not yet supported");
}

int64_t numElements = scaleSrcType.getNumElements();
if (numElements <= 4) {
return rewriter.notifyMatchFailure(
op, "no packing if # of scales less than four");
}

// Find a linearized idx using the size and offsets of the extract op.
auto extractedPos = llvm::to_vector_of<int64_t>(
llvm::reverse(extractOp.getStaticPosition()));
ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
int64_t scaleSrcRank = scaleSrcType.getRank();
SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
for (int64_t i = 1; i < scaleSrcRank; ++i) {
extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
}
int64_t idx = linearize(extractedPos, extractSizes);

// All n scales (where n is the total number of scales) must now be
// extracted in chunks of 4 elements. This is done by dividing the
// original vector of scales into groups of 4 elements
// at offsets 0, 4, ..., m (where m = n/4). All extractions of a
// scale at a particular index are now replaced with an extraction
// of the entire group of 4 elements to which that index belongs.
//
// If the number of scales happens to be indivisible by 4, extract
// the remaining n - m scales in a chunk of 4 elements starting at
// offset n - 4.
int64_t offset = idx - (idx % 4);
int64_t opsel = idx - offset;
int64_t size = 4l;
// Accomdate remaining elements in the case of non-4-divisible vectors.
if (numElements - offset < size) {
opsel = size - (numElements - idx);
offset = numElements - 4l;
}
Type scaleSrcElemType = scaleSrcType.getElementType();
auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need SmallVector here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably a good trivial followup once commit access goes through, good catch

scaleSrcElemType);
Value newScaleSrc =
vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc);
auto extract = vector::ExtractStridedSliceOp::create(
rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset},
ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1});
rewriter.modifyOpInPlace(op, [&] {
op->setOperand(opIdx, extract);
setOpsel(opIdx, opsel);
});
}
return success();
}
};
} // namespace

void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<PackScales>(context);
}

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUDialect
MLIRROCDLDialect
# Needed for GPU address space enum definition
MLIRGPUDialect
MLIRVectorDialect
MLIRIR
MLIRSideEffectInterfaces
MLIRMemRefUtils
Expand Down
85 changes: 85 additions & 0 deletions mlir/test/Dialect/AMDGPU/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,88 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
: f32, memref<128x72xf32, 1>, memref<?x?xf32, 3>
func.return
}

// -----

// CHECK-LABEL: func @scaled_mfma
// CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
// CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
%scaleA = vector.extract %scalesA[0, 0, 3, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
}

// -----

// CHECK-LABEL: func @scaled_mfma_less_than_4
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
%scaleA = vector.extract %scalesA[0] : f8E8M0FNU from vector<2xf8E8M0FNU>
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_0 : vector<4xf32>
}

// -----

// CHECK-LABEL: func @scaled_mfma_ugly_shapes
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
%scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
%scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU>

// idx = 160 => opsel = 3 (last idx of last 4 bytes)
%scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 159 => opsel = 3
%scaleB_6_21 = vector.extract %scalesB[6, 21] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 158 => opsel = 2
%scaleB_6_20 = vector.extract %scalesB[6, 20] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
// idx = 157 => opsel = 1
%scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU>

%sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>

%sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>

%res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}