Skip to content

Commit ab6b1ae

Browse files
PR Review round 2
Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
1 parent 970aa1a commit ab6b1ae

File tree

2 files changed

+52
-85
lines changed

2 files changed

+52
-85
lines changed

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/ADT/SmallVector.h"
3131
#include "llvm/ADT/TypeSwitch.h"
3232

33+
#include <algorithm>
3334
#include <cstdint>
3435
#include <limits>
3536
#include <optional>
@@ -647,22 +648,6 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
647648
LogicalResult matchAndRewrite(ScaledMFMAOp op,
648649
PatternRewriter &rewriter) const override {
649650
Location loc = op.getLoc();
650-
// If this use of a scale has a non zero opsel, packing has already been
651-
// done.
652-
auto checkIfUnpackable = [&](OpOperand &op) {
653-
if (auto smfma = dyn_cast<ScaledMFMAOp>(op.getOwner())) {
654-
switch (op.getOperandNumber()) {
655-
case 3:
656-
return smfma.getScalesIdxA() != 0;
657-
case 4:
658-
return smfma.getScalesIdxB() != 0;
659-
default:
660-
break;
661-
}
662-
}
663-
return true;
664-
};
665-
666651
auto setOpsel = [&](unsigned idx, int64_t val) {
667652
switch (idx) {
668653
case 3:
@@ -676,22 +661,11 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
676661
}
677662
};
678663

679-
// Obtain flat index from offsets and shape.
680-
auto getIdxFromExtract = [](vector::ExtractOp op) {
681-
ShapedType ty = dyn_cast<ShapedType>(op.getOperand(0).getType());
682-
int64_t cumul = 1;
683-
int64_t idx = 0;
684-
for (auto [offset, size] :
685-
reverse(llvm::zip_equal(op.getStaticPosition(), ty.getShape()))) {
686-
idx += offset * cumul;
687-
cumul *= size;
688-
}
689-
return idx;
690-
};
691-
692-
// For every scale operand of this ScaledMFMAOp, if the scale follows the
693-
// following pattern:
694-
// (f8 here means f8E8M0FNU)
664+
// For every scale operand of this ScaledMFMAOp, if the scale is produced by
665+
// the extraction of a single scale from some vector, then attempt to
666+
// extract 4 values from that vector instead.
667+
//
668+
// Example: (f8 here means f8E8M0FNU)
695669
// %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...>
696670
// %scale = vector.insert %unit, ... : f8 into vector<4xf8>
697671
// amdgpu.scaled_mfma(%scale[0] * ...
@@ -710,57 +684,79 @@ struct PackScales final : OpRewritePattern<ScaledMFMAOp> {
710684
return rewriter.notifyMatchFailure(op,
711685
"defining op not a vector.insert");
712686
}
713-
if (llvm::any_of(insertOp.getResult().getUses(), checkIfUnpackable)) {
714-
return rewriter.notifyMatchFailure(op,
715-
"some scaled mfma's already packed");
687+
// if the extracted value is not a single scalar, then it has been packed.
688+
if (dyn_cast<VectorType>(insertOp.getValueToStore().getType())) {
689+
return rewriter.notifyMatchFailure(
690+
op, "scaled mfma operand already packed");
716691
}
717692

718693
auto extractOp =
719-
insertOp.getOperand(0).getDefiningOp<vector::ExtractOp>();
694+
insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>();
720695
if (!extractOp) {
721696
return rewriter.notifyMatchFailure(op,
722697
"defining op not a vector.extract");
723698
}
724699

725700
Value scaleSrc = extractOp.getOperand(0);
726-
auto stype = dyn_cast<VectorType>(scaleSrc.getType());
727-
if (!stype) {
701+
auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType());
702+
if (!scaleSrcType) {
728703
return rewriter.notifyMatchFailure(op, "not a vector type");
729704
}
705+
730706
// We do not handle dynamic dims yet, assume that the input is padded to
731707
// a static shape now.
732-
if (!stype.hasStaticShape()) {
708+
if (!scaleSrcType.hasStaticShape()) {
733709
return rewriter.notifyMatchFailure(op,
734710
"dynamic dims not yet supported");
735711
}
736712

737-
int64_t numElements = stype.getNumElements();
713+
int64_t numElements = scaleSrcType.getNumElements();
738714
if (numElements <= 4) {
739715
return rewriter.notifyMatchFailure(
740716
op, "no packing if # of scales less than four");
741717
}
742-
int64_t idx = getIdxFromExtract(extractOp);
718+
719+
// Find a linearized idx using the size and offsets of the extract op
720+
ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape();
721+
int64_t scaleSrcRank = scaleSrcType.getRank();
722+
SmallVector<int64_t> extractedPos(extractOp.getStaticPosition());
723+
SmallVector<int64_t> extractSizes(scaleSrcRank, 1);
724+
std::reverse(extractedPos.begin(), extractedPos.end());
725+
for (int64_t i = 1; i < scaleSrcRank; i++) {
726+
extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i];
727+
}
728+
int64_t idx = linearize(extractedPos, extractSizes);
729+
730+
// All n scales (where n is the total number of scales) must now be
731+
// extracted in chunks of 4 elements. This is done by dividing the
732+
// original vector of scales into groups of 4 elements
733+
// at offsets 0, 4, ..., m (where m = n/4). All extractions of a
734+
// scale at a particular index are now replaced with an extraction
735+
// of the entire group of 4 elements to which that index belongs.
736+
//
737+
// If the number of scales happens to be indivisible by 4, extract
738+
// the remaining n - m scales in a chunk of 4 elements starting at
739+
// offset n - 4.
743740
int64_t offset = idx - (idx % 4);
744-
int64_t size = std::min(4l, numElements - offset);
745741
int64_t opsel = idx - offset;
746-
if (size != 4l) {
747-
opsel += 4l - size;
742+
int64_t size = 4l;
743+
// Accomdate remaining elements in the case of non-4-divisible vectors.
744+
if (numElements - offset < size) {
745+
opsel = size - (numElements - idx);
748746
offset = numElements - 4l;
749-
size = 4l;
750747
}
751-
752-
Type newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
753-
stype.getElementType());
748+
Type scaleSrcElemType = scaleSrcType.getElementType();
749+
auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}),
750+
scaleSrcElemType);
754751
Value newScaleSrc =
755752
rewriter.create<vector::ShapeCastOp>(loc, newSrcType, scaleSrc);
756-
auto scaleTy = VectorType::get({4}, stype.getElementType());
757-
Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
753+
auto extract = rewriter.create<vector::ExtractStridedSliceOp>(
758754
loc, newScaleSrc, ArrayRef<int64_t>{offset}, ArrayRef<int64_t>{size},
759755
ArrayRef<int64_t>{1});
760-
Value scale = rewriter.create<vector::ShapeCastOp>(loc, scaleTy, extract);
761-
rewriter.modifyOpInPlace(
762-
op, [&op, opIdx, scale] { op->setOperand(opIdx, scale); });
763-
setOpsel(opIdx, opsel);
756+
rewriter.modifyOpInPlace(op, [&] {
757+
op->setOperand(opIdx, extract);
758+
setOpsel(opIdx, opsel);
759+
});
764760
}
765761
return success();
766762
}

mlir/test/Dialect/AMDGPU/canonicalize.mlir

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -204,38 +204,21 @@ func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
204204
return %res_0 : vector<4xf32>
205205
}
206206

207-
208207
// -----
209208

210209
// CHECK-LABEL: func @scaled_mfma_ugly_shapes
211-
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
212-
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
213-
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[0] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
214-
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
215210
// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
216211
// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
217212
// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
218213
// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
219-
func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
214+
func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
220215
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
221216
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
222-
%scaleA_0_0 = vector.extract %scalesA[0, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
223-
%scaleA_0_1 = vector.extract %scalesA[1, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
224-
%scaleA_0_2 = vector.extract %scalesA[2, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
225-
%scaleA_0_3 = vector.extract %scalesA[3, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
226217
%scaleA_0_4 = vector.extract %scalesA[4, 0] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
227218
%scaleA_0_5 = vector.extract %scalesA[4, 1] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
228219
%scaleA_0_6 = vector.extract %scalesA[4, 2] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
229220
%scaleA_0_7 = vector.extract %scalesA[4, 3] : f8E8M0FNU from vector<5x5xf8E8M0FNU>
230221

231-
// idx = 138 + 8 = 146 => opsel = 2
232-
%scaleB_6_8 = vector.extract %scalesB[6, 8] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
233-
// idx = 147 => opsel = 3
234-
%scaleB_6_9 = vector.extract %scalesB[6, 9] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
235-
// idx = 148 => opsel = 0
236-
%scaleB_6_10 = vector.extract %scalesB[6, 10] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
237-
// idx = 149 => opsel = 1
238-
%scaleB_6_11 = vector.extract %scalesB[6, 11] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
239222
// idx = 160 => opsel = 3 (last idx of last 4 bytes)
240223
%scaleB_6_22 = vector.extract %scalesB[6, 22] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
241224
// idx = 159 => opsel = 3
@@ -245,31 +228,19 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
245228
// idx = 157 => opsel = 1
246229
%scaleB_6_19 = vector.extract %scalesB[6, 19] : f8E8M0FNU from vector<7x23xf8E8M0FNU>
247230

248-
%sA_0_0 = vector.insert %scaleA_0_0, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
249-
%sA_0_1 = vector.insert %scaleA_0_1, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
250-
%sA_0_2 = vector.insert %scaleA_0_2, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
251-
%sA_0_3 = vector.insert %scaleA_0_3, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
252231
%sA_0_4 = vector.insert %scaleA_0_4, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
253232
%sA_0_5 = vector.insert %scaleA_0_5, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
254233
%sA_0_6 = vector.insert %scaleA_0_6, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
255234
%sA_0_7 = vector.insert %scaleA_0_7, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
256235

257-
%sB_6_8 = vector.insert %scaleB_6_8, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
258-
%sB_6_9 = vector.insert %scaleB_6_9, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
259-
%sB_6_10 = vector.insert %scaleB_6_10, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
260-
%sB_6_11 = vector.insert %scaleB_6_11, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
261236
%sB_6_22 = vector.insert %scaleB_6_22, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
262237
%sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
263238
%sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
264239
%sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
265240

266-
%res_0 = amdgpu.scaled_mfma(%sA_0_0[0] * %opA) * (%sB_6_8[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
267-
%res_1 = amdgpu.scaled_mfma(%sA_0_1[0] * %opA) * (%sB_6_9[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
268-
%res_2 = amdgpu.scaled_mfma(%sA_0_2[0] * %opA) * (%sB_6_10[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
269-
%res_3 = amdgpu.scaled_mfma(%sA_0_3[0] * %opA) * (%sB_6_11[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
270241
%res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
271242
%res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
272243
%res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
273244
%res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
274-
return %res_0, %res_1, %res_2, %res_3, %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
245+
return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
275246
}

0 commit comments

Comments
 (0)