[mlir][vector] Add multi_reduction_flattening#181244
Merged
amd-eochoalo merged 7 commits intoFeb 18, 2026
Merged
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
62b630a to
4f92de4
Compare
Member
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Erick Ochoa Lopez (amd-eochoalo) ChangesAdds tests for Full diff: https://github.com/llvm/llvm-project/pull/181244.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 462c61df72108..6eb96e2a8fdab 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -266,6 +266,23 @@ def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
}];
}
+def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
+ "apply_patterns.vector.multi_reduction_flattening",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that vector multi_reduction operations should be flattened from
+ more than 2-D to 2-D.
+ }];
+
+ let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+ "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+ );
+
+ let assemblyFormat = [{
+ (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+ }];
+}
+
def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.lower_outerproduct",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 4e2b97aa07084..f3529ac26523f 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,6 +146,14 @@ void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
patterns, vectorTransformOptions.vectorMultiReductionLowering);
}
+void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::VectorTransformsOptions vectorTransformOptions;
+ vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+ vector::populateVectorMultiReductionFlatteningPatterns(
+ patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2d6a49bad27bc..fec04c967c9e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -287,7 +287,8 @@ class ReduceMultiDimReductionRank
return success();
}
- // 8. Creates shape cast for the output n-D -> 2-D.
+ // 8. Shape cast the flattened result back to the original n-D parallel
+ // shape.
VectorType outputCastedType = VectorType::get(
parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
parallelScalableDims);
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
new file mode 100644
index 0000000000000..b8f970912909b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -0,0 +1,122 @@
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefix=INNER_REDUCTION,ALL
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefix=INNER_PARALLEL,ALL
+
+// ALL-LABEL: func @negative_flattening_cases
+func.func @negative_flattening_cases(
+ %v1d: vector<8xf32>,
+ %v2d: vector<4x8xf32>,
+ %v_scalable: vector<[2]x[4]x8xf32>,
+ %v_non_contig: vector<2x3x4x5xi32>,
+ %acc_scalar: f32,
+ %acc_1d: vector<8xf32>,
+ %acc_2d: vector<2x4xi32>) -> (f32, vector<8xf32>, vector<8xf32>, vector<2x4xi32>) {
+
+ // Test 1: Less than 2 dimensions
+ // ALL: %[[R1:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [0] : vector<8xf32> to f32
+ %r1 = vector.multi_reduction <add>, %v1d, %acc_scalar [0] : vector<8xf32> to f32
+
+ // Test 2: More than one scalable dimensions
+ // ALL: %[[R2:.+]] = vector.multi_reduction <mul>, %{{.+}}, %{{.+}} [0, 1] : vector<[2]x[4]x8xf32> to vector<8xf32>
+ %r2 = vector.multi_reduction <mul>, %v_scalable, %acc_1d [0, 1] : vector<[2]x[4]x8xf32> to vector<8xf32>
+
+ // Test 3: Already 2D with reduction on single dim
+ // ALL: %[[R3:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [0] : vector<4x8xf32> to vector<8xf32>
+ %r3 = vector.multi_reduction <add>, %v2d, %acc_1d [0] : vector<4x8xf32> to vector<8xf32>
+
+ // Test 4: Non-contiguous parallel dimensions
+ // ALL: %[[R4:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [1, 3] : vector<2x3x4x5xi32> to vector<2x4xi32>
+ %r4 = vector.multi_reduction <add>, %v_non_contig, %acc_2d [1, 3] : vector<2x3x4x5xi32> to vector<2x4xi32>
+
+ // ALL: return %[[R1]], %[[R2]], %[[R3]], %[[R4]]
+ return %r1, %r2, %r3, %r4 : f32, vector<8xf32>, vector<8xf32>, vector<2x4xi32>
+}
+
+// ALL-LABEL: func @vector_multi_reduction_flattening
+// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+func.func @vector_multi_reduction_flattening(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+ // ALL: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+ // ALL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ // ALL: return %[[RESULT]]
+ return %0 : f32
+}
+
+// INNER_REDUCTION-LABEL: func @vector_multi_reduction_parallel_dim_innerreduction
+// INNER_REDUCTION-SAME: %[[INPUT:.+]]: vector<2x3x4xi32>
+// INNER_REDUCTION-SAME: %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim_innerreduction(%arg0: vector<2x3x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ // INNER_REDUCTION: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4xi32> to vector<2x12xi32>
+ // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [1]
+ %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4xi32> to vector<2xi32>
+ // INNER_REDUCTION: return %[[RESULT]]
+ return %0 : vector<2xi32>
+}
+
+// INNER_REDUCTION-LABEL: func @output_shapecast_multiple_parallel
+// INNER_REDUCTION-SAME: %[[INPUT:.+]]: vector<2x3x4x5x6xi32>
+// INNER_REDUCTION-SAME: %[[ACC:.+]]: vector<2x3x4xi32>
+func.func @output_shapecast_multiple_parallel(%arg0: vector<2x3x4x5x6xi32>, %acc: vector<2x3x4xi32>) -> vector<2x3x4xi32> {
+ // INNER_REDUCTION: %[[INPUT_CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5x6xi32> to vector<24x30xi32>
+ // INNER_REDUCTION: %[[ACC_CAST:.+]] = vector.shape_cast %[[ACC]] : vector<2x3x4xi32> to vector<24xi32>
+ // INNER_REDUCTION: %[[RESULT_FLAT:.+]] = vector.multi_reduction <mul>, %[[INPUT_CAST]], %[[ACC_CAST]] [1]
+ // INNER_REDUCTION: %[[RESULT:.+]] = vector.shape_cast %[[RESULT_FLAT]] : vector<24xi32> to vector<2x3x4xi32>
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [3, 4] : vector<2x3x4x5x6xi32> to vector<2x3x4xi32>
+ // INNER_REDUCTION: return %[[RESULT]]
+ return %0 : vector<2x3x4xi32>
+}
+
+// INNER_PARALLEL-LABEL: func @vector_multi_reduction_parallel_dim_innerparallel
+// INNER_PARALLEL-SAME: %[[INPUT:.+]]: vector<3x4x2xi32>
+// INNER_PARALLEL-SAME: %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim_innerparallel(%arg0: vector<3x4x2xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+ // INNER_PARALLEL: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<3x4x2xi32> to vector<12x2xi32>
+ // INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+ %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<3x4x2xi32> to vector<2xi32>
+ // INNER_PARALLEL: return %[[RESULT]]
+ return %0 : vector<2xi32>
+}
+
+// ALL-LABEL: func @single_scalable_dim
+// ALL-SAME: %[[INPUT:.+]]: vector<4x[8]xf32>
+// ALL-SAME: %[[ACC:.+]]: f32
+func.func @single_scalable_dim(%arg0: vector<4x[8]xf32>, %acc: f32) -> f32 {
+ // ALL: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<4x[8]xf32> to vector<[32]xf32>
+ // ALL: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [0]
+ %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<4x[8]xf32> to f32
+ // ALL: return %[[RESULT]]
+ return %0 : f32
+}
+
+// ALL-LABEL: func @masked_multi_reduction
+// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>
+// ALL-SAME: %[[ACC:.+]]: f32
+// ALL-SAME: %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: f32, %mask: vector<2x4xi1>) -> f32 {
+ // ALL: %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<2x4xi1> to vector<8xi1>
+ // ALL: %[[CASTED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+ // ALL: %[[RESULT:.+]] = vector.mask %[[CASTED_MASK]]
+ // ALL: vector.multi_reduction <mul>, %[[CASTED_INPUT]], %[[ACC]] [0]
+ %0 = vector.mask %mask {
+ vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+ } : vector<2x4xi1> -> f32
+ // ALL: return %[[RESULT]]
+ return %0 : f32
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+
+ transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+ } : !transform.op<"func.func">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 3ce9f57edf9d6..6b79a78e6a42a 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -19,18 +19,6 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
// CHECK: %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
// CHECK: return %[[RESULT_VEC]]
-// Patterns applied:
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
- %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
- return %0 : f32
-}
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
-// CHECK: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-// CHECK: %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
-// CHECK: return %[[REDUCED]]
-
// Patterns applied:
// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
// * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 33adb55456475..d0ab71e3f400f 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -181,13 +181,6 @@ func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 {
// CHECK-LABEL: func @vector_reduction_1D
// CHECK: return %{{.+}}
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
- %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
- return %0 : f32
-}
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-// CHECK: return %{{.+}}
-
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 0f9aab29ca5ca..29ce8ba63cd53 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -108,6 +108,14 @@ def enum_configurable_patterns():
lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
)
+ # CHECK: transform.apply_patterns.vector.multi_reduction_flattening
+ vector.ApplyMultiReductionFlatteningPatternsOp()
+ # CHECK: transform.apply_patterns.vector.multi_reduction_flattening
+ # CHECK-SAME: lowering_strategy = innerreduction
+ vector.ApplyMultiReductionFlatteningPatternsOp(
+ lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
+ )
+
# CHECK: transform.apply_patterns.vector.lower_transpose
vector.ApplyLowerTransposePatternsOp()
# CHECK: transform.apply_patterns.vector.lower_transpose
|
kuhar
approved these changes
Feb 18, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
populateVectorMultiReductionFlatteningPatternsThis follows PR #180977.
Assisted-by: claude