[MLIR][XeGPU] Lower vector.multi_reduction before linearization in XeGPUVectorLinearize#190272
[MLIR][XeGPU] Lower vector.multi_reduction before linearization in XeGPUVectorLinearize#190272nbpatel wants to merge 2 commits into
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/190272.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
index e31c37a2459ad..0b18f2ef49736 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp
@@ -51,6 +51,46 @@ struct XeGPUVectorLinearizePass final
return signalPassFailure();
}
+ // Lower vector.multi_reduction before linearization. Linearization flattens
+ // nD vectors to 1D, destroying axis information that multi_reduction relies
+ // on to know which elements to group together. By unrolling multi_reduction
+ // into row-wise shuffle + scalar reduction ops first, the IR contains only
+ // shape-agnostic ops by the time linearization runs.
+ //
+ // Two pattern sets are applied in order:
+ // 1. ReorderPatterns (InnerOuterDimReductionConversion): inserts
+ // vector.transpose to move all reduction dims to either the innermost
+ // or outermost positions. This normalizes arbitrary reductions into a
+ // canonical 2-D form that the unrolling patterns can handle.
+ // 2. UnrollingPatterns: with InnerParallel mode, the reduction dims are
+ // outermost, so the inner (parallel) dims are treated as rows and the
+ // outer loop is unrolled into a sequence of element-wise arith ops
+ // (TwoDimMultiReductionToElementWise). Any remaining 1-D
+ // multi_reduction is converted to vector.reduction
+ // (OneDimMultiReductionToReduction).
+ // Example: reduce 4x8 matrix along rows (dim 0):
+ // %0 = vector.multi_reduction <add>, %arg0, %acc [0]
+ // : vector<4x8xf32> to vector<8xf32>
+ // is unrolled into:
+ // %flat = vector.shape_cast %arg0 : vector<4x8xf32> to vector<32xf32>
+ // %s0 = vector.shuffle %flat, %flat [0, 1, 2, 3, 4, 5, 6, 7]
+ // : vector<32xf32>, vector<32xf32>
+ // %r0 = arith.addf %s0, %acc : vector<8xf32> // row 0 + acc
+ // %s1 = vector.shuffle %flat, %flat [8, 9, 10, 11, 12, 13, 14, 15]
+ // : vector<32xf32>, vector<32xf32>
+ // %r1 = arith.addf %s1, %r0 : vector<8xf32> // row 1 + r0
+ // ... // rows 2, 3
+ // These shape-agnostic ops are then safely linearized.
+ //
+ {
+ auto options = vector::VectorMultiReductionLowering::InnerParallel;
+ RewritePatternSet patterns(&getContext());
+ vector::populateVectorMultiReductionReorderPatterns(patterns, options);
+ vector::populateVectorMultiReductionUnrollingPatterns(patterns, options);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+
// Unroll load/store from <d1xd2x...xdk> to (d1*d2*...*d(k-1)) slices of
// <1x1x...x1xdk>.
{
diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
index 94205a6c26ba2..20b555d768337 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir
@@ -265,4 +265,22 @@ gpu.module @test_kernel {
}
}
+// -----
+// CHECK-LABEL: func.func @test_vector_multi_reduction_add
+// CHECK-SAME: (%[[ARG0:.*]]: vector<4x8xf32>, %[[ARG1:.*]]: vector<8xf32>) -> vector<8xf32>
+// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[ARG0]] : vector<4x8xf32> to vector<32xf32>
+// CHECK: %[[S0:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [0, 1, 2, 3, 4, 5, 6, 7]
+// CHECK: %[[R0:.*]] = arith.addf %[[S0]], %[[ARG1]] : vector<8xf32>
+// CHECK: %[[S1:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [8, 9, 10, 11, 12, 13, 14, 15]
+// CHECK: %[[R1:.*]] = arith.addf %[[S1]], %[[R0]] : vector<8xf32>
+// CHECK: %[[S2:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [16, 17, 18, 19, 20, 21, 22, 23]
+// CHECK: %[[R2:.*]] = arith.addf %[[S2]], %[[R1]] : vector<8xf32>
+// CHECK: %[[S3:.*]] = vector.shuffle %[[FLAT]], %[[FLAT]] [24, 25, 26, 27, 28, 29, 30, 31]
+// CHECK: %[[R3:.*]] = arith.addf %[[S3]], %[[R2]] : vector<8xf32>
+// CHECK: return %[[R3]] : vector<8xf32>
+func.func @test_vector_multi_reduction_add(%arg0: vector<4x8xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
+ %0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<4x8xf32> to vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
|
| // multi_reduction is converted to vector.reduction | ||
| // (OneDimMultiReductionToReduction). | ||
| // Example: reduce 4x8 matrix along rows (dim 0): | ||
| // %0 = vector.multi_reduction <add>, %arg0, %acc [0] |
There was a problem hiding this comment.
How come the source is 2D after sg distribution?
There was a problem hiding this comment.
For lane-local reductions, the SG→WI distribution pass lowers to vector.multi_reduction op.. since no cross-lane shuffle is needed..hence we end up with 2D source vectors in some cases
| auto options = vector::VectorMultiReductionLowering::InnerParallel; | ||
| RewritePatternSet patterns(&getContext()); | ||
| vector::populateVectorMultiReductionReorderPatterns(patterns, options); | ||
| vector::populateVectorMultiReductionUnrollingPatterns(patterns, options); |
There was a problem hiding this comment.
I think this should be a separate pattern in VectorLinearize.cpp. vector linearization should not depend on unrolling for multi reduce because other patterns are in VectorLinearize.cpp.
There was a problem hiding this comment.
why? there are already patterns for it in LowerVectorMultiReduction.cpp...we would have to implement similar patterns in vectorLinearize
| // CHECK: %[[R3:.*]] = arith.addf %[[S3]], %[[R2]] : vector<8xf32> | ||
| // CHECK: return %[[R3]] : vector<8xf32> | ||
| func.func @test_vector_multi_reduction_add(%arg0: vector<4x8xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { | ||
| %0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<4x8xf32> to vector<8xf32> |
There was a problem hiding this comment.
Does each lane after distribution really own 4x8 elements? Can we use a real use case example here? Just to help reason the motivation of need of this pass.
| // : vector<4x8xf32> to vector<8xf32> | ||
| // is unrolled into: | ||
| // %flat = vector.shape_cast %arg0 : vector<4x8xf32> to vector<32xf32> | ||
| // %s0 = vector.shuffle %flat, %flat [0, 1, 2, 3, 4, 5, 6, 7] |
There was a problem hiding this comment.
How this shuffle is further lowered after XeGPU in XeVM/LLVM level? I used to observe a long code sequence for them.
There was a problem hiding this comment.
Why not use vector.deinterleave to get each columns + then do reduction? I believe that is more effieicent. The shuffle is very general and uses constant offsets so loses the metadata information (subtile vs. tile).
|
|
||
| // ----- | ||
| // CHECK-LABEL: func.func @test_vector_multi_reduction_add | ||
| // CHECK-SAME: (%[[ARG0:.*]]: vector<16x1xf16>, %[[ARG1:.*]]: vector<1xf16>) -> vector<1xf16> |
There was a problem hiding this comment.
The code look very inefficient to me. Why it is not just removing the training dimension and use a vector.reduction of 16 elements?
|
closing this one for now |
No description provided.