-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][AMDGPU] Add canonicalization pattern to pack scales for ScaledMFMAOp #155951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
krzysz00
merged 7 commits into
llvm:main
from
Muzammiluddin-Syed-ECE:muzasyed/packScales
Sep 18, 2025
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
2b6d917
Add packing of scales for ScaledMFMAOp
Muzammiluddin-Syed-ECE 3873eda
PR review round 0
Muzammiluddin-Syed-ECE 2404d99
PR review round 1
Muzammiluddin-Syed-ECE 970aa1a
Perform packing for inputs with shapes non-divisible by 4
Muzammiluddin-Syed-ECE ab6b1ae
PR Review round 2
Muzammiluddin-Syed-ECE 9d8ffbd
PR Review round 3
Muzammiluddin-Syed-ECE f4ca565
minor fixups
Muzammiluddin-Syed-ECE File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" | ||
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Diagnostics.h" | ||
|
@@ -26,8 +27,11 @@ | |
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "llvm/ADT/DenseMap.h" | ||
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
|
||
#include <algorithm> | ||
#include <cstdint> | ||
#include <limits> | ||
#include <optional> | ||
|
||
|
@@ -631,6 +635,139 @@ LogicalResult TransposeLoadOp::verify() { | |
return success(); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ScaledMFMAOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
namespace { | ||
/// Check if the scales input is used in other scaled mfma's while they exist. | ||
/// If theyre unused then pack the scales. | ||
struct PackScales final : OpRewritePattern<ScaledMFMAOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ScaledMFMAOp op, | ||
PatternRewriter &rewriter) const override { | ||
Location loc = op.getLoc(); | ||
auto setOpsel = [&op](unsigned idx, int64_t val) { | ||
switch (idx) { | ||
case 3: | ||
op.setScalesIdxA(val); | ||
break; | ||
case 4: | ||
op.setScalesIdxB(val); | ||
break; | ||
default: | ||
break; | ||
} | ||
}; | ||
|
||
// For every scale operand of this ScaledMFMAOp, if the scale is produced by | ||
// the extraction of a single scale from some vector, then attempt to | ||
// extract 4 values from that vector instead. | ||
// | ||
// Example: (f8 here means f8E8M0FNU) | ||
// %unit = vector.extract %ScaleSrc[offsets] : f8 from vector<...> | ||
// %scale = vector.insert %unit, ... : f8 into vector<4xf8> | ||
// amdgpu.scaled_mfma(%scale[0] * ... | ||
// | ||
// rewrite to: | ||
// | ||
// %reshaped = vector.shape_cast %ScaleSrc : vector<...> to vector<?xf8> | ||
// %scale = vector.extract %reshaped[?] : vector<4xf8> from vector<?xf8> | ||
// amdgpu.scaled_mfma(%scale[0-3] * ... | ||
// | ||
// This creates duplicate shape_casts for every use but these will be | ||
// removed in CSE. | ||
for (auto opIdx : std::array<int64_t, 2>({3, 4})) { | ||
auto insertOp = op.getOperand(opIdx).getDefiningOp<vector::InsertOp>(); | ||
if (!insertOp) { | ||
return rewriter.notifyMatchFailure(op, | ||
"defining op not a vector.insert"); | ||
} | ||
// If the extracted value is not a single scalar, then it has been packed. | ||
if (isa<VectorType>(insertOp.getValueToStore().getType())) { | ||
return rewriter.notifyMatchFailure( | ||
op, "scaled mfma operand already packed"); | ||
} | ||
|
||
auto extractOp = | ||
insertOp.getValueToStore().getDefiningOp<vector::ExtractOp>(); | ||
if (!extractOp) { | ||
return rewriter.notifyMatchFailure(op, | ||
"defining op not a vector.extract"); | ||
} | ||
|
||
Value scaleSrc = extractOp.getOperand(0); | ||
auto scaleSrcType = dyn_cast<VectorType>(scaleSrc.getType()); | ||
if (!scaleSrcType) { | ||
return rewriter.notifyMatchFailure(op, "not a vector type"); | ||
} | ||
|
||
// We do not handle dynamic dims yet, assume that the input is padded to | ||
// a static shape now. | ||
if (!scaleSrcType.hasStaticShape()) { | ||
return rewriter.notifyMatchFailure(op, | ||
"dynamic dims not yet supported"); | ||
} | ||
|
||
int64_t numElements = scaleSrcType.getNumElements(); | ||
if (numElements <= 4) { | ||
return rewriter.notifyMatchFailure( | ||
op, "no packing if # of scales less than four"); | ||
} | ||
|
||
// Find a linearized idx using the size and offsets of the extract op. | ||
auto extractedPos = llvm::to_vector_of<int64_t>( | ||
llvm::reverse(extractOp.getStaticPosition())); | ||
ArrayRef<int64_t> scaleSrcShape = scaleSrcType.getShape(); | ||
int64_t scaleSrcRank = scaleSrcType.getRank(); | ||
SmallVector<int64_t> extractSizes(scaleSrcRank, 1); | ||
for (int64_t i = 1; i < scaleSrcRank; ++i) { | ||
extractSizes[i] = extractSizes[i - 1] * scaleSrcShape[scaleSrcRank - i]; | ||
} | ||
int64_t idx = linearize(extractedPos, extractSizes); | ||
|
||
// All n scales (where n is the total number of scales) must now be | ||
// extracted in chunks of 4 elements. This is done by dividing the | ||
// original vector of scales into groups of 4 elements | ||
// at offsets 0, 4, ..., m (where m = n/4). All extractions of a | ||
// scale at a particular index are now replaced with an extraction | ||
// of the entire group of 4 elements to which that index belongs. | ||
// | ||
// If the number of scales happens to be indivisible by 4, extract | ||
// the remaining n - m scales in a chunk of 4 elements starting at | ||
// offset n - 4. | ||
int64_t offset = idx - (idx % 4); | ||
int64_t opsel = idx - offset; | ||
int64_t size = 4l; | ||
// Accomdate remaining elements in the case of non-4-divisible vectors. | ||
if (numElements - offset < size) { | ||
opsel = size - (numElements - idx); | ||
offset = numElements - 4l; | ||
} | ||
Type scaleSrcElemType = scaleSrcType.getElementType(); | ||
auto newSrcType = VectorType::get(SmallVector<int64_t>({numElements}), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't need SmallVector here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, probably a good trivial followup once commit access goes through, good catch |
||
scaleSrcElemType); | ||
Value newScaleSrc = | ||
vector::ShapeCastOp::create(rewriter, loc, newSrcType, scaleSrc); | ||
auto extract = vector::ExtractStridedSliceOp::create( | ||
rewriter, loc, newScaleSrc, ArrayRef<int64_t>{offset}, | ||
ArrayRef<int64_t>{size}, ArrayRef<int64_t>{1}); | ||
rewriter.modifyOpInPlace(op, [&] { | ||
op->setOperand(opIdx, extract); | ||
setOpsel(opIdx, opsel); | ||
}); | ||
} | ||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void ScaledMFMAOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
results.add<PackScales>(context); | ||
} | ||
|
||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" | ||
|
||
#define GET_ATTRDEF_CLASSES | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.