diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp index e31c37a2459ad..0b18f2ef49736 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUVectorLinearize.cpp @@ -51,6 +51,46 @@ struct XeGPUVectorLinearizePass final return signalPassFailure(); } + // Lower vector.multi_reduction before linearization. Linearization flattens + // nD vectors to 1D, destroying axis information that multi_reduction relies + // on to know which elements to group together. By unrolling multi_reduction + // into row-wise shuffle + scalar reduction ops first, the IR contains only + // shape-agnostic ops by the time linearization runs. + // + // Two pattern sets are applied in order: + // 1. ReorderPatterns (InnerOuterDimReductionConversion): inserts + // vector.transpose to move all reduction dims to either the innermost + // or outermost positions. This normalizes arbitrary reductions into a + // canonical 2-D form that the unrolling patterns can handle. + // 2. UnrollingPatterns: with InnerParallel mode, the reduction dims are + // outermost, so the inner (parallel) dims are treated as rows and the + // outer loop is unrolled into a sequence of element-wise arith ops + // (TwoDimMultiReductionToElementWise). Any remaining 1-D + // multi_reduction is converted to vector.reduction + // (OneDimMultiReductionToReduction). + // Example: reduce 4x8 matrix along rows (dim 0): + // %0 = vector.multi_reduction , %arg0, %acc [0] + // : vector<4x8xf32> to vector<8xf32> + // is unrolled into: + // %flat = vector.shape_cast %arg0 : vector<4x8xf32> to vector<32xf32> + // %s0 = vector.shuffle %flat, %flat [0, 1, 2, 3, 4, 5, 6, 7] + // : vector<32xf32>, vector<32xf32> + // %r0 = arith.addf %s0, %acc : vector<8xf32> // row 0 + acc + // %s1 = vector.shuffle %flat, %flat [8, 9, 10, 11, 12, 13, 14, 15] + // : vector<32xf32>, vector<32xf32> + // %r1 = arith.addf %s1, %r0 : vector<8xf32> // row 1 + r0 + // ... // rows 2, 3 + // These shape-agnostic ops are then safely linearized. + // + { + auto options = vector::VectorMultiReductionLowering::InnerParallel; + RewritePatternSet patterns(&getContext()); + vector::populateVectorMultiReductionReorderPatterns(patterns, options); + vector::populateVectorMultiReductionUnrollingPatterns(patterns, options); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } + // Unroll load/store from to (d1*d2*...*d(k-1)) slices of // <1x1x...x1xdk>. { diff --git a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir index 94205a6c26ba2..bac1e838b038e 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-vector-linearize.mlir @@ -265,4 +265,21 @@ gpu.module @test_kernel { } } +// ----- +// CHECK-LABEL: func.func @test_vector_multi_reduction_add +// CHECK-SAME: (%[[ARG0:.*]]: vector<16x1xf16>, %[[ARG1:.*]]: vector<1xf16>) -> vector<1xf16> +// CHECK: %[[FLAT:.*]] = vector.shape_cast %[[ARG0]] : vector<16x1xf16> to vector<16xf16> +// CHECK: vector.shuffle %[[FLAT]], %[[FLAT]] [0] : vector<16xf16>, vector<16xf16> +// CHECK: arith.addf {{.*}}, %[[ARG1]] : vector<1xf16> +// 14 more shuffle+addf pairs for indices 1..14 +// CHECK-COUNT-14: vector.shuffle %[[FLAT]], %[[FLAT]] {{.*}} : vector<16xf16>, vector<16xf16> +// Final shuffle (index 15) + addf + return +// CHECK: vector.shuffle %[[FLAT]], %[[FLAT]] [15] : vector<16xf16>, vector<16xf16> +// CHECK: %[[LAST:.*]] = arith.addf +// CHECK: return %[[LAST]] : vector<1xf16> +func.func @test_vector_multi_reduction_add(%arg0: vector<16x1xf16>, %arg1: vector<1xf16>) -> vector<1xf16> { + %0 = vector.multi_reduction , %arg0, %arg1 [0] : vector<16x1xf16> to vector<1xf16> + return %0 : vector<1xf16> +} +