diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index dcd5f6ff3ad74..b91bd09053c7e 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -266,14 +266,37 @@ def ApplyMultiReductionUnrollingPatternsOp: Op]> { let description = [{ - Indicates that vector multi_reduction operations should be unrolled. + Indicates that vector multi_reduction operations with more than one + reduction dimension should be unrolled in a rank-reducing way. + + This populates the patterns from + `populateVectorMultiReductionUnrollingPatterns`, i.e.: + * `UnrollMultiReductionInnerReduction` (inner_reduction) + * `UnrollMultiReductionInnerParallel` (inner_parallel) + }]; + + let arguments = (ins DefaultValuedAttr:$lowering_strategy + ); + + let assemblyFormat = [{ + (`lowering_strategy` `=` $lowering_strategy^)? attr-dict + }]; +} + + +def ApplyMultiReductionLoweringPatternsOp: Op]> { + let description = [{ + Indicates that vector multi_reduction operations should be lowered. 1-D multi_reductions are converted directly to vector.reduction. 2-D multi_reductions are unrolled into either a sequence of vector.reduction ops (innerreduction) or element-wise arith ops (innerparallel). This populates the patterns from - `populateVectorMultiReductionUnrollingPatterns`, i.e.: + `populateVectorMultiReductionLoweringPatterns`, i.e.: * `OneDimMultiReductionToReduction` * `TwoDimMultiReductionToReduction` (innerreduction) * `TwoDimMultiReductionToElementWise` (innerparallel) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index aa75eff409ef9..f9eeaf4136370 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -66,9 +66,12 @@ void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, /// Rewrites vector.multi_reduction such that all reduction dimensions are /// either innermost or outermost, by adding the proper vector.transpose /// operations. +/// +/// The benefit is set to be higher than the unrolling patterns. Otherwise +/// patterns here would match the same operations as those in unrolling. void populateVectorMultiReductionReorderPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, - PatternBenefit benefit = 1); + PatternBenefit benefit = 2); /// Populate the pattern set with the following patterns: /// @@ -77,7 +80,23 @@ void populateVectorMultiReductionReorderPatterns( /// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, /// by introducing vector.shape_cast ops to collapse + multi-reduce + expand /// back. +/// +/// The benefit is set to be higher than the unrolling patterns. Otherwise +/// patterns here would match the same operations as those in unrolling. void populateVectorMultiReductionFlatteningPatterns( + RewritePatternSet &patterns, VectorMultiReductionLowering options, + PatternBenefit benefit = 2); + +/// Populate the pattern set with the following patterns: +/// +/// [UnrollMultiReductionInnerReduction] +/// Extracts vectors along outer dimension +/// and chains multiple multi_reduction operations. +/// +/// [UnrollMultiReductionInnerParallel] +/// Extracts vectors along outer dimension, performs multi_reduction operations +/// and inserts them back to a vector.multi_reduction with a lower rank. +void populateVectorMultiReductionUnrollingPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit = 1); @@ -96,9 +115,12 @@ void populateVectorMultiReductionFlatteningPatterns( /// dimension, unroll the outer dimension to obtain a sequence of extract + /// vector.reduction + insert. This can further lower to horizontal reduction /// ops. -void populateVectorMultiReductionUnrollingPatterns( +/// +/// The benefit is set to be higher than the unrolling patterns. Otherwise +/// patterns here would match the same operations as those in unrolling. +void populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, - PatternBenefit benefit = 1); + PatternBenefit benefit = 2); /// Populate the pattern set with the following patterns: /// diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 312bd28ad48cf..55953aa22dd88 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -153,6 +153,14 @@ void transform::ApplyMultiReductionUnrollingPatternsOp::populatePatterns( patterns, vectorTransformOptions.vectorMultiReductionLowering); } +void transform::ApplyMultiReductionLoweringPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); + vector::populateVectorMultiReductionLoweringPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); +} + void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorOuterProductLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 76599822fbfe4..3f2b399e55343 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -443,6 +443,196 @@ struct TwoDimMultiReductionToReduction } }; +/// Unrolls outermost dimension for vector.multi_reduction. +/// Matches when the outermost dimension is not the only +/// reduction dimension. +/// +/// ```mlir +/// %res = vector.multi_reduction %src, %acc [0, [[REDUCTION_DIMS]] ] : +/// vector to vector +/// ``` +/// +/// ```mlir +/// %0 = vector.extract %src[0] : vector from vector +/// ... +/// %Nminus1 = vector.extract %src[ [[N-1]] ] : vector from +/// vector +/// +/// %red0 = vector.multi_reduction %0, %acc [ [[REDUCTION_DIMS]] ] : +/// vector to vector +/// ... +/// %res = vector.multi_reduction %Nminus1, %redNminus2 [ [[REDUCTION_DIMS]] ] : +/// vector to vector +/// ``` +struct UnrollMultiReductionInnerParallel + : public vector::MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; + + FailureOr + matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp, + vector::MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + if (!multiReductionOp.isReducedDim(0)) + return failure(); + + ArrayRef reductionDims = multiReductionOp.getReductionDims(); + if (reductionDims.size() <= 1) + return failure(); + + Location loc = multiReductionOp.getLoc(); + Value source = multiReductionOp.getSource(); + + ArrayRef srcShape = + multiReductionOp.getSourceVectorType().getShape(); + int64_t outerDimSize = srcShape.front(); + + Value mask = maskingOp ? maskingOp.getMask() : nullptr; + + SmallVector vectors(outerDimSize); + for (int64_t i = 0; i < outerDimSize; ++i) + vectors[i] = vector::ExtractOp::create(rewriter, loc, source, i); + + SmallVector masks(outerDimSize); + if (mask) + for (int64_t i = 0; i < outerDimSize; ++i) + masks[i] = vector::ExtractOp::create(rewriter, loc, mask, i); + + SmallVector fullReductionMask = multiReductionOp.getReductionMask(); + ArrayRef reductionMask = + ArrayRef(fullReductionMask).drop_front(); + Value result = multiReductionOp.getAcc(); + for (auto [innerVector, innerMask] : llvm::zip_equal(vectors, masks)) { + auto reductionOp = vector::MultiDimReductionOp::create( + rewriter, loc, innerVector, result, reductionMask, + multiReductionOp.getKind()); + + if (innerMask) { + Operation *maskOp = + vector::maskOperation(rewriter, reductionOp, innerMask); + result = maskOp->getResult(0); + } else { + result = reductionOp.getResult(); + } + } + + return result; + } +}; + +/// Unrolls vector.multi_reduction along the outermost parallel dimension +/// when the innermost dimension is a reduction dimension. +/// +/// This pattern matches operations where: +/// - The innermost dimension is a reduction dimension +/// - The outermost dimension is a parallel dimension +/// +/// The transformation extracts slices along the outermost parallel dimension, +/// creates smaller multi_reductions, and assembles the results: +/// +/// ```mlir +/// %res = vector.multi_reduction %src, %acc [1, 3] +/// : vector to vector +/// ``` +/// +/// becomes: +/// +/// ```mlir +/// %result = arith.constant dense<0.0> : vector +/// %0 = vector.extract %src[0] : vector from vector +/// %acc0 = vector.extract %acc[0] : vector from vector +/// %red0 = vector.multi_reduction , %0, %acc0 [0, 2] +/// : vector to vector +/// %res0 = vector.insert %red0, %result[0] +/// : vector into vector +/// // ... repeat for indices 1 to A-1 +/// ``` +struct UnrollMultiReductionInnerReduction + : public vector::MaskableOpRewritePattern { + using MaskableOpRewritePattern::MaskableOpRewritePattern; + + FailureOr + matchAndRewriteMaskableOp(vector::MultiDimReductionOp multiReductionOp, + vector::MaskingOpInterface maskingOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + + if (srcRank < 2) + return rewriter.notifyMatchFailure(multiReductionOp, + "expected source rank >= 2."); + + if (!multiReductionOp.isReducedDim(srcRank - 1)) + return rewriter.notifyMatchFailure( + multiReductionOp, + "expected innermost dimension to be a reduction dimension."); + + if (multiReductionOp.isReducedDim(0)) + return rewriter.notifyMatchFailure( + multiReductionOp, + "expected outermost dimension to be a parallel dimension."); + + Type elementType = getElementTypeOrSelf(multiReductionOp.getDestType()); + if (!elementType.isIntOrIndexOrFloat()) + return rewriter.notifyMatchFailure( + multiReductionOp, "expected integer or float element type."); + + Location loc = multiReductionOp.getLoc(); + Value source = multiReductionOp.getSource(); + Value acc = multiReductionOp.getAcc(); + + ArrayRef srcShape = + multiReductionOp.getSourceVectorType().getShape(); + int64_t numSlices = srcShape.front(); + + Value mask = maskingOp ? maskingOp.getMask() : Value(); + + SmallVector srcSlices; + for (int64_t i = 0; i < numSlices; ++i) + srcSlices.push_back(vector::ExtractOp::create(rewriter, loc, source, i)); + + SmallVector accSlices; + for (int64_t i = 0; i < numSlices; ++i) + accSlices.push_back(vector::ExtractOp::create(rewriter, loc, acc, i)); + + SmallVector maskSlices; + for (int64_t i = 0; i < numSlices; ++i) + if (mask) + maskSlices.push_back(vector::ExtractOp::create(rewriter, loc, mask, i)); + else + maskSlices.push_back(nullptr); + + // Compute new reduction mask by dropping the first element (dimension 0). + // Since dimension 0 is parallel (not reduced), all reduction indices shift + // down by 1. + SmallVector fullReductionMask = multiReductionOp.getReductionMask(); + ArrayRef newReductionMask = + ArrayRef(fullReductionMask).drop_front(); + + SmallVector reductionResults; + for (auto [srcSlice, accSlice, maskSlice] : + llvm::zip(srcSlices, accSlices, maskSlices)) { + Operation *newReductionOp = vector::MultiDimReductionOp::create( + rewriter, loc, srcSlice, accSlice, newReductionMask, + multiReductionOp.getKind()); + + if (maskSlice) + newReductionOp = + mlir::vector::maskOperation(rewriter, newReductionOp, maskSlice); + + reductionResults.push_back(newReductionOp->getResult(0)); + } + + Value result = arith::ConstantOp::create( + rewriter, loc, multiReductionOp.getDestType(), + rewriter.getZeroAttr(multiReductionOp.getDestType())); + + for (int64_t i = 0; i < numSlices; ++i) + result = vector::InsertOp::create(rewriter, loc, reductionResults[i], + result, i); + + return result; + } +}; + /// Converts 1D vector.multi_reduction directly to vector.reduction. /// /// Example: @@ -496,19 +686,13 @@ struct LowerVectorMultiReductionPass RewritePatternSet patterns(context); mlir::vector::populateVectorMultiReductionReorderPatterns( patterns, this->loweringStrategy); - if (failed(applyPatternsGreedily(op, std::move(patterns)))) - signalPassFailure(); - - RewritePatternSet flatteningPatterns(context); mlir::vector::populateVectorMultiReductionFlatteningPatterns( - flatteningPatterns, this->loweringStrategy); - if (failed(applyPatternsGreedily(op, std::move(flatteningPatterns)))) - signalPassFailure(); - - RewritePatternSet unrollingPatterns(context); + patterns, this->loweringStrategy); mlir::vector::populateVectorMultiReductionUnrollingPatterns( - unrollingPatterns, this->loweringStrategy); - if (failed(applyPatternsGreedily(op, std::move(unrollingPatterns)))) + patterns, this->loweringStrategy); + mlir::vector::populateVectorMultiReductionLoweringPatterns( + patterns, this->loweringStrategy); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); } @@ -532,7 +716,7 @@ void mlir::vector::populateVectorMultiReductionFlatteningPatterns( patterns.add(patterns.getContext(), options, benefit); } -void mlir::vector::populateVectorMultiReductionUnrollingPatterns( +void mlir::vector::populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); @@ -544,6 +728,17 @@ void mlir::vector::populateVectorMultiReductionUnrollingPatterns( benefit); } +void mlir::vector::populateVectorMultiReductionUnrollingPatterns( + RewritePatternSet &patterns, VectorMultiReductionLowering options, + PatternBenefit benefit) { + if (options == VectorMultiReductionLowering ::InnerReduction) + patterns.add(patterns.getContext(), + benefit); + else + patterns.add(patterns.getContext(), + benefit); +} + std::unique_ptr vector::createLowerVectorMultiReductionPass( vector::VectorMultiReductionLowering option) { return std::make_unique(option); diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir index bf7eba6e50174..b4cb4b75a5f4d 100644 --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -33,6 +33,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel" transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 4dc11c26e83f1..b8729d06c9cfc 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -42,6 +42,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel" } : !transform.any_op transform.apply_patterns to %f { diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir index 447416ccba637..a7ec4b9206c33 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-unrolling.mlir @@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @innerreduction(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield } @@ -125,7 +125,7 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @innerparallel(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> transform.apply_patterns to %func_op { - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerparallel" } : !transform.op<"func.func"> transform.yield } diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir new file mode 100644 index 0000000000000..4f67d5678d6d3 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-parallel.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +// CHECK-LABEL: func @unroll_multi_reduction_inner_parallel +// CHECK-SAME: %[[INPUT:.+]]: vector<4x2x3xf32>, %[[ACC:.+]]: vector<2xf32> +func.func @unroll_multi_reduction_inner_parallel(%arg0: vector<4x2x3xf32>, %acc: vector<2xf32>) -> vector<2xf32> { + // CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] : vector<2x3xf32> from vector<4x2x3xf32> + // CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] : vector<2x3xf32> from vector<4x2x3xf32> + // CHECK: %[[V2:.+]] = vector.extract %[[INPUT]][2] : vector<2x3xf32> from vector<4x2x3xf32> + // CHECK: %[[V3:.+]] = vector.extract %[[INPUT]][3] : vector<2x3xf32> from vector<4x2x3xf32> + // CHECK: %[[RV0:.+]] = vector.multi_reduction , %[[V0]], %[[ACC]] [1] : vector<2x3xf32> to vector<2xf32> + // CHECK: %[[RV1:.+]] = vector.multi_reduction , %[[V1]], %[[RV0]] [1] : vector<2x3xf32> to vector<2xf32> + // CHECK: %[[RV2:.+]] = vector.multi_reduction , %[[V2]], %[[RV1]] [1] : vector<2x3xf32> to vector<2xf32> + // CHECK: %[[RESULT:.+]] = vector.multi_reduction , %[[V3]], %[[RV2]] [1] : vector<2x3xf32> to vector<2xf32> + %0 = vector.multi_reduction , %arg0, %acc [0, 2] : vector<4x2x3xf32> to vector<2xf32> + // CHECK: return %[[RESULT]] + return %0 : vector<2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { + %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.multi_reduction_unrolling + } : !transform.op<"func.func"> + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir new file mode 100644 index 0000000000000..04217479e377b --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-unroll-multi-reduction-inner-reduction.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test UnrollVectorMultiReduction for Inner Reduction +//===----------------------------------------------------------------------===// +// +// The general case handles multiple reduction dimensions. +// For vector<2x3x5xf32> with reduction on dims [1, 2]: +// UnrollMultiReductionInnerReduction unrolls along dim 0 (size 2), creating +// two vector<3x5xf32> multi_reductions with dims [0, 1], then insert results. + +// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general( +// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>, +// CHECK-SAME: %[[ACC:.+]]: vector<2xf32> +func.func @unroll_vector_multi_reduction_inner_general(%source: vector<2x3x5xf32>, %acc: vector<2xf32>) -> (vector<2xf32>) { + // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32> + // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32> + // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32> + // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32> + // CHECK: %[[R0:.+]] = vector.multi_reduction , %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32 + // CHECK: %[[R1:.+]] = vector.multi_reduction , %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32 + // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32> + // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32> + %1 = vector.multi_reduction , %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32> + + // CHECK: return %[[INSERT_1]] + return %1 : vector<2xf32> +} + +// CHECK-LABEL: func @unroll_vector_multi_reduction_inner_general_masked( +// CHECK-SAME: %[[SOURCE:.+]]: vector<2x3x5xf32>, +// CHECK-SAME: %[[MASK:.+]]: vector<2x3x5xi1>, +// CHECK-SAME: %[[ACC:.+]]: vector<2xf32> +func.func @unroll_vector_multi_reduction_inner_general_masked(%source: vector<2x3x5xf32>, %mask: vector<2x3x5xi1>, %acc: vector<2xf32>) -> (vector<2xf32>) { + // CHECK-DAG: %[[SRC_0:.+]] = vector.extract %[[SOURCE]][0] : vector<3x5xf32> from vector<2x3x5xf32> + // CHECK-DAG: %[[SRC_1:.+]] = vector.extract %[[SOURCE]][1] : vector<3x5xf32> from vector<2x3x5xf32> + // CHECK-DAG: %[[ACC_0:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32> + // CHECK-DAG: %[[ACC_1:.+]] = vector.extract %[[ACC]][1] : f32 from vector<2xf32> + // CHECK-DAG: %[[MASK_0:.+]] = vector.extract %[[MASK]][0] : vector<3x5xi1> from vector<2x3x5xi1> + // CHECK-DAG: %[[MASK_1:.+]] = vector.extract %[[MASK]][1] : vector<3x5xi1> from vector<2x3x5xi1> + // CHECK: %[[R0:.+]] = vector.mask %[[MASK_0]] {{.*}} vector.multi_reduction , %[[SRC_0]], %[[ACC_0]] [0, 1] : vector<3x5xf32> to f32 + // CHECK: %[[R1:.+]] = vector.mask %[[MASK_1]] {{.*}} vector.multi_reduction , %[[SRC_1]], %[[ACC_1]] [0, 1] : vector<3x5xf32> to f32 + // CHECK: %[[INSERT_0:.+]] = vector.insert %[[R0]], %{{.+}} [0] : f32 into vector<2xf32> + // CHECK: %[[INSERT_1:.+]] = vector.insert %[[R1]], %[[INSERT_0]] [1] : f32 into vector<2xf32> + + %0 = vector.mask %mask { + %1 = vector.multi_reduction , %source, %acc [1, 2] : vector<2x3x5xf32> to vector<2xf32> + } : vector<2x3x5xi1> -> vector<2xf32> + + // CHECK: return %[[INSERT_1]] + return %0 : vector<2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { + %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = innerreduction + } : !transform.op<"func.func"> + transform.yield + } +} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir index a7b0b27ca5fb9..23623b8b3f21e 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir @@ -153,6 +153,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir index 4adc68966f17a..bf6ab21afc909 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir @@ -158,6 +158,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir index 0883e7b698f55..8d973527847c1 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir @@ -56,6 +56,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns.vector.reorder_multi_reduction_dims lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction_lowering lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield } diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index a3c53a45048b2..b1dae910052b7 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -111,6 +111,14 @@ def enum_configurable_patterns(): lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction ) + # CHECK: transform.apply_patterns.vector.multi_reduction_lowering + vector.ApplyMultiReductionLoweringPatternsOp() + # CHECK: transform.apply_patterns.vector.multi_reduction_lowering + # CHECK-SAME: lowering_strategy = innerreduction + vector.ApplyMultiReductionLoweringPatternsOp( + lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction + ) + # CHECK: transform.apply_patterns.vector.lower_transpose vector.ApplyLowerTransposePatternsOp() # CHECK: transform.apply_patterns.vector.lower_transpose