Skip to content

Commit 9703bda

Browse files
authored
[mlir][xegpu] Add OptimizeBlockLoads pass. (#165483)
This pass rewrites certain xegpu `CreateNd` and `LoadNd` operations that feeds into `vector.transpose` to more optimal form to improve performance. Specifically, low precision (bitwidth < 32) `LoadNd` ops that feeds into transpose ops are rewritten to i32 loads with a valid transpose layout such that later passes can use the load with transpose HW feature to accelerate such load ops. **Update:** Pass is renamed to `OptimizeBlockLoads ` because later we plan to add the array length optimization into this pass as well. This will break down a larger load (like `32x32xf16`) into more DPAS-favorable array length loads (`32x16xf16` with array length = 2). Both these optmizations require rewriting `CreateNd` and `LoadNd` and it makes sense to have a common pass for both.
1 parent 2141edf commit 9703bda

File tree

8 files changed

+827
-30
lines changed

8 files changed

+827
-30
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,16 @@ def XeGPUVectorLinearize : Pass<"xegpu-vector-linearize"> {
8585
"scf::SCFDialect", "ub::UBDialect", "vector::VectorDialect"];
8686
}
8787

88+
def XeGPUOptimizeBlockLoads : Pass<"xegpu-optimize-block-loads"> {
89+
let summary = "Optimize XeGPU block load operations";
90+
let description = [{
91+
This pass rewrites XeGPU loadNd operations into more optimal forms
92+
to improve performance. This includes,
93+
- Rewriting transpose B loads into more optimal forms to use HW block
94+
transpose instructions for better performance.
95+
}];
96+
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
97+
"vector::VectorDialect"];
98+
}
99+
88100
#endif // MLIR_DIALECT_XEGPU_TRANSFORMS_PASSES_TD

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ struct UnrollOptions {
6161

6262
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
6363
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
64-
64+
/// Appends patterns for optimizing block load operations into `patterns`.
65+
void populateXeGPUOptimizeBlockLoadsPatterns(RewritePatternSet &patterns);
6566
/// Appends patterns for XeGPU SIMT distribution into `patterns`.
6667
void populateXeGPUSubgroupDistributePatterns(RewritePatternSet &patterns);
6768
/// Appends patterns for moving function body into gpu.warp_execute_on_lane0 op.

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,15 @@ SmallVector<OpFoldResult> addElementwise(OpBuilder &builder, Location loc,
166166
SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
167167
ArrayRef<OpFoldResult> lhs,
168168
ArrayRef<OpFoldResult> rhs);
169+
170+
/// Helper Function to find a proper instruction multiple for the user-supplied
171+
/// sg-level data shape (diven by `dim`). `candidates` are uArch allowed shapes.
172+
/// `candidateMultiples` are uArch multiples of such shapes (i.e. block count or
173+
/// array length).
174+
template <typename T>
175+
int getLargestDivisor(T dim, ArrayRef<T> candidates,
176+
ArrayRef<T> candidateMultiples = {});
177+
169178
} // namespace xegpu
170179

171180
} // namespace mlir

mlir/lib/Dialect/XeGPU/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRXeGPUTransforms
66
XeGPUWgToSgDistribute.cpp
77
XeGPUPropagateLayout.cpp
88
XeGPUVectorLinearize.cpp
9+
XeGPUOptimizeBlockLoads.cpp
910

1011
ADDITIONAL_HEADER_DIRS
1112
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/XeGPU

mlir/lib/Dialect/XeGPU/Transforms/XeGPUOptimizeBlockLoads.cpp

Lines changed: 490 additions & 0 deletions
Large diffs are not rendered by default.

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -204,28 +204,6 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
204204
using Lattice::Lattice;
205205
};
206206

207-
/// Helper Function to find a proper instruction multiple for the user-supplied
208-
/// sg-level data shape. `candidates` are uArch allowed shapes.
209-
/// `candidateMultiples` are uArch multiples of such shapes (e.g., block count).
210-
template <typename T>
211-
int getLargestDivisor(T dim, ArrayRef<T> candidates,
212-
ArrayRef<T> candidateMultiples = {}) {
213-
static_assert(std::is_integral<T>::value, "T must be an integer type");
214-
int largest = -1;
215-
SmallVector<T> multiples = {1};
216-
if (!candidateMultiples.empty())
217-
multiples =
218-
SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
219-
for (T candidate : candidates) {
220-
for (T multiple : multiples) {
221-
int value = static_cast<int>(candidate * multiple);
222-
if (value != 0 && dim % value == 0 && value > largest)
223-
largest = value;
224-
}
225-
}
226-
return largest;
227-
}
228-
229207
/// Helper Functions to get default layouts. A `default layout` is a layout that
230208
/// is assigned to a value when the layout is not fixed by some anchor operation
231209
/// (like DPAS).
@@ -505,7 +483,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
505483
prefetch.emitWarning("No known block params found for the element type.");
506484
auto [bWidth, bHeight, bCount] = blockWHC.value();
507485
SmallVector<int> instData;
508-
int instWidth = getLargestDivisor(
486+
int instWidth = xegpu::getLargestDivisor(
509487
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 1)), bWidth,
510488
bCount);
511489
if (instWidth == -1)
@@ -514,7 +492,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
514492
if (tdescTy.getRank() == 1)
515493
instData = {instWidth};
516494
else {
517-
int instHeight = getLargestDivisor(
495+
int instHeight = xegpu::getLargestDivisor(
518496
static_cast<int>(tdescTy.getDimSize(tdescTy.getRank() - 2)), bHeight);
519497
if (instHeight == -1)
520498
prefetch.emitWarning(
@@ -634,15 +612,15 @@ void LayoutInfoPropagation::visitDpasOp(
634612
const unsigned dataALen = aTy.getShape().front();
635613
auto supportedALen = uArchInstruction->getSupportedM(aTy.getElementType());
636614
const int maxALen =
637-
getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
615+
xegpu::getLargestDivisor(dataALen, ArrayRef<unsigned>(supportedALen));
638616
if (maxALen == -1)
639617
dpas.emitWarning(
640618
"No suitable instruction multiple found for the given shape.");
641619

642620
const unsigned dataBLen = bTy.getShape().back();
643621
auto supportedBLen = uArchInstruction->getSupportedK(bTy.getElementType());
644622
const int maxBLen =
645-
getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
623+
xegpu::getLargestDivisor(dataBLen, ArrayRef<unsigned>(supportedBLen));
646624
if (maxBLen == -1)
647625
dpas.emitWarning(
648626
"No suitable instruction multiple found for the given shape.");
@@ -662,7 +640,7 @@ void LayoutInfoPropagation::visitDpasOp(
662640
const unsigned dataCLen = bTy.getShape().back();
663641
auto supportedCLen = uArchInstruction->getSupportedN(bTy.getElementType());
664642
const int maxCLen =
665-
getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
643+
xegpu::getLargestDivisor(dataCLen, ArrayRef<unsigned>(supportedCLen));
666644
if (maxCLen == -1)
667645
dpas.emitWarning(
668646
"No suitable instruction multiple found for the given shape.");
@@ -691,7 +669,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
691669
store.emitWarning("No known block params found for the element type.");
692670
auto [bWidth, bHeight, bCount] = blockWHC.value();
693671
SmallVector<int> instData;
694-
int instWidth = getLargestDivisor(
672+
int instWidth = xegpu::getLargestDivisor(
695673
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 1)), bWidth,
696674
bCount);
697675
if (instWidth == -1)
@@ -700,7 +678,7 @@ void LayoutInfoPropagation::visitStoreNdOp(
700678
if (dataTy.getRank() == 1)
701679
instData = {instWidth};
702680
else {
703-
int instHeight = getLargestDivisor(
681+
int instHeight = xegpu::getLargestDivisor(
704682
static_cast<int>(dataTy.getDimSize(dataTy.getRank() - 2)), bHeight);
705683
if (instHeight == -1)
706684
store.emitWarning(

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,29 @@ xegpu::addWithRightAligned(OpBuilder &builder, Location loc,
555555
results.append(addElementwise(builder, loc, a, b));
556556
return results;
557557
}
558+
559+
template <typename T>
560+
int xegpu::getLargestDivisor(T dim, ArrayRef<T> candidates,
561+
ArrayRef<T> candidateMultiples) {
562+
static_assert(std::is_integral<T>::value, "T must be an integer type");
563+
int largest = -1;
564+
SmallVector<T> multiples = {1};
565+
if (!candidateMultiples.empty())
566+
multiples =
567+
SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
568+
for (T candidate : candidates) {
569+
for (T multiple : multiples) {
570+
int value = static_cast<int>(candidate * multiple);
571+
if (value != 0 && dim % value == 0 && value > largest)
572+
largest = value;
573+
}
574+
}
575+
return largest;
576+
}
577+
578+
/// Explicit instantiations
579+
template int xegpu::getLargestDivisor<int>(int dim, ArrayRef<int> candidates,
580+
ArrayRef<int> candidateMultiples);
581+
template int
582+
xegpu::getLargestDivisor<unsigned>(unsigned dim, ArrayRef<unsigned> candidates,
583+
ArrayRef<unsigned> candidateMultiples);

0 commit comments

Comments
 (0)