Skip to content

[vector][multi_reduction] Add unrolling for vector.multi_reduction#185033

Closed
amd-eochoalo wants to merge 6 commits into
llvm:mainfrom
amd-eochoalo:eochoa/2026-03-06/multi-reduction
Closed

[vector][multi_reduction] Add unrolling for vector.multi_reduction#185033
amd-eochoalo wants to merge 6 commits into
llvm:mainfrom
amd-eochoalo:eochoa/2026-03-06/multi-reduction

Conversation

@amd-eochoalo
Copy link
Copy Markdown
Contributor

vector.multi_reduction's unrolling patterns are currently a misnomer. They don't unroll, but instead lower multi_reduction to either arith operations or vector.reduction. This PR:

  • renames multi_reduction_unrolling to multi_reduction_lowering
  • adds multi_reduction_unrolling where actual unrolling happens.

The benefits of vector.multi_reduction expand, flattening, unrolling and lowering patterns are such that they preserve existing behaviour. The benefit of adding unrolling is that now we can unroll without going through expand or flattening. This allows different backends finer control over which patterns to apply. For example, to better handle differences between SPIR-V and LLVM lowerings.

This will break downstream projects. All you need to do is apply multi_reduction_lowering where multi_reduction_unrolling was applied before.

Stated previously:

> Below are some notes on naming
> * RewriteXAsY or LowerXToY (if we're changing op kind)
>
> Concretely, TwoDimMultiReductionToReduction looks like a lowering
> (it rewrites to vector.reduction), not an unrolling,

llvm#182301 (comment)
When unrolling vector.multi_reduction with outermost
dimensions being reduced, exctract outermost dimension
vectors and chain them with vector.multi_reductions.
Unrolls multi_reduction by series of extractions, multi_reduction
and insertions.
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 6, 2026

🐧 Linux x64 Test Results

  • 7600 tests passed
  • 603 tests skipped

✅ The build succeeded and all tests passed.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 6, 2026

🪟 Windows x64 Test Results

  • 3525 tests passed
  • 416 tests skipped

✅ The build succeeded and all tests passed.

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Mar 6, 2026

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir-llvm

Author: Erick Ochoa Lopez (amd-eochoalo)

Changes

vector.multi_reduction's unrolling patterns are currently a misnomer. They don't unroll, but instead lower multi_reduction to either arith operations or vector.reduction. This PR:

  • renames multi_reduction_unrolling to multi_reduction_lowering
  • adds multi_reduction_unrolling where actual unrolling happens.

The benefits of vector.multi_reduction expand, flattening, unrolling and lowering patterns are such that they preserve existing behaviour. The benefit of adding unrolling is that now we can unroll without going through expand or flattening. This allows different backends finer control over which patterns to apply. For example, to better handle differences between SPIR-V and LLVM lowerings.

This will break downstream projects. All you need to do is apply multi_reduction_lowering where multi_reduction_unrolling was applied before.


Patch is 28.86 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/185033.diff

13 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td (+25-2)
  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+25-3)
  • (modified) mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp (+8)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+207-12)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+1)
  • (modified) mlir/test/Dialect/Vector/transform-vector.mlir (+1)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir (+2-2)
  • (added) mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir (+27)
  • (added) mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir (+62)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir (+1)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir (+1)
  • (modified) mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir (+1)
  • (modified) mlir/test/python/dialects/transform_vector_ext.py (+8)
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index dcd5f6ff3ad74..b91bd09053c7e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -266,14 +266,37 @@ def ApplyMultiReductionUnrollingPatternsOp: Op<Transform_Dialect,
     "apply_patterns.vector.multi_reduction_unrolling",
     [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that vector multi_reduction operations should be unrolled.
+    Indicates that vector multi_reduction operations with more than one
+    reduction dimension should be unrolled in a rank-reducing way.
+
+    This populates the patterns from
+    `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+    * `UnrollMultiReductionInnerReduction` (inner_reduction)
+    * `UnrollMultiReductionInnerParallel` (inner_parallel)
+  }];
+
+  let arguments = (ins DefaultValuedAttr<VectorMultiReductionLoweringAttr,
+      "vector::VectorMultiReductionLowering::InnerParallel">:$lowering_strategy
+  );
+
+  let assemblyFormat = [{
+    (`lowering_strategy` `=` $lowering_strategy^)? attr-dict
+  }];
+}
+
+
+def ApplyMultiReductionLoweringPatternsOp: Op<Transform_Dialect,
+    "apply_patterns.vector.multi_reduction_lowering",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector multi_reduction operations should be lowered.
     1-D multi_reductions are converted directly to vector.reduction.
     2-D multi_reductions are unrolled into either a sequence of
     vector.reduction ops (innerreduction) or element-wise arith ops
     (innerparallel).
 
     This populates the patterns from
-    `populateVectorMultiReductionUnrollingPatterns`, i.e.:
+    `populateVectorMultiReductionLoweringPatterns`, i.e.:
     * `OneDimMultiReductionToReduction`
     * `TwoDimMultiReductionToReduction` (innerreduction)
     * `TwoDimMultiReductionToElementWise` (innerparallel)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index aa75eff409ef9..f9eeaf4136370 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -66,9 +66,12 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns,
 /// Rewrites vector.multi_reduction such that all reduction dimensions are
 /// either innermost or outermost, by adding the proper vector.transpose
 /// operations.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
 void populateVectorMultiReductionReorderPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 2);
 
 /// Populate the pattern set with the following patterns:
 ///
@@ -77,7 +80,23 @@ void populateVectorMultiReductionReorderPatterns(
 /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction,
 /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand
 /// back.
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
 void populateVectorMultiReductionFlatteningPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit = 2);
+
+/// Populate the pattern set with the following patterns:
+///
+/// [UnrollMultiReductionInnerReduction]
+/// Extracts vectors along outer dimension
+/// and chains multiple multi_reduction operations.
+///
+/// [UnrollMultiReductionInnerParallel]
+/// Extracts vectors along outer dimension, performs multi_reduction operations
+/// and inserts them back to a vector.multi_reduction with a lower rank.
+void populateVectorMultiReductionUnrollingPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit = 1);
 
@@ -96,9 +115,12 @@ void populateVectorMultiReductionFlatteningPatterns(
 /// dimension, unroll the outer dimension to obtain a sequence of extract +
 /// vector.reduction + insert. This can further lower to horizontal reduction
 /// ops.
-void populateVectorMultiReductionUnrollingPatterns(
+///
+/// The benefit is set to be higher than the unrolling patterns. Otherwise
+/// patterns here would match the same operations as those in unrolling.
+void populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
-    PatternBenefit benefit = 1);
+    PatternBenefit benefit = 2);
 
 /// Populate the pattern set with the following patterns:
 ///
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 312bd28ad48cf..55953aa22dd88 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -153,6 +153,14 @@ void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns(
       patterns, vectorTransformOptions.vectorMultiReductionLowering);
 }
 
+void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  vector::VectorTransformsOptions vectorTransformOptions;
+  vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy());
+  vector::populateVectorMultiReductionLoweringPatterns(
+      patterns, vectorTransformOptions.vectorMultiReductionLowering);
+}
+
 void transform::ApplyLowerOuterProductPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
   populateVectorOuterProductLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 76599822fbfe4..3f2b399e55343 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -443,6 +443,196 @@ struct TwoDimMultiReductionToReduction
   }
 };
 
+/// Unrolls outermost dimension for vector.multi_reduction.
+/// Matches when the outermost dimension is not the only
+/// reduction dimension.
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [0, [[REDUCTION_DIMS]] ] :
+/// vector<NxMx...xf32> to vector<Ix...xf32>
+/// ```
+///
+/// ```mlir
+/// %0 = vector.extract %src[0] : vector<Mx...xf32> from vector<NxMx...xf32>
+/// ...
+/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector<Mx...x.f32> from
+/// vector<NxMx...xf32>
+///
+/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ...
+/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] :
+/// vector<Mx...xf32> to vector<Ix...xf32>
+/// ```
+struct UnrollMultiReductionInnerParallel
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    if (!multiReductionOp.isReducedDim(0))
+      return failure();
+
+    ArrayRef<int64_t> reductionDims = multiReductionOp.getReductionDims();
+    if (reductionDims.size() <= 1)
+      return failure();
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t outerDimSize = srcShape.front();
+
+    Value mask = maskingOp ? maskingOp.getMask() : nullptr;
+
+    SmallVector<Value> vectors(outerDimSize);
+    for (int64_t i = 0; i < outerDimSize; ++i)
+      vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i);
+
+    SmallVector<Value> masks(outerDimSize);
+    if (mask)
+      for (int64_t i = 0; i < outerDimSize; ++i)
+        masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i);
+
+    SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+    ArrayRef<bool> reductionMask =
+        ArrayRef<bool>(fullReductionMask).drop_front();
+    Value result = multiReductionOp.getAcc();
+    for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) {
+      auto reductionOp = vector::MultiDimReductionOp::create(
+          rewriter, loc, innerVector, result, reductionMask,
+          multiReductionOp.getKind());
+
+      if (innerMask) {
+        Operation *maskOp =
+            vector::maskOperation(rewriter, reductionOp, innerMask);
+        result = maskOp->getResult(0);
+      } else {
+        result = reductionOp.getResult();
+      }
+    }
+
+    return result;
+  }
+};
+
+/// Unrolls vector.multi_reduction along the outermost parallel dimension
+/// when the innermost dimension is a reduction dimension.
+///
+/// This pattern matches operations where:
+/// - The innermost dimension is a reduction dimension
+/// - The outermost dimension is a parallel dimension
+///
+/// The transformation extracts slices along the outermost parallel dimension,
+/// creates smaller multi_reductions, and assembles the results:
+///
+/// ```mlir
+/// %res = vector.multi_reduction <add> %src, %acc [1, 3]
+///     : vector<AxBxCxDxf32> to vector<AxCxf32>
+/// ```
+///
+/// becomes:
+///
+/// ```mlir
+/// %result = arith.constant dense<0.0> : vector<AxCxf32>
+/// %0 = vector.extract %src[0] : vector<BxCxDxf32> from vector<AxBxCxDxf32>
+/// %acc0 = vector.extract %acc[0] : vector<Cxf32> from vector<AxCxf32>
+/// %red0 = vector.multi_reduction <add>, %0, %acc0 [0, 2]
+///     : vector<BxCxDxf32> to vector<Cxf32>
+/// %res0 = vector.insert %red0, %result[0]
+///     : vector<Cxf32> into vector<AxCxf32>
+/// // ... repeat for indices 1 to A-1
+/// ```
+struct UnrollMultiReductionInnerReduction
+    : public vector::MaskableOpRewritePattern<vector::MultiDimReductionOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
+
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
+    auto srcRank = multiReductionOp.getSourceVectorType().getRank();
+
+    if (srcRank < 2)
+      return rewriter.notifyMatchFailure(multiReductionOp,
+                                         "expected source rank >= 2.");
+
+    if (!multiReductionOp.isReducedDim(srcRank - 1))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected innermost dimension to be a reduction dimension.");
+
+    if (multiReductionOp.isReducedDim(0))
+      return rewriter.notifyMatchFailure(
+          multiReductionOp,
+          "expected outermost dimension to be a parallel dimension.");
+
+    Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType());
+    if (!elementType.isIntOrIndexOrFloat())
+      return rewriter.notifyMatchFailure(
+          multiReductionOp, "expected integer or float element type.");
+
+    Location loc = multiReductionOp.getLoc();
+    Value source = multiReductionOp.getSource();
+    Value acc = multiReductionOp.getAcc();
+
+    ArrayRef<int64_t> srcShape =
+        multiReductionOp.getSourceVectorType().getShape();
+    int64_t numSlices = srcShape.front();
+
+    Value mask = maskingOp ? maskingOp.getMask() : Value();
+
+    SmallVector<Value> srcSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i));
+
+    SmallVector<Value> accSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i));
+
+    SmallVector<Value> maskSlices;
+    for (int64_t i = 0; i < numSlices; ++i)
+      if (mask)
+        maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i));
+      else
+        maskSlices.push_back(nullptr);
+
+    // Compute new reduction mask by dropping the first element (dimension 0).
+    // Since dimension 0 is parallel (not reduced), all reduction indices shift
+    // down by 1.
+    SmallVector<bool> fullReductionMask = multiReductionOp.getReductionMask();
+    ArrayRef<bool> newReductionMask =
+        ArrayRef<bool>(fullReductionMask).drop_front();
+
+    SmallVector<Value> reductionResults;
+    for (auto [srcSlice, accSlice, maskSlice] :
+         llvm::zip(srcSlices, accSlices, maskSlices)) {
+      Operation *newReductionOp = vector::MultiDimReductionOp::create(
+          rewriter, loc, srcSlice, accSlice, newReductionMask,
+          multiReductionOp.getKind());
+
+      if (maskSlice)
+        newReductionOp =
+            mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice);
+
+      reductionResults.push_back(newReductionOp->getResult(0));
+    }
+
+    Value result = arith::ConstantOp::create(
+        rewriter, loc, multiReductionOp.getDestType(),
+        rewriter.getZeroAttr(multiReductionOp.getDestType()));
+
+    for (int64_t i = 0; i < numSlices; ++i)
+      result = vector::InsertOp::create(rewriter, loc, reductionResults[i],
+                                        result, i);
+
+    return result;
+  }
+};
+
 /// Converts 1D vector.multi_reduction directly to vector.reduction.
 ///
 /// Example:
@@ -496,19 +686,13 @@ struct LowerVectorMultiReductionPass
     RewritePatternSet patterns(context);
     mlir::vector::populateVectorMultiReductionReorderPatterns(
         patterns, this->loweringStrategy);
-    if (failed(applyPatternsGreedily(op, std::move(patterns))))
-      signalPassFailure();
-
-    RewritePatternSet flatteningPatterns(context);
     mlir::vector::populateVectorMultiReductionFlatteningPatterns(
-        flatteningPatterns, this->loweringStrategy);
-    if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns))))
-      signalPassFailure();
-
-    RewritePatternSet unrollingPatterns(context);
+        patterns, this->loweringStrategy);
     mlir::vector::populateVectorMultiReductionUnrollingPatterns(
-        unrollingPatterns, this->loweringStrategy);
-    if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns))))
+        patterns, this->loweringStrategy);
+    mlir::vector::populateVectorMultiReductionLoweringPatterns(
+        patterns, this->loweringStrategy);
+    if (failed(applyPatternsGreedily(op, std::move(patterns))))
       signalPassFailure();
   }
 
@@ -532,7 +716,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns(
   patterns.add<FlattenMultiReduction>(patterns.getContext(), options, benefit);
 }
 
-void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+void mlir::vector::populateVectorMultiReductionLoweringPatterns(
     RewritePatternSet &patterns, VectorMultiReductionLowering options,
     PatternBenefit benefit) {
   patterns.add<OneDimMultiReductionToReduction>(patterns.getContext(), benefit);
@@ -544,6 +728,17 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
                                                     benefit);
 }
 
+void mlir::vector::populateVectorMultiReductionUnrollingPatterns(
+    RewritePatternSet &patterns, VectorMultiReductionLowering options,
+    PatternBenefit benefit) {
+  if (options == VectorMultiReductionLowering ::InnerReduction)
+    patterns.add<UnrollMultiReductionInnerReduction>(patterns.getContext(),
+                                                     benefit);
+  else
+    patterns.add<UnrollMultiReductionInnerParallel>(patterns.getContext(),
+                                                    benefit);
+}
+
 std::unique_ptr<Pass> vector::createLowerVectorMultiReductionPass(
     vector::VectorMultiReductionLowering option) {
   return std::make_unique<LowerVectorMultiReductionPass>(option);
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index bf7eba6e50174..b4cb4b75a5f4d 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -33,6 +33,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy"
       transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true
       transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1
diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir
index 4dc11c26e83f1..b8729d06c9cfc 100644
--- a/mlir/test/Dialect/Vector/transform-vector.mlir
+++ b/mlir/test/Dialect/Vector/transform-vector.mlir
@@ -42,6 +42,7 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel"
       transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
     } : !transform.any_op
 
     transform.apply_patterns to %f {
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
index 447416ccba637..a7ec4b9206c33 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir
@@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction"
     } : !transform.op<"func.func">
     transform.yield
   }
@@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
     transform.apply_patterns to %func_op {
-      transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel"
+      transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel"
     } : !transform.op<"func.func">
     transform.yield
   }
diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
new file mode 100644
index 0000000000000..4f67d5678d6d3
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: func @unroll_multi_reduction_inner_parallel
+// CHECK-SAME:    %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32>
+func.func @unroll_multi_reduction_inner_parallel(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
+    // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32>
+    // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32>
+    // CHECK: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vect...
[truncated]

@amd-eochoalo
Copy link
Copy Markdown
Contributor Author

Closing due to change of direction.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants