diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index f256af2f6b12b..42057d8d0c910 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -41,6 +41,18 @@ def ApplyEraseUnnecessaryInputsPatternsOp : Op]> { + let description = [{ + Collect patterns to generalize tensor.pack and tensor.unpack (i.e. to + decompose it into e.g. tensor::PadOp, linalg::transposeOp etc). Requires + all outer dims to be unit. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 89e9a3b70d2ab..0b55a76f88433 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1516,8 +1516,8 @@ struct GeneralizePadOpPattern : public OpRewritePattern { }; /// Rewrites a tensor::PackOp into a sequence of: -/// * tensor::PadOp + linalg::TransposeOp + -/// tensor::EmptyOp + tensor::InsertSliceOp ops. +/// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp + +/// tensor::InsertSliceOp ops. /// /// Required that all the outer dims of the input tensor::PackOp are 1. /// @@ -1683,6 +1683,11 @@ void populateLinalgGenericOpsSpecializationPatterns( void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g. +/// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all +/// outer dims to be unit. +void populateGeneralizePatterns(RewritePatternSet &patterns); + /// Populates patterns to transform linalg.conv_2d_xxx operations into /// linalg.generic (for img2col packing) and linalg.matmul. /// \see rewriteInIm2Col for more details. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 1956fc634ef39..a00c609779c3a 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -229,6 +229,11 @@ void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns( linalg::populateEraseUnnecessaryInputsPatterns(patterns); } +void transform::ApplyGeneralizeTensorPackUnpackPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateGeneralizePatterns(patterns); +} + void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( RewritePatternSet &patterns) { linalg::ControlDropUnitDims options; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index ed9ebca4f306a..c9eac66367559 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1618,3 +1618,8 @@ void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, DownscaleSizeOneWindowed2DConvolution>( patterns.getContext(), benefit); } + +void linalg::populateGeneralizePatterns(RewritePatternSet &patterns) { + // TODO: Add and test patterns for tensor.unpack + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir index f4b1d9a55f091..ad20541e301d3 100644 --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt -split-input-file --test-linalg-transform-patterns="test-generalize-tensor-pack" %s | FileCheck %s - +// RUN: mlir-opt --transform-preload-library='transform-library-paths=%p/td/generalize-pack.mlir' -split-input-file --transform-interpreter %s | FileCheck %s func.func @simple_KCRS_to_KCRSsr(%arg0: tensor, %arg1: tensor<1x1x?x1xi32>) -> tensor<1x1x?x1xi32> { %c8 = arith.constant 8 : index diff --git a/mlir/test/Dialect/Linalg/lit.local.cfg b/mlir/test/Dialect/Linalg/lit.local.cfg new file mode 100644 index 0000000000000..62743008a3e3a --- /dev/null +++ b/mlir/test/Dialect/Linalg/lit.local.cfg @@ -0,0 +1,2 @@ +# Skip the directory with input TD sequences +config.excludes = ["td"] diff --git a/mlir/test/Dialect/Linalg/td/generalize-pack.mlir b/mlir/test/Dialect/Linalg/td/generalize-pack.mlir new file mode 100644 index 0000000000000..62e5b779ff361 --- /dev/null +++ b/mlir/test/Dialect/Linalg/td/generalize-pack.mlir @@ -0,0 +1,12 @@ +module @transforms attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %pack = transform.structured.match ops{["tensor.pack"]} in %module : (!transform.any_op) -> !transform.any_op + + %1 = transform.get_parent_op %pack {isolated_from_above} : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %1 { + transform.apply_patterns.linalg.generalize_pack_unpack + } : !transform.any_op + + transform.yield + } +}