diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 72a69a056c46e..03d25505dc65c 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -85,6 +85,20 @@ def ApplyDropUnitDimWithShapeCastPatternsOp : Op]> { + let description = [{ + Apply vector patterns to drop the inner most unit dims from + vector.transfer_read and vector.transfer_write Ops by taking a subview (via + memref.subview) of the original source/destination MemRef. Since it + requires the input/ouptu to be MemRefs, this Op is only helpful + past-bufferization. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyTransferPermutationPatternsOp : Op]> { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 6bb390aa09d3e..18f105ef62e38 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -88,6 +88,11 @@ void transform::ApplyDropUnitDimWithShapeCastPatternsOp::populatePatterns( vector::populateDropUnitDimWithShapeCastPatterns(patterns); } +void transform::ApplyDropInnerMostUnitDimsFromXferOpsPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + vector::populateDropInnerMostUnitDimsXferOpPatterns(patterns); +} + void transform::ApplyLowerBitCastPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorBitCastLoweringPatterns(patterns); diff --git a/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir new file mode 100644 index 0000000000000..5bffa20842b0c --- /dev/null +++ b/mlir/test/Dialect/Vector/td/xfer-drop-unit-dims.mlir @@ -0,0 +1,11 @@ +module @transforms attributes { transform.with_named_sequence } { + transform.named_sequence @drop_unit_dims(%module: !transform.any_op {transform.readonly}) { + + %func_op = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.drop_inner_most_unit_dims_from_xfer_ops + } : !transform.op<"func.func"> + + transform.yield + } +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir index cd56c1bf9695b..18c28799a62e5 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -1,4 +1,6 @@ -// RUN: mlir-opt %s -test-vector-transfer-collapse-inner-most-dims -split-input-file | FileCheck %s +// RUN: mlir-opt -split-input-file \ +// RUN: -transform-preload-library='transform-library-paths=%p/td/xfer-drop-unit-dims.mlir' \ +// RUN: -transform-interpreter=entry-point=drop_unit_dims %s | FileCheck %s //----------------------------------------------------------------------------- // 1. vector.transfer_read diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index d6596cd341df7..c2d184626818f 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -344,36 +344,6 @@ struct TestVectorTransferOpt } }; -struct TestVectorTransferCollapseInnerMostContiguousDims - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestVectorTransferCollapseInnerMostContiguousDims) - - TestVectorTransferCollapseInnerMostContiguousDims() = default; - TestVectorTransferCollapseInnerMostContiguousDims( - const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - StringRef getArgument() const final { - return "test-vector-transfer-collapse-inner-most-dims"; - } - - StringRef getDescription() const final { - return "Test lowering patterns that reduces the rank of the vector " - "transfer memory and vector operands."; - } - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateDropInnerMostUnitDimsXferOpPatterns(patterns); - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestVectorSinkPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorSinkPatterns) @@ -1079,8 +1049,6 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); - PassRegistration(); PassRegistration();