Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] Move memref specific implementation of AllocationOpInterface to memref dialect directory #66637

Merged
merged 1 commit into from
Sep 20, 2023

Conversation

maerhart
Copy link
Member

Follow-up on #65578

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Changes

Follow-up on #65578

Full diff: https://github.com/llvm/llvm-project/pull/66637.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h (-3)
  • (added) mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h (+20)
  • (modified) mlir/include/mlir/InitAllDialects.h (+2)
  • (modified) mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp (-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp (-1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (-57)
  • (added) mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp (+69)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+1)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
index 92520eb13da6875..a6f668b26aa10e4 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h
@@ -211,9 +211,6 @@ std::unique_ptr<Pass> createBufferizationBufferizePass();
 // Registration
 //===----------------------------------------------------------------------===//
 
-/// Register external models for AllocationOpInterface.
-void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
-
 /// Generate the code for registering passes.
 #define GEN_PASS_REGISTRATION
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h.inc"
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h
new file mode 100644
index 000000000000000..aea05821fd1167c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- AllocationOpInterfaceImpl.h - Impl. of AllocationOpInterface -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerAllocationOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_ALLOCATIONOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 5b2b1ed24d5173d..f36b79e86832171 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -50,6 +50,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -147,6 +148,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   linalg::registerBufferizableOpInterfaceExternalModels(registry);
   linalg::registerTilingInterfaceExternalModels(registry);
   linalg::registerValueBoundsOpInterfaceExternalModels(registry);
+  memref::registerAllocationOpInterfaceExternalModels(registry);
   memref::registerBufferizableOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index b84cc452d0141cd..7a6d1858489d1e6 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -175,5 +175,4 @@ class BufferizationTransformDialectExtension
 void mlir::bufferization::registerTransformDialectExtension(
     DialectRegistry &registry) {
   registry.addExtensions<BufferizationTransformDialectExtension>();
-  bufferization::registerAllocationOpInterfaceExternalModels(registry);
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
index f74c6255c196ba5..a0a81d4add71210 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp
@@ -634,7 +634,6 @@ struct BufferDeallocationPass
   void getDependentDialects(DialectRegistry &registry) const override {
     registry.insert<bufferization::BufferizationDialect>();
     registry.insert<memref::MemRefDialect>();
-    registerAllocationOpInterfaceExternalModels(registry);
   }
 
   void runOnOperation() override {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 7358d0d465d3e3d..2edb27da98fe910 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -196,7 +196,6 @@ struct OneShotBufferizePass
   void getDependentDialects(DialectRegistry &registry) const override {
     registry
         .insert<bufferization::BufferizationDialect, memref::MemRefDialect>();
-    registerAllocationOpInterfaceExternalModels(registry);
   }
 
   void runOnOperation() override {
@@ -682,59 +681,3 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
   options.opFilter.allowDialect<BufferizationDialect>();
   return options;
 }
-
-//===----------------------------------------------------------------------===//
-// Default AllocationOpInterface implementation and registration
-//===----------------------------------------------------------------------===//
-
-namespace {
-struct DefaultAllocationInterface
-    : public bufferization::AllocationOpInterface::ExternalModel<
-          DefaultAllocationInterface, memref::AllocOp> {
-  static std::optional<Operation *> buildDealloc(OpBuilder &builder,
-                                                 Value alloc) {
-    return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
-        .getOperation();
-  }
-  static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
-    return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
-        .getResult();
-  }
-  static ::mlir::HoistingKind getHoistingKind() {
-    return HoistingKind::Loop | HoistingKind::Block;
-  }
-  static ::std::optional<::mlir::Operation *>
-  buildPromotedAlloc(OpBuilder &builder, Value alloc) {
-    Operation *definingOp = alloc.getDefiningOp();
-    return builder.create<memref::AllocaOp>(
-        definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
-        definingOp->getOperands(), definingOp->getAttrs());
-  }
-};
-
-struct DefaultAutomaticAllocationHoistingInterface
-    : public bufferization::AllocationOpInterface::ExternalModel<
-          DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
-  static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
-};
-
-struct DefaultReallocationInterface
-    : public bufferization::AllocationOpInterface::ExternalModel<
-          DefaultAllocationInterface, memref::ReallocOp> {
-  static std::optional<Operation *> buildDealloc(OpBuilder &builder,
-                                                 Value realloc) {
-    return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
-        .getOperation();
-  }
-};
-} // namespace
-
-void bufferization::registerAllocationOpInterfaceExternalModels(
-    DialectRegistry &registry) {
-  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
-    memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
-    memref::AllocaOp::attachInterface<
-        DefaultAutomaticAllocationHoistingInterface>(*ctx);
-    memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
-  });
-}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
new file mode 100644
index 000000000000000..c4334159443236e
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp
@@ -0,0 +1,69 @@
+//===- AllocationOpInterfaceImpl.cpp - Impl. of AllocationOpInterface -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+
+namespace {
+struct DefaultAllocationInterface
+    : public bufferization::AllocationOpInterface::ExternalModel<
+          DefaultAllocationInterface, memref::AllocOp> {
+  static std::optional<Operation *> buildDealloc(OpBuilder &builder,
+                                                 Value alloc) {
+    return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
+        .getOperation();
+  }
+  static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) {
+    return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc)
+        .getResult();
+  }
+  static ::mlir::HoistingKind getHoistingKind() {
+    return HoistingKind::Loop | HoistingKind::Block;
+  }
+  static ::std::optional<::mlir::Operation *>
+  buildPromotedAlloc(OpBuilder &builder, Value alloc) {
+    Operation *definingOp = alloc.getDefiningOp();
+    return builder.create<memref::AllocaOp>(
+        definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]),
+        definingOp->getOperands(), definingOp->getAttrs());
+  }
+};
+
+struct DefaultAutomaticAllocationHoistingInterface
+    : public bufferization::AllocationOpInterface::ExternalModel<
+          DefaultAutomaticAllocationHoistingInterface, memref::AllocaOp> {
+  static ::mlir::HoistingKind getHoistingKind() { return HoistingKind::Loop; }
+};
+
+struct DefaultReallocationInterface
+    : public bufferization::AllocationOpInterface::ExternalModel<
+          DefaultAllocationInterface, memref::ReallocOp> {
+  static std::optional<Operation *> buildDealloc(OpBuilder &builder,
+                                                 Value realloc) {
+    return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc)
+        .getOperation();
+  }
+};
+} // namespace
+
+void mlir::memref::registerAllocationOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) {
+    memref::AllocOp::attachInterface<DefaultAllocationInterface>(*ctx);
+    memref::AllocaOp::attachInterface<
+        DefaultAutomaticAllocationHoistingInterface>(*ctx);
+    memref::ReallocOp::attachInterface<DefaultReallocationInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
index ddd674c37c4e536..b16c281c93640ea 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMemRefTransforms
+  AllocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   ComposeSubView.cpp
   ExpandOps.cpp
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 9bea555f701757c..3449a9a1bbcabe0 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -11722,6 +11722,7 @@ cc_library(
         ":AffineDialect",
         ":AffineTransforms",
         ":AffineUtils",
+        ":AllocationOpInterface",
         ":ArithDialect",
         ":ArithTransforms",
         ":ArithUtils",

Copy link
Member

@matthias-springer matthias-springer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be marked NFC.

@maerhart maerhart merged commit 65341b0 into llvm:main Sep 20, 2023
5 checks passed
@maerhart maerhart deleted the merhart_move_allocationopinterfaceimpl branch September 20, 2023 12:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:memref mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants