Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][memref]: Add expand/collapse rewrite pattern to MemRef::CopyOp #67808

Closed
wants to merge 0 commits into from

Conversation

AviadCo
Copy link
Contributor

@AviadCo AviadCo commented Sep 29, 2023

This pattern is useful to adjust the memref copy ranks.

@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Sep 29, 2023
@AviadCo AviadCo requested a review from amrami September 29, 2023 14:18
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 29, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Changes

This pattern is useful to adjust the memref copy ranks.


Patch is 26.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67808.diff

7 Files Affected:

  • (added) mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h (+45)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp (+238)
  • (added) mlir/test/Transforms/expand-collapse-copy-ops.mlir (+141)
  • (modified) mlir/test/lib/Dialect/MemRef/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Dialect/MemRef/TestExpandCollapseCopyOps.cpp (+66)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
new file mode 100644
index 000000000000000..27a69ab93e42c74
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h
@@ -0,0 +1,45 @@
+//===-- ExpandCollapseCopyOps.h - Expand/Collapse MemRef copy ranks      --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Patterns for expand collapse MemRef copies.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+#define MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+#include <functional>
+
+namespace mlir {
+class MLIRContext;
+class RewritePatternSet;
+
+namespace memref {
+
+typedef std::function<bool(memref::CopyOp)> ExpandCollapseFuncCB;
+inline bool expandCollapseAny([[maybe_unused]] memref::CopyOp copyOp) {
+  return true;
+}
+
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in whole dims),
+/// it will not be possible to expanded/collapsed the `memref::CopyOp`.
+
+void populateExpandCollapseCopyOpsPatterns(
+    RewritePatternSet &patterns, unsigned minRank = 1, unsigned maxRank = 1,
+    ExpandCollapseFuncCB funcCB = expandCollapseAny);
+
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_TRANSFORMS_EXPAND_COLLAPSE_COPY_OPS_H_
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index b16c281c93640ea..924feca4cad3012 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms
   AllocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   ComposeSubView.cpp
+  ExpandCollapseCopyOps.cpp
   ExpandOps.cpp
   ExpandRealloc.cpp
   ExpandStridedMetadata.cpp
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
new file mode 100644
index 000000000000000..7905254e71e19fc
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.cpp
@@ -0,0 +1,238 @@
+//===- ExpandCollapseCopyOps.cpp - Expand/Collapse rank of source/target copies
+//-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===---------------------------------------------------------------------------===//
+//
+// This file contains rewrite patterns (transformations) to expand/collapse
+// MemRef copies. This is useful in architecture which have limitations on
+// dimensions of the copy operation.
+//
+//===--------------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/ExpandCollapseCopyOps.h"
+#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+#define DEBUG_TYPE "expand-collapse-copy-ops"
+
+using namespace mlir;
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape);
+#endif // NDEBUG
+
+namespace {
+/// ExpandCollapseCopyOpConverter is a rewrite pattern that checks
+/// if a `memref::CopyOp` should be expanded/collapsed into `minRank`
+/// `maxRank` ranks. A selective callback may be provided to distinguish
+/// which operations should be expanded/collapsed.
+/// In some cases (i.e. the source/target are strided in each dim),
+/// it will not be possible to expand/collapse the `memref::CopyOp`.
+
+struct ExpandCollapseCopyOpConverter : public OpRewritePattern<memref::CopyOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  ExpandCollapseCopyOpConverter(MLIRContext *context, unsigned minRank,
+                                unsigned maxRank,
+                                memref::ExpandCollapseFuncCB funcCB)
+      : OpRewritePattern<memref::CopyOp>(context, /*benefit=*/1),
+        minRank(minRank), maxRank(maxRank), funcCB(funcCB) {
+    assert(minRank <= maxRank && "invalid ranks range");
+  }
+
+  LogicalResult matchAndRewrite(memref::CopyOp copyOp,
+                                PatternRewriter &rewriter) const final {
+    MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+    unsigned rank = memRefType.getRank();
+
+    if (!funcCB(copyOp)) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Skip rewriting " << copyOp << ", filtered by funcCB\n");
+      return failure();
+    } else if (rank >= minRank && rank <= maxRank) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Skip rewriting " << copyOp
+                 << ", operation does not need to expand/collapse\n");
+      return failure();
+    }
+
+    if (rank > maxRank) {
+      return collapseCopyOpRank(copyOp, maxRank, rewriter);
+    } else {
+      assert(rank < minRank);
+      expandCopyOpRank(copyOp, minRank, rewriter);
+      // Expand is always successful.
+      return success();
+    }
+  }
+
+private:
+  unsigned minRank;
+  unsigned maxRank;
+  // Accept callback to select which `memref::CopyOp` to collapse/expand.
+  memref::ExpandCollapseFuncCB funcCB;
+
+  // Expand the `copyOp` source/target dims to newRank by
+  // adding new dims in size of `1`.
+  void expandCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+                        PatternRewriter &rewriter) const;
+  // Collapse the `copyOp` source/target dims to newRank.
+  // The function tries to collapse starting from the most inner dims
+  // to the most outer dims.
+  // This function return failure if there are no dims to collapse.
+  LogicalResult collapseCopyOpRank(memref::CopyOp copyOp, unsigned newRank,
+                                   PatternRewriter &rewriter) const;
+  // Fill `collapsedShape` with a shape in size of `newRank`.
+  // The function tries to collapse starting from the most inner dims
+  // to the most outer dims of `memrefToCollapse`.
+  // This function return failure if there are no dims to collapse.
+  LogicalResult getCollapsedShape(MemRefType memrefToCollapse, unsigned newRank,
+                                  SmallVector<int64_t> &collapsedShape) const;
+};
+
+} // namespace
+
+void ExpandCollapseCopyOpConverter::expandCopyOpRank(
+    memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+  MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+  // New outer most dims will be 1s, rest dims are same as original shape.
+  auto shape = memRefType.getShape();
+  SmallVector<int64_t> newShape(newRank - memRefType.getRank(), 1);
+  newShape.insert(newShape.end(), shape.begin(), shape.end());
+
+#ifdef NDEBUG
+  LLVM_DEBUG(llvm::dbgs() << "Expanding shape " << shape_to_string(shape)
+                          << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+  // Expand reassociation is the same as collapse with opposing source/target
+  // shapes.
+  std::optional<SmallVector<ReassociationIndices>> reassociation =
+      getReassociationIndicesForCollapse(newShape, shape);
+  assert(reassociation && "expected reassociation to be valid for expand");
+
+  rewriter.setInsertionPoint(copyOp);
+  Value expandShapeSrc = rewriter.create<memref::ExpandShapeOp>(
+      copyOp.getLoc(), newShape, copyOp.getSource(), *reassociation);
+  Value expandShapeTarget = rewriter.create<memref::ExpandShapeOp>(
+      copyOp.getLoc(), newShape, copyOp.getTarget(), *reassociation);
+
+  rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, expandShapeSrc,
+                                              expandShapeTarget);
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::collapseCopyOpRank(
+    memref::CopyOp copyOp, unsigned newRank, PatternRewriter &rewriter) const {
+  MemRefType memRefType = cast<MemRefType>(copyOp.getSource().getType());
+
+  auto shape = memRefType.getShape();
+  SmallVector<int64_t> collapsedShape;
+  if (failed(getCollapsedShape(memRefType, newRank, collapsedShape)))
+    return failure();
+
+  std::optional<SmallVector<ReassociationIndices>> reassociation =
+      getReassociationIndicesForCollapse(shape, collapsedShape);
+  assert(reassociation && "expected reassociation to be valid for collapse");
+
+  rewriter.setInsertionPoint(copyOp);
+  Value collapseShapeSrc = rewriter.create<memref::CollapseShapeOp>(
+      copyOp.getLoc(), copyOp.getSource(), *reassociation);
+  Value collapseShapeTarget = rewriter.create<memref::CollapseShapeOp>(
+      copyOp.getLoc(), copyOp.getTarget(), *reassociation);
+
+  rewriter.replaceOpWithNewOp<memref::CopyOp>(copyOp, collapseShapeSrc,
+                                              collapseShapeTarget);
+
+  return success();
+}
+
+LogicalResult ExpandCollapseCopyOpConverter::getCollapsedShape(
+    MemRefType memrefToCollapse, unsigned newRank,
+    SmallVector<int64_t> &collapsedShape) const {
+  auto shape = memrefToCollapse.getShape();
+  auto rank = memrefToCollapse.getRank();
+  int dimsToCollapse = rank - newRank;
+  assert(dimsToCollapse > 0);
+
+  // Try to find `dimsToCollapse` dims we can collapse, starting with most inner
+  // dim to collapse.
+  for (int firstDimToCollapse = rank - dimsToCollapse - 1;
+       firstDimToCollapse >= 0; --firstDimToCollapse) {
+    SmallVector<int64_t> newShape;
+
+    unsigned collapsedDims =
+        std::accumulate(shape.begin() + firstDimToCollapse,
+                        shape.begin() + firstDimToCollapse + dimsToCollapse + 1,
+                        1, std::multiplies<unsigned>());
+
+    // Generate new shape in `newRank` size. All collapse dims we be to set
+    // `collapsedDims`.
+    for (int i = 0; i < rank; ++i) {
+      if (i == firstDimToCollapse)
+        newShape.push_back(collapsedDims);
+      else if (i < firstDimToCollapse ||
+               i > firstDimToCollapse + dimsToCollapse)
+        newShape.push_back(shape[i]);
+    }
+    assert(newShape.size() == newRank);
+    assert(std::accumulate(shape.begin(), shape.end(), 1,
+                           std::multiplies<unsigned>()) ==
+           std::accumulate(newShape.begin(), newShape.end(), 1,
+                           std::multiplies<unsigned>()));
+
+#ifdef NDEBUG
+    LLVM_DEBUG(llvm::dbgs()
+               << "trying to collapse shape " << shape_to_string(shape)
+               << " to " << shape_to_string(newShape) << "\n");
+#endif // NDEBUG
+
+    std::optional<SmallVector<ReassociationIndices>> reassociation =
+        getReassociationIndicesForCollapse(shape, newShape);
+    assert(reassociation && "reassociation must be valid for collapse");
+    if (memref::CollapseShapeOp::isGuaranteedCollapsible(memrefToCollapse,
+                                                         *reassociation)) {
+      collapsedShape = std::move(newShape);
+      return success();
+    }
+  }
+
+  return failure();
+}
+
+#ifndef NDEBUG
+static inline std::string shape_to_string(ArrayRef<int64_t> shape) {
+  std::ostringstream shapeStream;
+
+  for (auto dim : shape) {
+    shapeStream << dim << 'x';
+  }
+
+  std::string shapeStr = shapeStream.str();
+
+  // Remove the trailing 'x' character.
+  if (!shapeStr.empty()) {
+    shapeStr.pop_back();
+  }
+
+  return shapeStr;
+}
+#endif // NDEBUG
+
+void memref::populateExpandCollapseCopyOpsPatterns(
+    RewritePatternSet &patterns, unsigned minRank, unsigned maxRank,
+    memref::ExpandCollapseFuncCB funcCB) {
+  patterns.add<ExpandCollapseCopyOpConverter>(patterns.getContext(), minRank,
+                                              maxRank, funcCB);
+}
diff --git a/mlir/test/Transforms/expand-collapse-copy-ops.mlir b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
new file mode 100644
index 000000000000000..b3cd187424e084b
--- /dev/null
+++ b/mlir/test/Transforms/expand-collapse-copy-ops.mlir
@@ -0,0 +1,141 @@
+// RUN: mlir-opt -test-expand-collapse-copy-ops="minRank=2 maxRank=3" %s -split-input-file | FileCheck %s
+
+// CHECK-LABEL:   func.func @empty() {
+// CHECK:           return
+// CHECK:         }
+func.func @empty() -> () {
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_to_expand(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: memref<6xi32>) {
+// CHECK:           %[[VAL_1:.*]] = memref.alloc() : memref<6xi32>
+// CHECK:           %[[VAL_2:.*]] = memref.expand_shape %[[VAL_0]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK:           %[[VAL_3:.*]] = memref.expand_shape %[[VAL_1]] {{\[\[}}0, 1]] : memref<6xi32> into memref<1x6xi32>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x6xi32> to memref<1x6xi32>
+// CHECK:           return
+// CHECK:         }
+func.func @memref_copy_to_expand(%arg0: memref<6xi32>) {
+  %0 = memref.alloc() : memref<6xi32>
+  memref.copy %arg0, %0 : memref<6xi32> to memref<6xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_to_collapse(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME:                                       %[[VAL_1:.*]]: memref<1x5x24x48xi32>) {
+// CHECK:           %[[VAL_2:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK:           %[[VAL_3:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0], [1], [2, 3]] : memref<1x5x24x48xi32> into memref<1x5x1152xi32>
+// CHECK:           memref.copy %[[VAL_2]], %[[VAL_3]] : memref<1x5x1152xi32> to memref<1x5x1152xi32>
+// CHECK:           return
+// CHECK:         }
+func.func @memref_copy_to_collapse(%arg0: memref<1x5x24x48xi32>, %arg1: memref<1x5x24x48xi32>) {
+  memref.copy %arg0, %arg1 : memref<1x5x24x48xi32> to memref<1x5x24x48xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_collapse_expand_in_loop(
+// CHECK-SAME:                                                   %[[VAL_0:.*]]: memref<1x5x24x48xf32>,
+// CHECK-SAME:                                                   %[[VAL_1:.*]]: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 5760 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 16 : index
+// CHECK:           %[[VAL_5:.*]] = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+// CHECK:           %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_0]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           %[[VAL_7:.*]] = memref.collapse_shape %[[VAL_1]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+// CHECK:           scf.for %[[VAL_9:.*]] = %[[VAL_3]] to %[[VAL_2]] step %[[VAL_4]] {
+// CHECK:             %[[VAL_10:.*]] = memref.subview %[[VAL_6]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_11:.*]] = memref.subview %[[VAL_7]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_12:.*]] = memref.subview %[[VAL_8]]{{\[}}%[[VAL_9]]] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+// CHECK:             %[[VAL_13:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             %[[VAL_14:.*]] = memref.expand_shape %[[VAL_10]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             %[[VAL_15:.*]] = memref.expand_shape %[[VAL_13]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             memref.copy %[[VAL_14]], %[[VAL_15]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK:             %[[VAL_16:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             %[[VAL_17:.*]] = memref.expand_shape %[[VAL_11]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             %[[VAL_18:.*]] = memref.expand_shape %[[VAL_16]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             memref.copy %[[VAL_17]], %[[VAL_18]] : memref<1x16xf32, strided<[16, 1], offset: ?>> to memref<1x16xf32>
+// CHECK:             %[[VAL_19:.*]] = memref.alloc() : memref<16xf32>
+// CHECK:             linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%[[VAL_13]], %[[VAL_16]] : memref<16xf32>, memref<16xf32>) outs(%[[VAL_19]] : memref<16xf32>) {
+// CHECK:             ^bb0(%[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32, %[[VAL_22:.*]]: f32):
+// CHECK:               %[[VAL_23:.*]] = arith.addf %[[VAL_20]], %[[VAL_21]] : f32
+// CHECK:               linalg.yield %[[VAL_23]] : f32
+// CHECK:             }
+// CHECK:             %[[VAL_24:.*]] = memref.expand_shape %[[VAL_19]] {{\[\[}}0, 1]] : memref<16xf32> into memref<1x16xf32>
+// CHECK:             %[[VAL_25:.*]] = memref.expand_shape %[[VAL_12]] {{\[\[}}0, 1]] : memref<16xf32, strided<[1], offset: ?>> into memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:             memref.copy %[[VAL_24]], %[[VAL_25]] : memref<1x16xf32> to memref<1x16xf32, strided<[16, 1], offset: ?>>
+// CHECK:           }
+// CHECK:           return %[[VAL_5]] : memref<1x5x24x48xf32>
+// CHECK:         }
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @memref_copy_collapse_expand_in_loop(%arg0: memref<1x5x24x48xf32>, %arg1: memref<1x5x24x48xf32>) -> memref<1x5x24x48xf32> {
+    %c5760 = arith.constant 5760 : index
+    %c0 = arith.constant 0 : index
+    %c16 = arith.constant 16 : index
+    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x5x24x48xf32>
+    %collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    %collapse_shape_0 = memref.collapse_shape %arg1 [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    %collapse_shape_1 = memref.collapse_shape %alloc [[0, 1, 2, 3]] : memref<1x5x24x48xf32> into memref<5760xf32>
+    scf.for %arg2 = %c0 to %c5760 step %c16 {
+      %subview = memref.subview %collapse_shape[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %subview_2 = memref.subview %collapse_shape_0[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %subview_3 = memref.subview %collapse_shape_1[%arg2] [16] [1] : memref<5760xf32> to memref<16xf32, strided<[1], offset: ?>>
+      %alloc_4 = memref.alloc() : memref<16xf32>
+      memref.copy %subview, %alloc_4 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+      %alloc_5 = memref.alloc() : memref<16xf32>
+      memref.copy %subview_2, %alloc_5 : memref<16xf32, strided<[1], offset: ?>> to memref<16xf32>
+      %alloc_6 = memref.alloc() : memref<16xf32>
+      linalg.generic {doc = "", indexing_maps = [#map, #map, #map], iterator_types = ["parallel"], library_call = ""} ins(%alloc_4, %alloc_5 : memref<16xf32>, memref<16xf32>) outs(%alloc_6 : memref<16xf32>) {
+      ^bb0(%in: f32, %in_7: f32, %out: f32):
+        %0 = arith.addf %in, %in_7 : f32
+        linalg.yield %0 : f32
+      }
+      memref.copy %alloc_6, %subview_3 : memref<16xf32> to memref<16xf32, strided<[1], offset: ?>>
+    }
+    return %alloc : memref<1x5x24x48xf32>
+  }
+}
+
+// -----
+
+// CHECK-LABEL:   func.func @memref_copy_strided_to_collapse(
+// CHECK-SAME:                                               %[[VAL_0:.*]]: memref<1x5x24x48xi32>,
+// CHECK-SAME:                 ...
[truncated]

@nicolasvasilache
Copy link
Contributor

Hmm, could you please elaborate on the rationale for this?

I am generally quite opposed to adding transformations on memref.copy because this is a very specific specialization of linalg.copy that almost always results in code duplication.

Have you considered using a linalg.copy instead and using/extending the existing transforms on such abstractions?

If anything, I would just delete memref.copy.

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 3, 2023

Hi @nicolasvasilache

The rational for this pattern is this:

  1. Due to HW limitations, compute operations must use local memory space (we need to DMA import/export).
  2. We use PromoteOp to inject memref copies from global memory to local memory (copy between memory spaces).
  3. Due to HW limitations, not every shape / strides are supported in the DMA engine, for that we need to expand/collapse the memref copies that are considered as DMA import/export - which is the rewrite pattern motivation.

I can refactor PromoteOp and this pattern to be based on linalg.copy (instead of memref::copy) and add this pattern as new transform operation. If this would be useful, we can generalize it to any linalg.generic operation. I couldn't find any pattern/transform that does such job as expand/collapse according to rank parameters.

@nicolasvasilache
Copy link
Contributor

@AviadCo thanks for your description.

I would recommend trying to use/refactor the logic of linalg::collapseGenericOpIterationDims in mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp if possible.

Atm this works on tensors but could be refactored to work with memref if needed.
I don't know whether transformations on tensors are reasonable for you (vs on memref).
If you need to stay in memref land for your transforms, then refactoring makes sense.

If possible, I'd recommend going for transforms on tensors and using bufferization.bufferize_to_alloc_tensor in mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir, which takes a memory space and connects properly to bufferization. There is also mlir/test/Dialect/Linalg/pad-to-specific-memory-space.mlir for other related tests.

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 5, 2023

@nicolasvasilache thanks for the ideas!

Regarding linalg::collapseGenericOpIterationDims - this may work with some refactoring, I'll see what I can do to make some pattern match based on linalg.generic to collapse a linalg.copy and base the DMA on linalg.copy instead.

I also investigated bufferization.bufferize_to_alloc_tensor. We expect the inputs to first exist on global memory and to be able to work with them to promote (DMA import/export) to the local memory. So it seems that we better bufferize to allocation on global memory first and then use the promote transform to DMA it to local memory.

@AviadCo
Copy link
Contributor Author

AviadCo commented Oct 8, 2023

@nicolasvasilache I created 2 PRs:
#68522
#68526

I believe with those 2 PRs I may use linalg::CopyOp as a promotion (DMA import/export) and then will be able to filter the linalg::CopyOp operation by supplying GetCollapsableDimensionsFn which selectively decide if and how to collapse the copy opration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:memref mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants