-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref]: Add expand/collapse rewrite pattern to MemRef::CopyOp #67808
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-memref ChangesThis pattern is useful to adjust the memref copy ranks. Patch is 26.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67808.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
new file mode 100644
index 000000000000000..27a69ab93e42c74
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
@@ -0,0 +1,45 @@
+//===-- ExpandCollapseCopyOps.h - Expand/Collapse MemRef copy ranks --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Patterns for expand collapse MemRef copies.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+#include <functional>
+
+namespace mlir {
+class MLIRContext;
+class RewritePatternSet;
+
+namespace memref {
+
+typedef std::function<bool(memref::CopyOp)> ExpandCollapseFuncCB;
+inline bool expandCollapseAny([[maybe_unused]] memref::CopyOp copyOp) {
+ return true;
+}
+
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in whole dims),
+/// it will not be possible to expanded/collapsed the `memref::CopyOp`.
+
+void populateExpandCollapseCopyOpsPatterns(
+ RewritePatternSet &patterns, unsigned minRank = 1, unsigned maxRank = 1,
+ ExpandCollapseFuncCB funcCB = expandCollapseAny);
+
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index b16c281c93640ea..924feca4cad3012 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
AllocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
ComposeSubView.cpp
+ ExpandCollapseCopyOps.cpp
ExpandOps.cpp
ExpandRealloc.cpp
ExpandStridedMetadata.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
new file mode 100644
index 000000000000000..7905254e71e19fc
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
@@ -0,0 +1,238 @@
+//===- ExpandCollapseCopyOps.cpp - Expand/Collapse rank of source/target copies
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------===//
+//
+// This file contains rewrite patterns (transformations) to expand/collapse
+// MemRef copies. This is useful in architecture which have limitations on
+// dimensions of the copy operation.
+//
+//===--------------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "expand-collapse-copy-ops"
+
+using namespace mlir;
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape);
+#endif // NDEBUG
+
+namespace {
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in each dim),
+/// it will not be possible to expand/collapse the `memref::CopyOp`.
+
+struct ExpandCollapseCopyOpConverter : public OpRewritePattern<memref::CopyOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ ExpandCollapseCopyOpConverter(MLIRContext *context, unsigned minRank,
+ unsigned maxRank,
+ memref::ExpandCollapseFuncCB funcCB)
+ : OpRewritePattern<memref::CopyOp>(context, /*benefit=*/1),
+ minRank(minRank), maxRank(maxRank), funcCB(funcCB) {
+ assert(minRank <= maxRank && "invalid ranks range");
+ }
+
+ LogicalResult matchAndRewrite(memref::CopyOp copyOp,
+ PatternRewriter &rewriter) const final {
+ MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+ unsigned rank = memRefType.getRank();
+
+ if (!funcCB(copyOp)) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Skip rewriting " << copyOp << ", filtered by funcCB\n");
+ return failure();
+ } else if (rank >= minRank && rank <= maxRank) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Skip rewriting " << copyOp
+ << ", operation does not need to expand/collapse\n");
+ return failure();
+ }
+
+ if (rank > maxRank) {
+ return collapseCopyOpRank(copyOp, maxRank, rewriter);
+ } else {
+ assert(rank < minRank);
+ expandCopyOpRank(copyOp, minRank, rewriter);
+ // Expand is always successful.
+ return success();
+ }
+ }
+
+private:
+ unsigned minRank;
+ unsigned maxRank;
+ // Accept callback to select which `memref::CopyOp` to collapse/expand.
+ memref::ExpandCollapseFuncCB funcCB;
+
+ // Expand the `copyOp` source/target dims to newRank by
+ // adding new dims in size of `1`.
+ void expandCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+ PatternRewriter &rewriter) const;
+ // Collapse the `copyOp` source/target dims to newRank.
+ // The function tries to collapse starting from the most inner dims
+ // to the most outer dims.
+ // This function return failure if there are no dims to collapse.
+ LogicalResult collapseCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+ PatternRewriter &rewriter) const;
+ // Fill `collapsedShape` with a shape in size of `newRank`.
+ // The function tries to collapse starting from the most inner dims
+ // to the most outer dims of `memrefToCollapse`.
+ // This function return failure if there are no dims to collapse.
+ LogicalResult getCollapsedShape(MemRefType memrefToCollapse, unsigned newRank,
+ SmallVector<int64_t> &collapsedShape) const;
+};
+
+} // namespace
+
+void ExpandCollapseCopyOpConverter::expandCopyOpRank(
+ memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+ MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+ // New outer most dims will be 1s, rest dims are same as original shape.
+ auto shape = memRefType.getShape();
+ SmallVector<int64_t> newShape(newRank - memRefType.getRank(), 1);
+ newShape.insert(newShape.end(), shape.begin(), shape.end());
+
+#ifdef NDEBUG
+ LLVM_DEBUG(llvm::dbgs() << "Expanding shape " << shape_to_string(shape)
+ << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+ // Expand reassociation is the same as collapse with opposing source/target
+ // shapes.
+ std::optional<SmallVector<ReassociationIndices>> reassociation =
+ getReassociationIndicesForCollapse(newShape, shape);
+ assert(reassociation && "expected reassociation to be valid for expand");
+
+ rewriter.setInsertionPoint(copyOp);
+ Value expandShapeSrc = rewriter.create<memref::ExpandShapeOp>(
+ copyOp.getLoc(), newShape, copyOp.getSource(), *reassociation);
+ Value expandShapeTarget = rewriter.create<memref::ExpandShapeOp>(
+ copyOp.getLoc(), newShape, copyOp.getTarget(), *reassociation);
+
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, expandShapeSrc,
+ expandShapeTarget);
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::collapseCopyOpRank(
+ memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+ MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+ auto shape = memRefType.getShape();
+ SmallVector<int64_t> collapsedShape;
+ if (failed(getCollapsedShape(memRefType, newRank, collapsedShape)))
+ return failure();
+
+ std::optional<SmallVector<ReassociationIndices>> reassociation =
+ getReassociationIndicesForCollapse(shape, collapsedShape);
+ assert(reassociation && "expected reassociation to be valid for collapse");
+
+ rewriter.setInsertionPoint(copyOp);
+ Value collapseShapeSrc = rewriter.create<memref::CollapseShapeOp>(
+ copyOp.getLoc(), copyOp.getSource(), *reassociation);
+ Value collapseShapeTarget = rewriter.create<memref::CollapseShapeOp>(
+ copyOp.getLoc(), copyOp.getTarget(), *reassociation);
+
+ rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, collapseShapeSrc,
+ collapseShapeTarget);
+
+ return success();
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::getCollapsedShape(
+ MemRefType memrefToCollapse, unsigned newRank,
+ SmallVector<int64_t> &collapsedShape) const {
+ auto shape = memrefToCollapse.getShape();
+ auto rank = memrefToCollapse.getRank();
+ int dimsToCollapse = rank - newRank;
+ assert(dimsToCollapse > 0);
+
+ // Try to find `dimsToCollapse` dims we can collapse, starting with most inner
+ // dim to collapse.
+ for (int firstDimToCollapse = rank - dimsToCollapse - 1;
+ firstDimToCollapse >= 0; --firstDimToCollapse) {
+ SmallVector<int64_t> newShape;
+
+ unsigned collapsedDims =
+ std::accumulate(shape.begin() + firstDimToCollapse,
+ shape.begin() + firstDimToCollapse + dimsToCollapse + 1,
+ 1, std::multiplies<unsigned>());
+
+ // Generate new shape in `newRank` size. All collapse dims we be to set
+ // `collapsedDims`.
+ for (int i = 0; i < rank; ++i) {
+ if (i == firstDimToCollapse)
+ newShape.push_back(collapsedDims);
+ else if (i < firstDimToCollapse ||
+ i > firstDimToCollapse + dimsToCollapse)
+ newShape.push_back(shape[i]);
+ }
+ assert(newShape.size() == newRank);
+ assert(std::accumulate(shape.begin(), shape.end(), 1,
+ std::multiplies<unsigned>()) ==
+ std::accumulate(newShape.begin(), newShape.end(), 1,
+ std::multiplies<unsigned>()));
+
+#ifdef NDEBUG
+ LLVM_DEBUG(llvm::dbgs()
+ << "trying to collapse shape " << shape_to_string(shape)
+ << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+ std::optional<SmallVector<ReassociationIndices>> reassociation =
+ getReassociationIndicesForCollapse(shape, newShape);
+ assert(reassociation && "reassociation must be valid for collapse");
+ if (memref::CollapseShapeOp::isGuaranteedCollapsible(memrefToCollapse,
+ *reassociation)) {
+ collapsedShape = std::move(newShape);
+ return success();
+ }
+ }
+
+ return failure();
+}
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape) {
+ std::ostringstream shapeStream;
+
+ for (auto dim : shape) {
+ shapeStream << dim << 'x';
+ }
+
+ std::string shapeStr = shapeStream.str();
+
+ // Remove the trailing 'x' character.
+ if (!shapeStr.empty()) {
+ shapeStr.pop_back();
+ }
+
+ return shapeStr;
+}
+#endif // NDEBUG
+
+void memref::populateExpandCollapseCopyOpsPatterns(
+ RewritePatternSet &patterns, unsigned minRank, unsigned maxRank,
+ memref::ExpandCollapseFuncCB funcCB) {
+ patterns.add<ExpandCollapseCopyOpConverter>(patterns.getContext(), minRank,
+ maxRank, funcCB);
+}
diff --git a/mlir/test/Transforms/expand-collapse-copy-ops.mlir b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
new file mode 100644
index 000000000000000..b3cd187424e084b
--- /dev/null
+++ b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
@@ -0,0 +1,141 @@
+// RUN: mlir-opt -test-expand-collapse-copy-ops="minRank=2 maxRank=3" %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @empty() {
+// CHECK: return
+// CHECK: }
+func.func @empty() -> () {
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_to_expand(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<6xi32>) {
+// CHECK: %[[VAL_1:.*]] = memref.alloc() : memref<6xi32>
+// CHECK: %[[VAL_2:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK: %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x6xi32> to memref<1x6xi32>
+// CHECK: return
+// CHECK: }
+func.func @memref_copy_to_expand(%arg0: memref<6xi32>) {
+ %0 = memref.alloc() : memref<6xi32>
+ memref.copy %arg0, %0 : memref<6xi32> to memref<6xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_to_collapse(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x5x24x48xi32>) {
+// CHECK: %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK: memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x5x1152xi32> to memref<1x5x1152xi32>
+// CHECK: return
+// CHECK: }
+func.func @memref_copy_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) {
+ memref.copy %arg0, %arg1 : memref<1x5x24x48xi32> to memref<1x5x24x48xi32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_collapse_expand_in_loop(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 5760 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK: %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK: %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK: scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] {
+// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK: %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK: %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
+// CHECK: %[[VAL_14:.*]] = memref.expand_shape %[[VAL_10]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK: %[[VAL_15:.*]] = memref.expand_shape %[[VAL_13]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK: %[[VAL_16:.*]] = memref.alloc() : memref<16xf32>
+// CHECK: %[[VAL_17:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK: %[[VAL_18:.*]] = memref.expand_shape %[[VAL_16]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK: memref.copy %[[VAL_17]], %[[VAL_18]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK: %[[VAL_19:.*]] = memref.alloc() : memref<16xf32>
+// CHECK: linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%[[VAL_13]], %[[VAL_16]] : memref<16xf32>, memref<16xf32>) outs(%[[VAL_19]] : memref<16xf32>) {
+// CHECK: ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK: %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK: linalg.yield %[[VAL_23]] : f32
+// CHECK: }
+// CHECK: %[[VAL_24:.*]] = memref.expand_shape %[[VAL_19]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK: %[[VAL_25:.*]] = memref.expand_shape %[[VAL_12]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK: memref.copy %[[VAL_24]], %[[VAL_25]] : memref<1x16xf32> to memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK: }
+// CHECK: return %[[VAL_5]] : memref<1x5x24x48xf32>
+// CHECK: }
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @memref_copy_collapse_expand_in_loop(%arg0: memref<1x5x24x48xf32>, %arg1: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+ %c5760 = arith.constant 5760 : index
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+ %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+ %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+ %collapse_shape_1 = memref.collapse_shape %alloc [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+ scf.for %arg2 = %c0 to %c5760 step %c16 {
+ %subview = memref.subview %collapse_shape[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+ %subview_2 = memref.subview %collapse_shape_0[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+ %subview_3 = memref.subview %collapse_shape_1[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+ %alloc_4 = memref.alloc() : memref<16xf32>
+ memref.copy %subview, %alloc_4 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+ %alloc_5 = memref.alloc() : memref<16xf32>
+ memref.copy %subview_2, %alloc_5 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+ %alloc_6 = memref.alloc() : memref<16xf32>
+ linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%alloc_4, %alloc_5 : memref<16xf32>, memref<16xf32>) outs(%alloc_6 : memref<16xf32>) {
+ ^bb0(%in: f32, %in_7: f32, %out: f32):
+ %0 = arith.addf %in, %in_7 : f32
+ linalg.yield %0 : f32
+ }
+ memref.copy %alloc_6, %subview_3 : memref<16xf32> to memref<16xf32, strided<[1], offset: ?>>
+ }
+ return %alloc : memref<1x5x24x48xf32>
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @memref_copy_strided_to_collapse(
+// CHECK-SAME: %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME: ...
[truncated]
|
Hmm, could you please elaborate on the rationale for this? I am generally quite opposed to adding transformations on Have you considered using a If anything, I would just delete |
The rational for this pattern is this:
I can refactor PromoteOp and this pattern to be based on |
@AviadCo thanks for your description. I would recommend trying to use/refactor the logic of Atm this works on tensors but could be refactored to work with memref if needed. If possible, I'd recommend going for transforms on tensors and using bufferization.bufferize_to_alloc_tensor in mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir, which takes a memory space and connects properly to bufferization. There is also mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir for other related tests. |
@nicolasvasilache thanks for the ideas! Regarding I also investigated |
@nicolasvasilache I created 2 PRs: I believe with those 2 PRs I may use linalg::CopyOp as a promotion (DMA import/export) and then will be able to filter the linalg::CopyOp operation by supplying |
This pattern is useful to adjust the memref copy ranks.