diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 70d424bae9285..caec229207ea6 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -717,7 +717,7 @@ def LowerUnPackOp : Op { let description = [{ Lower a linalg.unpack into empty + linalg.transpose + tensor.collapse_shape + - tensor.extract_slice. + tensor.extract_slice + linalg.copy. #### Return modes @@ -725,7 +725,7 @@ def LowerUnPackOp : Op:$target, @@ -733,7 +733,8 @@ def LowerUnPackOp : Op:$empty_op, Transform_ConcreteOpType<"linalg.transpose">:$transpose_op, Transform_ConcreteOpType<"tensor.collapse_shape">:$collapse_shape_op, - Transform_ConcreteOpType<"tensor.extract_slice">:$extract_slice_op); + Transform_ConcreteOpType<"tensor.extract_slice">:$extract_slice_op, + Transform_ConcreteOpType<"linalg.copy">:$copy_op); let assemblyFormat = [{ $target attr-dict `:` functional-type(operands, results) }]; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index d1f313098a2c1..fb9cede670801 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1360,9 +1360,10 @@ struct LowerUnPackOpResult { linalg::TransposeOp transposeOp; tensor::CollapseShapeOp collapseShapeOp; tensor::ExtractSliceOp extractSliceOp; + linalg::CopyOp copyOp; }; -/// Rewrite pack as empty + transpose + reshape + extract_slice. +/// Rewrite pack as empty + transpose + reshape + extract_slice + copy. FailureOr lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, bool lowerUnpadLikeWithExtractSlice = true); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index e945a15476b3a..309a4d989465d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1556,6 +1556,7 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( transformResults.push_back(res->transposeOp); transformResults.push_back(res->collapseShapeOp); transformResults.push_back(res->extractSliceOp); + transformResults.push_back(res->copyOp); return DiagnosedSilenceableFailure::success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index eb3eb48a7fe34..2b4986aeac14f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -382,7 +382,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, - /*reshapeOp=*/nullptr, extractSliceOp}; + /*reshapeOp=*/nullptr, extractSliceOp, + /*copyOp=*/nullptr}; } // 1. Compute the permutation vector to shuffle packed shape into the shape @@ -444,7 +445,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, // 7. Replace unPackOp by copyOp. rewriter.replaceOp(unPackOp, copyOp->getResults()); - return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; + return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp, + copyOp}; } SmallVector diff --git a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir index 9e7681d1a1b7d..b6fe67a9ae1f3 100644 --- a/mlir/test/Dialect/Linalg/transform-lower-pack.mlir +++ b/mlir/test/Dialect/Linalg/transform-lower-pack.mlir @@ -185,7 +185,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } @@ -220,7 +221,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } @@ -254,7 +256,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } @@ -286,7 +289,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } @@ -554,7 +558,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } @@ -594,7 +599,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } @@ -637,7 +643,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } @@ -677,7 +684,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } @@ -711,7 +719,8 @@ module attributes {transform.with_named_sequence} { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield } } diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index d72ab080f3c5c..dc4d2e434c0db 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -159,7 +159,8 @@ module { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) %root = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op @@ -220,7 +221,8 @@ module { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) %root = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir index a7bb039b04102..08dbe7c0ef345 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/pack-unpack-mmt4d.mlir @@ -159,7 +159,8 @@ module @transforms attributes { transform.with_named_sequence } { -> (!transform.op<"tensor.empty">, !transform.op<"linalg.transpose">, !transform.op<"tensor.collapse_shape">, - !transform.op<"tensor.extract_slice">) + !transform.op<"tensor.extract_slice">, + !transform.op<"linalg.copy">) transform.yield }