From d07bef658ff8167892b8e56769bc4e52b5b9a712 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 6 Mar 2026 09:47:24 -0500 Subject: [PATCH 1/6] [mlir][vector] Rename multi_reduction's unrolling to lowering. Stated previously: > Below are some notes on naming > * RewriteXAsY or LowerXToY (if we're changing op kind) > > Concretely, TwoDimMultiReductionToReduction looks like a lowering > (it rewrites to vector.reduction), not an unrolling, https://github.com/llvm/llvm-project/pull/182301#issuecomment-3951827980 --- .../Dialect/Vector/TransformOps/VectorTransformOps.td | 8 ++++---- .../mlir/Dialect/Vector/Transforms/LoweringPatterns.h | 2 +- .../Dialect/Vector/TransformOps/VectorTransformOps.cpp | 4 ++-- .../Vector/Transforms/LowerVectorMultiReduction.cpp | 10 +++++----- mlir/test/Dialect/LLVM/transform-e2e.mlir | 2 +- mlir/test/Dialect/Vector/transform-vector.mlir | 2 +- .../Vector/vector-multi-reduction-unrolling.mlir | 4 ++-- .../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir | 2 +- .../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir | 2 +- .../Dialect/Linalg/CPU/test-matmul-masked-vec.mlir | 2 +- mlir/test/python/dialects/transform_vector_ext.py | 6 +++--- 11 files changed, 22 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index dcd5f6ff3ad74..76e77b9f90a13 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -262,18 +262,18 @@ def ApplyMultiReductionFlatteningPatternsOp: Op]> { let description = [{ - Indicates that vector multi_reduction operations should be unrolled. + Indicates that vector multi_reduction operations should be lowered. 1-D multi_reductions are converted directly to vector.reduction. 2-D multi_reductions are unrolled into either a sequence of vector.reduction ops (innerreduction) or element-wise arith ops (innerparallel). This populates the patterns from - `populateVectorMultiReductionUnrollingPatterns`, i.e.: + `populateVectorMultiReductionLoweringPatterns`, i.e.: * `OneDimMultiReductionToReduction` * `TwoDimMultiReductionToReduction` (innerreduction) * `TwoDimMultiReductionToElementWise` (innerparallel) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index aa75eff409ef9..2ea3367da414d 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -96,7 +96,7 @@ void populateVectorMultiReductionFlatteningPatterns( /// dimension, unroll the outer dimension to obtain a sequence of extract + /// vector.reduction + insert. This can further lower to horizontal reduction /// ops. -void populateVectorMultiReductionUnrollingPatterns( +void populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit = 1); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 312bd28ad48cf..d1799b3bd84d8 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -145,11 +145,11 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns( patterns, vectorTransformOptions.vectorMultiReductionLowering); } -void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns( +void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); - vector::populateVectorMultiReductionUnrollingPatterns( + vector::populateVectorMultiReductionLoweringPatterns( patterns, vectorTransformOptions.vectorMultiReductionLowering); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 76599822fbfe4..88e669d1a34b0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -505,10 +505,10 @@ struct LowerVectorMultiReductionPass if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns)))) signalPassFailure(); - RewritePatternSet unrollingPatterns(context); - mlir::vector::populateVectorMultiReductionUnrollingPatterns( - unrollingPatterns, this->loweringStrategy); - if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns)))) + RewritePatternSet loweringPatterns(context); + mlir::vector::populateVectorMultiReductionLoweringPatterns( + loweringPatterns, this->loweringStrategy); + if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) signalPassFailure(); } @@ -532,7 +532,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns( patterns.add(patterns.getContext(), options, benefit); } -void mlir::vector::populateVectorMultiReductionUnrollingPatterns( +void mlir::vector::populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir index bf7eba6e50174..7f4c78a7fc2c8 100644 --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -32,7 +32,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.transfer_permutation_patterns transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel" transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 4dc11c26e83f1..9027bef2eac5c 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -41,7 +41,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns to %f { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel" } : !transform.any_op transform.apply_patterns to %f { diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir index 447416ccba637..a7ec4b9206c33 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir @@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield } @@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel" } : !transform.op<"func.func"> transform.yield } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir index a7b0b27ca5fb9..15f5bfd3619eb 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir @@ -152,7 +152,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.lower_masked_transfers transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir index 4adc68966f17a..0cfba97654769 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir @@ -157,7 +157,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.lower_masked_transfers transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir index 0883e7b698f55..ce85720d665c2 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir @@ -55,7 +55,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns to %func_op { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield } diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index a3c53a45048b2..cfaf91d24f471 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -103,9 +103,9 @@ def enum_configurable_patterns(): lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction ) - # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling - vector.ApplyMultiReductionUnrollingPatternsOp() - # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling + # CHECK: transform.apply_patterns.vector.multi_reduction_lowering + vector.ApplyMultiReductionLoweringPatternsOp() + # CHECK: transform.apply_patterns.vector.multi_reduction_lowering # CHECK-SAME: lowering_strategy = innerreduction vector.ApplyMultiReductionUnrollingPatternsOp( lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction From ba1fffa23f4b9cf6862897959db2cbecf2eea9d4 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 6 Mar 2026 10:01:00 -0500 Subject: [PATCH 2/6] [mlir][vector] Add multi_reduction unrolling innerparallel. When unrolling vector.multi_reduction with outermost dimensions being reduced, exctract outermost dimension vectors and chain them with vector.multi_reductions. --- .../Vector/TransformOps/VectorTransformOps.td | 14 ++++ .../Vector/Transforms/LoweringPatterns.h | 5 ++ .../TransformOps/VectorTransformOps.cpp | 5 ++ .../Transforms/LowerVectorMultiReduction.cpp | 82 +++++++++++++++++++ 4 files changed, 106 insertions(+) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 76e77b9f90a13..dff2803f95ae6 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -262,6 +262,20 @@ def ApplyMultiReductionFlatteningPatternsOp: Op]> { + let description = [{ + Indicates that vector multi_reduction operations with more than one + reduction dimension should be unrolled in a rank-reducing way. + + This populates the patterns: + }]; + + let assemblyFormat = "attr-dict"; +} + + def ApplyMultiReductionLoweringPatternsOp: Op]> { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 2ea3367da414d..d9237bac63b78 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -81,6 +81,11 @@ void populateVectorMultiReductionFlatteningPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate the pattern set with the following patterns: /// /// [OneDimMultiReductionToReduction] diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index d1799b3bd84d8..256384f6b7c18 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -145,6 +145,11 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns( patterns, vectorTransformOptions.vectorMultiReductionLowering); } +void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorMultiReductionUnrollingPatterns(patterns); +} + void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 88e669d1a34b0..51b2f64c9a7da 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -443,6 +443,82 @@ struct TwoDimMultiReductionToReduction } }; +/// Unrolls outermost dimension for vector.multi_reduction. +/// Matches when the outermost dimension is not the only +/// reduction dimension. +/// +/// ```mlir +/// %res = vector.multi_reduction %src, %acc [0, [[REDUCTION_DIMS]] ] : +/// vector to vector +/// ``` +/// +/// ```mlir +/// %0 = vector.extract %src[0] : vector from vector +/// ... +/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector from +/// vector +/// +/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] : +/// vector to vector +/// ... +/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] : +/// vector to vector +/// ``` +struct UnrollMultiReductionInnerParallel + : public vector::MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; + + FailureOr + matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp, + vector::MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + if (!multiReductionOp.isReducedDim(0)) + return failure(); + + ArrayRef reductionDims = multiReductionOp.getReductionDims(); + if (reductionDims.size() <= 1) + return failure(); + + Location loc = multiReductionOp.getLoc(); + Value source = multiReductionOp.getSource(); + + ArrayRef srcShape = + multiReductionOp.getSourceVectorType().getShape(); + int64_t outerDimSize = srcShape.front(); + + Value mask = maskingOp ? maskingOp.getMask() : nullptr; + + SmallVector vectors(outerDimSize); + for (int64_t i = 0; i < outerDimSize; ++i) + vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i); + + SmallVector masks(outerDimSize); + if (mask) + for (int64_t i = 0; i < outerDimSize; ++i) + masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i); + + SmallVector fullReductionMask = multiReductionOp.getReductionMask(); + ArrayRef reductionMask = + ArrayRef(fullReductionMask).drop_front(); + Value result = multiReductionOp.getAcc(); + for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) { + auto reductionOp = vector::MultiDimReductionOp::create( + rewriter, loc, innerVector, result, reductionMask, + multiReductionOp.getKind()); + + if (innerMask) { + Operation *maskOp = + vector::maskOperation(rewriter, reductionOp, innerMask); + result = maskOp->getResult(0); + } else { + result = reductionOp.getResult(); + } + } + + return result; + } +}; + /// Converts 1D vector.multi_reduction directly to vector.reduction. /// /// Example: @@ -544,6 +620,12 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns( benefit); } +void mlir::vector::populateVectorMultiReductionUnrollingPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} + std::unique_ptr vector::createLowerVectorMultiReductionPass( vector::VectorMultiReductionLowering option) { return std::make_unique(option); From 2ea3b59224e103f515c1abebc802435147a4dbff Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 6 Mar 2026 10:37:47 -0500 Subject: [PATCH 3/6] [mlir][vector] Add multi_reduction unrolling innerreduction. Unrolls multi_reduction by series of extractions, multi_reduction and insertions. --- .../Vector/TransformOps/VectorTransformOps.td | 13 +- .../Vector/Transforms/LoweringPatterns.h | 5 +- .../TransformOps/VectorTransformOps.cpp | 5 +- .../Transforms/LowerVectorMultiReduction.cpp | 125 +++++++++++++++++- .../python/dialects/transform_vector_ext.py | 4 +- 5 files changed, 143 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index dff2803f95ae6..c9226af30e674 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -269,10 +269,19 @@ def ApplyUnrollMultiReductionUnrollingPatternsOp: Op:$lowering_strategy + ); + + let assemblyFormat = [{ + (`lowering_strategy` `=` $lowering_strategy^)? attr-dict + }]; } diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index d9237bac63b78..744a7d4b3b425 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -83,8 +83,9 @@ void populateVectorMultiReductionFlatteningPatterns( /// Populate the pattern set with the following patterns: /// -void populateVectorMultiReductionUnrollingPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); +void populateVectorMultiReductionUnrollingPatterns( + RewritePatternSet &patterns, VectorMultiReductionLowering options, + PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 256384f6b7c18..f63d43a7d165c 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -147,7 +147,10 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns( void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorMultiReductionUnrollingPatterns(patterns); + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); + vector::populateVectorMultiReductionUnrollingPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); } void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 51b2f64c9a7da..e347ba15537bd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -519,6 +519,120 @@ struct UnrollMultiReductionInnerParallel } }; +/// Unrolls vector.multi_reduction along the outermost parallel dimension +/// when the innermost dimension is a reduction dimension. +/// +/// This pattern matches operations where: +/// - The innermost dimension is a reduction dimension +/// - The outermost dimension is a parallel dimension +/// +/// The transformation extracts slices along the outermost parallel dimension, +/// creates smaller multi_reductions, and assembles the results: +/// +/// ```mlir +/// %res = vector.multi_reduction %src, %acc [1, 3] +/// : vector to vector +/// ``` +/// +/// becomes: +/// +/// ```mlir +/// %result = arith.constant dense<0.0> : vector +/// %0 = vector.extract %src[0] : vector from vector +/// %acc0 = vector.extract %acc[0] : vector from vector +/// %red0 = vector.multi_reduction , %0, %acc0 [0, 2] +/// : vector to vector +/// %res0 = vector.insert %red0, %result[0] +/// : vector into vector +/// // ... repeat for indices 1 to A-1 +/// ``` +struct UnrollMultiReductionInnerReduction + : public vector::MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; + + FailureOr + matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp, + vector::MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + + if (srcRank < 2) + return rewriter.notifyMatchFailure(multiReductionOp, + "expected source rank >= 2."); + + if (!multiReductionOp.isReducedDim(srcRank - 1)) + return rewriter.notifyMatchFailure( + multiReductionOp, + "expected innermost dimension to be a reduction dimension."); + + if (multiReductionOp.isReducedDim(0)) + return rewriter.notifyMatchFailure( + multiReductionOp, + "expected outermost dimension to be a parallel dimension."); + + Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); + if (!elementType.isIntOrIndexOrFloat()) + return rewriter.notifyMatchFailure( + multiReductionOp, "expected integer or float element type."); + + Location loc = multiReductionOp.getLoc(); + Value source = multiReductionOp.getSource(); + Value acc = multiReductionOp.getAcc(); + + ArrayRef srcShape = + multiReductionOp.getSourceVectorType().getShape(); + int64_t numSlices = srcShape.front(); + + Value mask = maskingOp ? maskingOp.getMask() : Value(); + + SmallVector srcSlices; + for (int64_t i = 0; i < numSlices; ++i) + srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i)); + + SmallVector accSlices; + for (int64_t i = 0; i < numSlices; ++i) + accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i)); + + SmallVector maskSlices; + for (int64_t i = 0; i < numSlices; ++i) + if (mask) + maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i)); + else + maskSlices.push_back(nullptr); + + // Compute new reduction mask by dropping the first element (dimension 0). + // Since dimension 0 is parallel (not reduced), all reduction indices shift + // down by 1. + SmallVector fullReductionMask = multiReductionOp.getReductionMask(); + ArrayRef newReductionMask = + ArrayRef(fullReductionMask).drop_front(); + + SmallVector reductionResults; + for (auto [srcSlice, accSlice, maskSlice] : + llvm::zip(srcSlices, accSlices, maskSlices)) { + Operation *newReductionOp = vector::MultiDimReductionOp::create( + rewriter, loc, srcSlice, accSlice, newReductionMask, + multiReductionOp.getKind()); + + if (maskSlice) + newReductionOp = + mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice); + + reductionResults.push_back(newReductionOp->getResult(0)); + } + + Value result = arith::ConstantOp::create( + rewriter, loc, multiReductionOp.getDestType(), + rewriter.getZeroAttr(multiReductionOp.getDestType())); + + for (int64_t i = 0; i < numSlices; ++i) + result = vector::InsertOp::create(rewriter, loc, reductionResults[i], + result, i); + + return result; + } +}; + /// Converts 1D vector.multi_reduction directly to vector.reduction. /// /// Example: @@ -621,9 +735,14 @@ void mlir::vector::populateVectorMultiReductionLoweringPatterns( } void mlir::vector::populateVectorMultiReductionUnrollingPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); + RewritePatternSet &patterns, VectorMultiReductionLowering options, + PatternBenefit benefit) { + if (options == VectorMultiReductionLowering ::InnerReduction) + patterns.add(patterns.getContext(), + benefit); + else + patterns.add(patterns.getContext(), + benefit); } std::unique_ptr vector::createLowerVectorMultiReductionPass( diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index cfaf91d24f471..84d70181b5288 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -105,7 +105,9 @@ def enum_configurable_patterns(): # CHECK: transform.apply_patterns.vector.multi_reduction_lowering vector.ApplyMultiReductionLoweringPatternsOp() - # CHECK: transform.apply_patterns.vector.multi_reduction_lowering + # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling + vector.ApplyUnrollMultiReductionUnrollingPatternsOp() + # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling # CHECK-SAME: lowering_strategy = innerreduction vector.ApplyMultiReductionUnrollingPatternsOp( lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction From 4328d792b9c104cb50346deb039440e08f071f1f Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 6 Mar 2026 10:39:07 -0500 Subject: [PATCH 4/6] Add tests --- ...unroll-multi-reduction-inner-parallel.mlir | 27 ++++++++ ...nroll-multi-reduction-inner-reduction.mlir | 62 +++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir create mode 100644 mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir new file mode 100644 index 0000000000000..4f67d5678d6d3 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +// CHECK-LABEL: func @unroll_multi_reduction_inner_parallel +// CHECK-SAME: %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32> +func.func @unroll_multi_reduction_inner_parallel(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> { + // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32> + // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32> + // CHECK: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vector<4x2x3xf32> + // CHECK: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2x3xf32> from vector<4x2x3xf32> + // CHECK: %[[RV0:.+]] = vector.multi_reduction , %[[V0]], %[[ACC]] [1] : vector<2x3xf32> to vector<2xf32> + // CHECK: %[[RV1:.+]] = vector.multi_reduction , %[[V1]], %[[RV0]] [1] : vector<2x3xf32> to vector<2xf32> + // CHECK: %[[RV2:.+]] = vector.multi_reduction , %[[V2]], %[[RV1]] [1] : vector<2x3xf32> to vector<2xf32> + // CHECK: %[[RESULT:.+]] = vector.multi_reduction , %[[V3]], %[[RV2]] [1] : vector<2x3xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0, %acc [0, 2] : vector<4x2x3xf32> to vector<2xf32> + // CHECK: return %[[RESULT]] + return %0 : vector<2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { + %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.multi_reduction_unrolling + } : !transform.op<"func.func"> + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir new file mode 100644 index 0000000000000..04217479e377b --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test UnrollVectorMultiReduction for Inner Reduction +//===----------------------------------------------------------------------===// +// +// The general case handles multiple reduction dimensions. +// For vector<2x3x5xf32> with reduction on dims [1, 2]: +// UnrollMultiReductionInnerReduction unrolls along dim 0 (size 2), creating +// two vector<3x5xf32> multi_reductions with dims [0, 1], then insert results. + +// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general( +// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>, +// CHECK-SAME: %[[ACC:.+]]: vector<2xf32> +func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) { + // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32> + // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32> + // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32> + // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32> + // CHECK: %[[R0:.+]] = vector.multi_reduction , %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32 + // CHECK: %[[R1:.+]] = vector.multi_reduction , %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32 + // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32> + // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32> + %1 = vector.multi_reduction , %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32> + + // CHECK: return %[[INSERT_1]] + return %1 : vector<2xf32> +} + +// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general_masked( +// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>, +// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>, +// CHECK-SAME: %[[ACC:.+]]: vector<2xf32> +func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) { + // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32> + // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32> + // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32> + // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32> + // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1> + // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1> + // CHECK: %[[R0:.+]] = vector.mask %[[MASK_0]] {{.*}} vector.multi_reduction , %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32 + // CHECK: %[[R1:.+]] = vector.mask %[[MASK_1]] {{.*}} vector.multi_reduction , %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32 + // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32> + // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32> + + %0 = vector.mask %mask { + %1 = vector.multi_reduction , %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32> + } : vector<2x3x5xi1> -> vector<2xf32> + + // CHECK: return %[[INSERT_1]] + return %0 : vector<2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { + %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = innerreduction + } : !transform.op<"func.func"> + transform.yield + } +} From 9245142ef7cad0ebe090216ad5b179507f6b5edd Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 6 Mar 2026 11:07:58 -0500 Subject: [PATCH 5/6] Update benefits to ensure previous behaviour remains the same --- .../Vector/Transforms/LoweringPatterns.h | 22 ++++++++++++++++--- .../Transforms/LowerVectorMultiReduction.cpp | 16 +++++--------- mlir/test/Dialect/LLVM/transform-e2e.mlir | 1 + .../test/Dialect/Vector/transform-vector.mlir | 1 + .../Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir | 1 + .../Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir | 1 + .../Linalg/CPU/test-matmul-masked-vec.mlir | 1 + 7 files changed, 29 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 744a7d4b3b425..f9eeaf4136370 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -66,9 +66,12 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, /// Rewrites vector.multi_reduction such that all reduction dimensions are /// either innermost or outermost, by adding the proper vector.transpose /// operations. +/// +/// The benefit is set to be higher than the unrolling patterns. Otherwise +/// patterns here would match the same operations as those in unrolling. void populateVectorMultiReductionReorderPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, - PatternBenefit benefit = 1); + PatternBenefit benefit = 2); /// Populate the pattern set with the following patterns: /// @@ -77,12 +80,22 @@ void populateVectorMultiReductionReorderPatterns( /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand /// back. +/// +/// The benefit is set to be higher than the unrolling patterns. Otherwise +/// patterns here would match the same operations as those in unrolling. void populateVectorMultiReductionFlatteningPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, - PatternBenefit benefit = 1); + PatternBenefit benefit = 2); /// Populate the pattern set with the following patterns: /// +/// [UnrollMultiReductionInnerReduction] +/// Extracts vectors along outer dimension +/// and chains multiple multi_reduction operations. +/// +/// [UnrollMultiReductionInnerParallel] +/// Extracts vectors along outer dimension, performs multi_reduction operations +/// and inserts them back to a vector.multi_reduction with a lower rank. void populateVectorMultiReductionUnrollingPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit = 1); @@ -102,9 +115,12 @@ void populateVectorMultiReductionUnrollingPatterns( /// dimension, unroll the outer dimension to obtain a sequence of extract + /// vector.reduction + insert. This can further lower to horizontal reduction /// ops. +/// +/// The benefit is set to be higher than the unrolling patterns. Otherwise +/// patterns here would match the same operations as those in unrolling. void populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, - PatternBenefit benefit = 1); + PatternBenefit benefit = 2); /// Populate the pattern set with the following patterns: /// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index e347ba15537bd..3f2b399e55343 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -686,19 +686,13 @@ struct LowerVectorMultiReductionPass RewritePatternSet patterns(context); mlir::vector::populateVectorMultiReductionReorderPatterns( patterns, this->loweringStrategy); - if (failed(applyPatternsGreedily(op, std::move(patterns)))) - signalPassFailure(); - - RewritePatternSet flatteningPatterns(context); mlir::vector::populateVectorMultiReductionFlatteningPatterns( - flatteningPatterns, this->loweringStrategy); - if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns)))) - signalPassFailure(); - - RewritePatternSet loweringPatterns(context); + patterns, this->loweringStrategy); + mlir::vector::populateVectorMultiReductionUnrollingPatterns( + patterns, this->loweringStrategy); mlir::vector::populateVectorMultiReductionLoweringPatterns( - loweringPatterns, this->loweringStrategy); - if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) + patterns, this->loweringStrategy); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir index 7f4c78a7fc2c8..b4cb4b75a5f4d 100644 --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -32,6 +32,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.transfer_permutation_patterns transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel" transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 9027bef2eac5c..b8729d06c9cfc 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -41,6 +41,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns to %f { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel" } : !transform.any_op diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir index 15f5bfd3619eb..23623b8b3f21e 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir @@ -152,6 +152,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.lower_masked_transfers transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir index 0cfba97654769..bf6ab21afc909 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir @@ -157,6 +157,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.lower_masked_transfers transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir index ce85720d665c2..8d973527847c1 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir @@ -55,6 +55,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns to %func_op { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield From 365480502c15eada5308f9d14af3d09132d45b2c Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Fri, 6 Mar 2026 13:24:38 -0500 Subject: [PATCH 6/6] Fix identifier name --- .../Vector/TransformOps/VectorTransformOps.td | 2 +- .../Vector/TransformOps/VectorTransformOps.cpp | 2 +- mlir/test/python/dialects/transform_vector_ext.py | 12 +++++++++--- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index c9226af30e674..b91bd09053c7e 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -262,7 +262,7 @@ def ApplyMultiReductionFlatteningPatternsOp: Op]> { let description = [{ diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index f63d43a7d165c..55953aa22dd88 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -145,7 +145,7 @@ void transform::ApplyMultiReductionFlatteningPatternsOp::populatePatterns( patterns, vectorTransformOptions.vectorMultiReductionLowering); } -void transform::ApplyUnrollMultiReductionUnrollingPatternsOp::populatePatterns( +void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index 84d70181b5288..b1dae910052b7 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -103,16 +103,22 @@ def enum_configurable_patterns(): lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction ) - # CHECK: transform.apply_patterns.vector.multi_reduction_lowering - vector.ApplyMultiReductionLoweringPatternsOp() # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling - vector.ApplyUnrollMultiReductionUnrollingPatternsOp() + vector.ApplyMultiReductionUnrollingPatternsOp() # CHECK: transform.apply_patterns.vector.multi_reduction_unrolling # CHECK-SAME: lowering_strategy = innerreduction vector.ApplyMultiReductionUnrollingPatternsOp( lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction ) + # CHECK: transform.apply_patterns.vector.multi_reduction_lowering + vector.ApplyMultiReductionLoweringPatternsOp() + # CHECK: transform.apply_patterns.vector.multi_reduction_lowering + # CHECK-SAME: lowering_strategy = innerreduction + vector.ApplyMultiReductionLoweringPatternsOp( + lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction + ) + # CHECK: transform.apply_patterns.vector.lower_transpose vector.ApplyLowerTransposePatternsOp() # CHECK: transform.apply_patterns.vector.lower_transpose