Skip to content

[mlir][vector] Add multi_reduction_flattening#181244

Merged
amd-eochoalo merged 7 commits into
llvm:mainfrom
amd-eochoalo:eochoa/2026-02-12/flattening-tests
Feb 18, 2026
Merged

[mlir][vector] Add multi_reduction_flattening#181244
amd-eochoalo merged 7 commits into
llvm:mainfrom
amd-eochoalo:eochoa/2026-02-12/flattening-tests

Conversation

@amd-eochoalo
Copy link
Copy Markdown
Contributor

@amd-eochoalo amd-eochoalo commented Feb 12, 2026

  • Adds tests for populateVectorMultiReductionFlatteningPatterns
  • Add apply_patterns.vector.multi_reduction_flattening transform op.

This follows PR #180977.

Assisted-by: claude

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Feb 12, 2026

✅ With the latest revision this PR passed the C/C++ code formatter.

@amd-eochoalo amd-eochoalo force-pushed the eochoa/2026-02-12/flattening-tests branch from 62b630a to 4f92de4 Compare February 18, 2026 15:26
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 18, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

Adds tests for populateVectorMultiReductionFlatteningPatterns.


Full diff: https://github.com/llvm/llvm-project/pull/181244.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+17)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+8)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+2-1)
  • (added) mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir (+122)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (-12)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir (-7)
  • (modified) mlir/test/python/dialects/transform_vector_ext.py (+8)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index 462c61df72108..6eb96e2a8fdab 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -266,6 +266,23 @@ def ApplyReorderAndExpandMultiReductionPatternsOp: Op<Transform_Dialect,
   }];
 }
 
+def ApplyMultiReductionFlatteningPatternsOp: Op<Transform_Dialect,
+    "apply_patterns.vector.multi_reduction_flattening",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector multi_reduction operations should be flattened from
+    more than 2-D to 2-D.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = [{
+    (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+  }];
+}
+
 def ApplyLowerOuterProductPatternsOp : Op<Transform_Dialect,
     "apply_patterns.vector.lower_outerproduct",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 4e2b97aa07084..f3529ac26523f 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -146,6 +146,14 @@ void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorMultiReductionFlatteningPatterns(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2d6a49bad27bc..fec04c967c9e1 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -287,7 +287,8 @@ class ReduceMultiDimReductionRank
       return success();
     }
 
-    // 8. Creates shape cast for the output n-D -> 2-D.
+    // 8. Shape cast the flattened result back to the original n-D parallel
+    // shape.
     VectorType outputCastedType = VectorType::get(
         parallelShapes, multiReductionOp.getSourceVectorType().getElementType(),
         parallelScalableDims);
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
new file mode 100644
index 0000000000000..b8f970912909b
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-flattening.mlir
@@ -0,0 +1,122 @@
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerreduction' | FileCheck %s --check-prefix=INNER_REDUCTION,ALL
+// RUN: mlir-opt %s --transform-interpreter='entry-point=innerparallel' | FileCheck %s --check-prefix=INNER_PARALLEL,ALL
+
+// ALL-LABEL: func @negative_flattening_cases
+func.func @negative_flattening_cases(
+    %v1d: vector<8xf32>,
+    %v2d: vector<4x8xf32>,
+    %v_scalable: vector<[2]x[4]x8xf32>,
+    %v_non_contig: vector<2x3x4x5xi32>,
+    %acc_scalar: f32,
+    %acc_1d: vector<8xf32>,
+    %acc_2d: vector<2x4xi32>) -> (f32, vector<8xf32>, vector<8xf32>, vector<2x4xi32>) {
+
+    // Test 1: Less than 2 dimensions
+    // ALL: %[[R1:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [0] : vector<8xf32> to f32
+    %r1 = vector.multi_reduction <add>, %v1d, %acc_scalar [0] : vector<8xf32> to f32
+
+    // Test 2: More than one scalable dimensions
+    // ALL: %[[R2:.+]] = vector.multi_reduction <mul>, %{{.+}}, %{{.+}} [0, 1] : vector<[2]x[4]x8xf32> to vector<8xf32>
+    %r2 = vector.multi_reduction <mul>, %v_scalable, %acc_1d [0, 1] : vector<[2]x[4]x8xf32> to vector<8xf32>
+
+    // Test 3: Already 2D with reduction on single dim
+    // ALL: %[[R3:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [0] : vector<4x8xf32> to vector<8xf32>
+    %r3 = vector.multi_reduction <add>, %v2d, %acc_1d [0] : vector<4x8xf32> to vector<8xf32>
+
+    // Test 4: Non-contiguous parallel dimensions
+    // ALL: %[[R4:.+]] = vector.multi_reduction <add>, %{{.+}}, %{{.+}} [1, 3] : vector<2x3x4x5xi32> to vector<2x4xi32>
+    %r4 = vector.multi_reduction <add>, %v_non_contig, %acc_2d [1, 3] : vector<2x3x4x5xi32> to vector<2x4xi32>
+
+    // ALL: return %[[R1]], %[[R2]], %[[R3]], %[[R4]]
+    return %r1, %r2, %r3, %r4 : f32, vector<8xf32>, vector<8xf32>, vector<2x4xi32>
+}
+
+// ALL-LABEL: func @vector_multi_reduction_flattening
+// ALL-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
+func.func @vector_multi_reduction_flattening(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
+    // ALL: %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+    // ALL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+    // ALL: return %[[RESULT]]
+    return %0 : f32
+}
+
+// INNER_REDUCTION-LABEL: func @vector_multi_reduction_parallel_dim_innerreduction
+// INNER_REDUCTION-SAME:    %[[INPUT:.+]]: vector<2x3x4xi32>
+// INNER_REDUCTION-SAME:    %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim_innerreduction(%arg0: vector<2x3x4xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    // INNER_REDUCTION: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4xi32> to vector<2x12xi32>
+    // INNER_REDUCTION: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [1]
+    %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4xi32> to vector<2xi32>
+    // INNER_REDUCTION: return %[[RESULT]]
+    return %0 : vector<2xi32>
+}
+
+// INNER_REDUCTION-LABEL: func @output_shapecast_multiple_parallel
+// INNER_REDUCTION-SAME:    %[[INPUT:.+]]: vector<2x3x4x5x6xi32>
+// INNER_REDUCTION-SAME:    %[[ACC:.+]]: vector<2x3x4xi32>
+func.func @output_shapecast_multiple_parallel(%arg0: vector<2x3x4x5x6xi32>, %acc: vector<2x3x4xi32>) -> vector<2x3x4xi32> {
+    // INNER_REDUCTION: %[[INPUT_CAST:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5x6xi32> to vector<24x30xi32>
+    // INNER_REDUCTION: %[[ACC_CAST:.+]] = vector.shape_cast %[[ACC]] : vector<2x3x4xi32> to vector<24xi32>
+    // INNER_REDUCTION: %[[RESULT_FLAT:.+]] = vector.multi_reduction <mul>, %[[INPUT_CAST]], %[[ACC_CAST]] [1]
+    // INNER_REDUCTION: %[[RESULT:.+]] = vector.shape_cast %[[RESULT_FLAT]] : vector<24xi32> to vector<2x3x4xi32>
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [3, 4] : vector<2x3x4x5x6xi32> to vector<2x3x4xi32>
+    // INNER_REDUCTION: return %[[RESULT]]
+    return %0 : vector<2x3x4xi32>
+}
+
+// INNER_PARALLEL-LABEL: func @vector_multi_reduction_parallel_dim_innerparallel
+// INNER_PARALLEL-SAME:    %[[INPUT:.+]]: vector<3x4x2xi32>
+// INNER_PARALLEL-SAME:    %[[ACC:.+]]: vector<2xi32>
+func.func @vector_multi_reduction_parallel_dim_innerparallel(%arg0: vector<3x4x2xi32>, %acc: vector<2xi32>) -> vector<2xi32> {
+    // INNER_PARALLEL: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<3x4x2xi32> to vector<12x2xi32>
+    // INNER_PARALLEL: %[[RESULT:.+]] = vector.multi_reduction <mul>, %[[CASTED]], %[[ACC]] [0]
+    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<3x4x2xi32> to vector<2xi32>
+    // INNER_PARALLEL: return %[[RESULT]]
+    return %0 : vector<2xi32>
+}
+
+// ALL-LABEL: func @single_scalable_dim
+// ALL-SAME:    %[[INPUT:.+]]: vector<4x[8]xf32>
+// ALL-SAME:    %[[ACC:.+]]: f32
+func.func @single_scalable_dim(%arg0: vector<4x[8]xf32>, %acc: f32) -> f32 {
+    // ALL: %[[CASTED:.+]] = vector.shape_cast %[[INPUT]] : vector<4x[8]xf32> to vector<[32]xf32>
+    // ALL: %[[RESULT:.+]] = vector.multi_reduction <add>, %[[CASTED]], %[[ACC]] [0]
+    %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<4x[8]xf32> to f32
+    // ALL: return %[[RESULT]]
+    return %0 : f32
+}
+
+// ALL-LABEL: func @masked_multi_reduction
+// ALL-SAME:    %[[INPUT:.+]]: vector<2x4xf32>
+// ALL-SAME:    %[[ACC:.+]]: f32
+// ALL-SAME:    %[[MASK:.+]]: vector<2x4xi1>
+func.func @masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: f32, %mask: vector<2x4xi1>) -> f32 {
+    // ALL: %[[CASTED_MASK:.+]] = vector.shape_cast %[[MASK]] : vector<2x4xi1> to vector<8xi1>
+    // ALL: %[[CASTED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
+    // ALL: %[[RESULT:.+]] = vector.mask %[[CASTED_MASK]]
+    // ALL:   vector.multi_reduction <mul>, %[[CASTED_INPUT]], %[[ACC]] [0]
+    %0 = vector.mask %mask {
+      vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
+    } : vector<2x4xi1> -> f32
+    // ALL: return %[[RESULT]]
+    return %0 : f32
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+
+  transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
+    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
+    transform.apply_patterns to %func_op {
+      transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
+    } : !transform.op<"func.func">
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 3ce9f57edf9d6..6b79a78e6a42a 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -19,18 +19,6 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insert %[[RV1:.+]], %[[RESULT_VEC_1]] [1] : f32 into vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
-// Patterns applied:
-// * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
-    %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
-    return %0 : f32
-}
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-//  CHECK-SAME:   %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.*]]: f32)
-//       CHECK:   %[[CASTED:.*]] = vector.shape_cast %[[INPUT]] : vector<2x4xf32> to vector<8xf32>
-//       CHECK:   %[[REDUCED:.*]] = vector.reduction <mul>, %[[CASTED]], %[[ACC]] : vector<8xf32> into f32
-//       CHECK:   return %[[REDUCED]]
-
 // Patterns applied:
 // * ReduceMultiDimReductionRank from populateVectorMultiReductionFlatteningPatterns
 // * TwoDimMultiReductionToReduction from populateVectorMultiReductionUnrollingPatterns
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
index 33adb55456475..d0ab71e3f400f 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir
@@ -181,13 +181,6 @@ func.func @vector_reduction_1D(%arg0 : vector<2xf32>, %acc: f32) -> f32 {
 // CHECK-LABEL: func @vector_reduction_1D
 //       CHECK:   return %{{.+}}
 
-func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) -> f32 {
-  %0 = vector.multi_reduction <add>, %arg0, %acc [0, 1] : vector<2x3xf32> to f32
-  return %0 : f32
-}
-// CHECK-LABEL: func @vector_multi_reduction_to_scalar
-//       CHECK:   return %{{.+}}
-
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py
index 0f9aab29ca5ca..29ce8ba63cd53 100644
--- a/mlir/test/python/dialects/transform_vector_ext.py
+++ b/mlir/test/python/dialects/transform_vector_ext.py
@@ -108,6 +108,14 @@ def enum_configurable_patterns():
         lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
     )
 
+    # CHECK: transform.apply_patterns.vector.multi_reduction_flattening
+    vector.ApplyMultiReductionFlatteningPatternsOp()
+    # CHECK: transform.apply_patterns.vector.multi_reduction_flattening
+    # CHECK-SAME: lowering_strategy = innerreduction
+    vector.ApplyMultiReductionFlatteningPatternsOp(
+        lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction
+    )
+
     # CHECK: transform.apply_patterns.vector.lower_transpose
     vector.ApplyLowerTransposePatternsOp()
     # CHECK: transform.apply_patterns.vector.lower_transpose

@amd-eochoalo amd-eochoalo changed the title [mlir][vector] Add lower_multi_reduction_flattening [mlir][vector] Add multi_reduction_flattening Feb 18, 2026
@amd-eochoalo amd-eochoalo merged commit 6ec5c1e into llvm:main Feb 18, 2026
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants