diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 9fec5804d0b3b..64edb27be83d5 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -223,6 +223,27 @@ def ApplyMaterializeMasksPatternsOp : Op]> { + let description = [{ + Indicates that vector multi_reduction operations will be lowered to + vector arithmetic elementwise operations on vectors of rank 1 or + vector.reduction operations. + + This populates all multi_reduction lowering patterns, + i.e., reorder_and_expand, flattening, and unrolling. + }]; + + let arguments = (ins DefaultValuedAttr:$lowering_strategy + ); + + let assemblyFormat = [{ + (`lowering_strategy` `=` $lowering_strategy^)? attr-dict + }]; +} + def ApplyReorderAndExpandMultiReductionPatternsOp: Op]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 9da4be88586f4..8cf7f7f5db8f4 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -129,6 +129,18 @@ void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( //===----------------------------------------------------------------------===// // Multi-reduction patterns //===----------------------------------------------------------------------===// +void transform::ApplyMultiReductionPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::VectorTransformsOptions vectorTransformOptions; + vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); + vector::populateVectorMultiReductionReorderAndExpandPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); + vector::populateVectorMultiReductionFlatteningPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); + vector::populateVectorMultiReductionUnrollingPatterns( + patterns, vectorTransformOptions.vectorMultiReductionLowering); +} + void transform::ApplyReorderAndExpandMultiReductionPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir index ab58dda91a914..7e930727f5cfb 100644 --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -30,9 +30,7 @@ module attributes {transform.with_named_sequence} { transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" transform.apply_patterns.vector.transfer_permutation_patterns - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel" - transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction lowering_strategy = "innerparallel" transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index a37105d573219..a19997909778a 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -39,9 +39,7 @@ module attributes {transform.with_named_sequence} { } : !transform.any_op transform.apply_patterns to %f { - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerparallel" - transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerparallel" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerparallel" + transform.apply_patterns.vector.multi_reduction lowering_strategy = "innerparallel" } : !transform.any_op transform.apply_patterns to %f { diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir index 25b65080339d5..c25f9aa243163 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_1d.mlir @@ -150,9 +150,7 @@ module attributes {transform.with_named_sequence} { // Step 3: Lower vector.multi_reduction transform.apply_patterns to %func { transform.apply_patterns.vector.lower_masked_transfers - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir index 6072b44adf4fa..948d9e7a5bc9a 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/reduce_2d.mlir @@ -155,9 +155,7 @@ module attributes {transform.with_named_sequence} { // Step 3: Lower vector.multi_reduction transform.apply_patterns to %func { transform.apply_patterns.vector.lower_masked_transfers - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir index 3c4f10316d0f3..d3b925ff70714 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-matmul-masked-vec.mlir @@ -53,9 +53,7 @@ module attributes {transform.with_named_sequence} { %func_op = transform.get_parent_op %0 : (!transform.any_op) -> !transform.op<"func.func"> transform.structured.vectorize %0 vector_sizes [4, 4, 2] : !transform.any_op transform.apply_patterns to %func_op { - transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_flattening lowering_strategy = "innerreduction" - transform.apply_patterns.vector.multi_reduction_unrolling lowering_strategy = "innerreduction" + transform.apply_patterns.vector.multi_reduction lowering_strategy = "innerreduction" } : !transform.op<"func.func"> transform.yield } diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index 8a3091d0b1b02..255ee75379027 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -87,6 +87,19 @@ def enum_configurable_patterns(): lowering_strategy=vector.VectorContractLowering.ParallelArith ) + # CHECK: transform.apply_patterns.vector.multi_reduction + vector.ApplyMultiReductionPatternsOp() + # CHECK: transform.apply_patterns.vector.multi_reduction + # This is the default mode, not printed. + vector.ApplyMultiReductionPatternsOp( + lowering_strategy=vector.VectorMultiReductionLowering.InnerParallel + ) + # CHECK: transform.apply_patterns.vector.multi_reduction + # CHECK-SAME: lowering_strategy = innerreduction + vector.ApplyMultiReductionPatternsOp( + lowering_strategy=vector.VectorMultiReductionLowering.InnerReduction + ) + # CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims vector.ApplyReorderAndExpandMultiReductionPatternsOp() # CHECK: transform.apply_patterns.vector.reorder_and_expand_multi_reduction_dims