Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] Add BufferViewFlowOpInterface #78718

Merged

Conversation

matthias-springer
Copy link
Member

This commit adds the BufferViewFlowOpInterface to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the BufferViewFlowAnalysis.

The new interface has two interface methods:

  • populateDependencies: Implementations use the provided callback to declare dependencies between operands and op results/region entry block arguments. E.g., for %r = arith.select %c, %m1, %m2 : memref<5xf32>, the interface implementation should declare two dependencies: %m1 -> %r and %m2 -> %r.
  • isTerminalBuffer: An SSA value is a terminal buffer if the buffer view flow analysis stops at the specified value. E.g., because the value is a newly allocated buffer or because no further information is available about the origin of the buffer.

Ops that implement the RegionBranchOpInterface or BranchOpInterface do not have to implement the BufferViewFlowOpInterface. The buffer dependencies can be inferred from those two interfaces.

This commit makes the BufferViewFlowAnalysis more accurate. For unknown ops, it conservatively used to declare all combinations of operands and op results/region entry block arguments as dependencies (false positives). This is no longer the case. While the analysis is still a "maybe" analysis with false positives (e.g., when analyzing ops such as arith.select or scf.if where the taken branch is not known at compile time), results and region entry block arguments of unknown ops are now marked as terminal buffers.

This commit addresses a TODO in BufferViewFlowAnalysis.cpp:

// TODO: We should have an op interface instead of a hard-coded list of
// interfaces/ops.

It is no longer needed to hard-code ops.

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 19, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-arith

Author: Matthias Springer (matthias-springer)

Changes

This commit adds the BufferViewFlowOpInterface to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the BufferViewFlowAnalysis.

The new interface has two interface methods:

  • populateDependencies: Implementations use the provided callback to declare dependencies between operands and op results/region entry block arguments. E.g., for %r = arith.select %c, %m1, %m2 : memref&lt;5xf32&gt;, the interface implementation should declare two dependencies: %m1 -> %r and %m2 -> %r.
  • isTerminalBuffer: An SSA value is a terminal buffer if the buffer view flow analysis stops at the specified value. E.g., because the value is a newly allocated buffer or because no further information is available about the origin of the buffer.

Ops that implement the RegionBranchOpInterface or BranchOpInterface do not have to implement the BufferViewFlowOpInterface. The buffer dependencies can be inferred from those two interfaces.

This commit makes the BufferViewFlowAnalysis more accurate. For unknown ops, it conservatively used to declare all combinations of operands and op results/region entry block arguments as dependencies (false positives). This is no longer the case. While the analysis is still a "maybe" analysis with false positives (e.g., when analyzing ops such as arith.select or scf.if where the taken branch is not known at compile time), results and region entry block arguments of unknown ops are now marked as terminal buffers.

This commit addresses a TODO in BufferViewFlowAnalysis.cpp:

// TODO: We should have an op interface instead of a hard-coded list of
// interfaces/ops.

It is no longer needed to hard-code ops.


Patch is 24.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78718.diff

16 Files Affected:

  • (added) mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h (+20)
  • (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h (+27)
  • (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td (+67)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h (+6)
  • (added) mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h (+20)
  • (modified) mlir/include/mlir/InitAllDialects.h (+4)
  • (added) mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp (+45)
  • (modified) mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp (+18)
  • (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+58-14)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp (+49)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+2)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+36)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 00000000000000..8f79ae4913d1c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
new file mode 100644
index 00000000000000..84e67fe72b623b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
@@ -0,0 +1,27 @@
+//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class ValueRange;
+
+namespace bufferization {
+
+using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;
+
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
new file mode 100644
index 00000000000000..baeb7308aad107
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
@@ -0,0 +1,67 @@
+//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
+#define BUFFER_VIEW_FLOW_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def BufferViewFlowOpInterface :
+    OpInterface<"BufferViewFlowOpInterface"> {
+  let description = [{
+    An op interface for the buffer view flow analysis. This interface describes
+    buffer dependencies between operands and op results/region entry block
+    arguments.
+  }];
+  let cppNamespace = "::mlir::bufferization";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Populate buffer dependencies between operands and op results/region
+          entry block arguments.
+
+          Implementations should register dependencies between an operand ("X")
+          and an op result/region entry block argument ("Y") if Y may depend
+          on X. Y depends on X if Y and X are the same buffer or if Y is a
+          subview of X.
+
+          Example:
+          ```
+          %r = arith.select %c, %m1, %m2 : memref<5xf32>
+          ```
+          In the above example, %0 may depend on %m1 or %m2 and a correct
+          interface implementation should call:
+          - "registerDependenciesFn(%m1, %r)".
+          - "registerDependenciesFn(%m2, %r)"
+        }],
+        /*retType=*/"void",
+        /*methodName=*/"populateDependencies",
+        /*args=*/(ins
+            "::mlir::bufferization::RegisterDependenciesFn"
+                :$registerDependenciesFn)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if the given value is a terminal buffer. A buffer value
+          is "terminal" if it cannot be traced back any further in the buffer
+          view flow analysis. E.g., because the value is a newly allocated
+          buffer or because there is not enough information available.
+
+          The given SSA value is guaranteed to be an OpResult of this operation
+          or a region entry block argument of this operation.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isTerminalBuffer",
+        /*args=*/(ins "Value":$value),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return false;"
+      >,
+  ];
+}
+
+#endif  // BUFFER_VIEW_FLOW_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 31a553f9a32f55..13a5bc370a4fce 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferDeallocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(BufferViewFlowOpInterface)
 
 set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
 mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 24825db69f90c5..110f7f51007d5b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -63,6 +63,9 @@ class BufferViewFlowAnalysis {
   /// results have to be changed.
   void rename(Value from, Value to);
 
+  /// Return "true" if the given value is a terminal.
+  bool isTerminalBuffer(Value value) const;
+
 private:
   /// This function constructs a mapping from values to its immediate
   /// dependencies.
@@ -70,6 +73,9 @@ class BufferViewFlowAnalysis {
 
   /// Maps values to all immediate dependencies this value can have.
   ValueMapT dependencies;
+
+  /// A set of all terminal values. I.e., values at which the analysis stopped.
+  DenseSet<Value> terminals;
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 00000000000000..1497070510dd6e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04..93833723229569 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -53,6 +54,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -145,6 +147,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   affine::registerValueBoundsOpInterfaceExternalModels(registry);
   arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
+  arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
   arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
       registry);
@@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   linalg::registerTilingInterfaceExternalModels(registry);
   linalg::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
+  memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerMemorySlotExternalModels(registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..77ba4269b5fcf6
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,45 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::arith;
+
+namespace mlir {
+namespace arith {
+namespace {
+
+struct SelectOpInterface
+    : public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
+                                                      SelectOp> {
+  void
+  populateDependencies(Operation *op,
+                       RegisterDependenciesFn registerDependenciesFn) const {
+    auto selectOp = cast<SelectOp>(op);
+
+    // Either one of the true/false value may be selected at runtime.
+    registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
+    registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
+  }
+};
+
+} // namespace
+} // namespace arith
+} // namespace mlir
+
+void arith::registerBufferViewFlowOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+    SelectOp::attachInterface<SelectOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 02240601bcd35a..12659eaba1fa5e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  BufferViewFlowOpInterfaceImpl.cpp
   EmulateUnsupportedFloats.cpp
   EmulateWideInt.cpp
   EmulateNarrowType.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
new file mode 100644
index 00000000000000..ea726a4bfc3fb9
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
@@ -0,0 +1,18 @@
+//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 9895db9d93ce0b..63dcc1eb233e92 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferDeallocationOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
+  BufferViewFlowOpInterface.cpp
   UnstructuredControlFlow.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 88ef1b639fc5ce..7cf202ac81d7c0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,12 +8,16 @@
 
 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SetVector.h"
 
 using namespace mlir;
+using namespace mlir::bufferization;
 
 /// Constructs a new alias analysis using the op provided.
 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
@@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
 void BufferViewFlowAnalysis::build(Operation *op) {
   // Registers all dependencies of the given values.
   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
-    for (auto [value, dep] : llvm::zip(values, dependencies))
+    for (auto [value, dep] : llvm::zip_equal(values, dependencies))
       this->dependencies[value].insert(dep);
   };
 
+  // Mark all buffer results and buffer region entry block arguments of the
+  // given op as terminals.
+  auto populateTerminalValues = [&](Operation *op) {
+    for (Value v : op->getResults())
+      if (isa<BaseMemRefType>(v.getType()))
+        this->terminals.insert(v);
+    for (Region &r : op->getRegions())
+      for (BlockArgument v : r.getArguments())
+        if (isa<BaseMemRefType>(v.getType()))
+          this->terminals.insert(v);
+  };
+
   op->walk([&](Operation *op) {
-    // TODO: We should have an op interface instead of a hard-coded list of
-    // interfaces/ops.
+    // Query BufferViewFlowOpInterface. If the op does not implement that
+    // interface, try to infer the dependencies from other interfaces that the
+    // op may implement.
+    if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
+      bufferViewFlowOp.populateDependencies(registerDependencies);
+      for (Value v : op->getResults())
+        if (isa<BaseMemRefType>(v.getType()) &&
+            bufferViewFlowOp.isTerminalBuffer(v))
+          this->terminals.insert(v);
+      for (Region &r : op->getRegions())
+        for (BlockArgument v : r.getArguments())
+          if (isa<BaseMemRefType>(v.getType()) &&
+              bufferViewFlowOp.isTerminalBuffer(v))
+            this->terminals.insert(v);
+      return WalkResult::advance();
+    }
 
     // Add additional dependencies created by view changes to the alias list.
     if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
-      dependencies[viewInterface.getViewSource()].insert(
-          viewInterface->getResult(0));
+      registerDependencies(viewInterface.getViewSource(),
+                           viewInterface->getResult(0));
       return WalkResult::advance();
     }
 
@@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
       return WalkResult::advance();
     }
 
-    // Unknown op: Assume that all operands alias with all results.
-    for (Value operand : op->getOperands()) {
-      if (!isa<BaseMemRefType>(operand.getType()))
-        continue;
-      for (Value result : op->getResults()) {
-        if (!isa<BaseMemRefType>(result.getType()))
-          continue;
-        registerDependencies({operand}, {result});
-      }
+    // Region terminators are handled together with RegionBranchOpInterface.
+    if (isa<RegionBranchTerminatorOpInterface>(op))
+      return WalkResult::advance();
+
+    if (isa<CallOpInterface>(op)) {
+      // This is an intra-function analysis. We have no information about other
+      // functions. Conservatively assume that each operand may alias with each
+      // result. Also mark the results are terminals because the function could
+      // return newly allocated buffers.
+      populateTerminalValues(op);
+      for (Value operand : op->getOperands())
+        for (Value result : op->getResults())
+          registerDependencies({operand}, {result});
+      return WalkResult::advance();
     }
+
+    // We have no information about unknown ops.
+    populateTerminalValues(op);
+
     return WalkResult::advance();
   });
 }
+
+bool BufferViewFlowAnalysis::isTerminalBuffer(Value value) const {
+  assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
+  return terminals.contains(value);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 8617c17e7a5e5e..58feae66427b7a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIRBufferizationDialect
+  MLIRBufferizationTransforms
   MLIRControlFlowInterfaces
   MLIRFuncDialect
   MLIRFunctionInterfaces
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..ff1df72ba37407
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,49 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::memref;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+struct ReallocOpInterface
+    : public B...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 19, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

This commit adds the BufferViewFlowOpInterface to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the BufferViewFlowAnalysis.

The new interface has two interface methods:

  • populateDependencies: Implementations use the provided callback to declare dependencies between operands and op results/region entry block arguments. E.g., for %r = arith.select %c, %m1, %m2 : memref&lt;5xf32&gt;, the interface implementation should declare two dependencies: %m1 -> %r and %m2 -> %r.
  • isTerminalBuffer: An SSA value is a terminal buffer if the buffer view flow analysis stops at the specified value. E.g., because the value is a newly allocated buffer or because no further information is available about the origin of the buffer.

Ops that implement the RegionBranchOpInterface or BranchOpInterface do not have to implement the BufferViewFlowOpInterface. The buffer dependencies can be inferred from those two interfaces.

This commit makes the BufferViewFlowAnalysis more accurate. For unknown ops, it conservatively used to declare all combinations of operands and op results/region entry block arguments as dependencies (false positives). This is no longer the case. While the analysis is still a "maybe" analysis with false positives (e.g., when analyzing ops such as arith.select or scf.if where the taken branch is not known at compile time), results and region entry block arguments of unknown ops are now marked as terminal buffers.

This commit addresses a TODO in BufferViewFlowAnalysis.cpp:

// TODO: We should have an op interface instead of a hard-coded list of
// interfaces/ops.

It is no longer needed to hard-code ops.


Patch is 24.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78718.diff

16 Files Affected:

  • (added) mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h (+20)
  • (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h (+27)
  • (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td (+67)
  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt (+1)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h (+6)
  • (added) mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h (+20)
  • (modified) mlir/include/mlir/InitAllDialects.h (+4)
  • (added) mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp (+45)
  • (modified) mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp (+18)
  • (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+58-14)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp (+49)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt (+2)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+36)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 00000000000000..8f79ae4913d1c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
new file mode 100644
index 00000000000000..84e67fe72b623b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
@@ -0,0 +1,27 @@
+//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class ValueRange;
+
+namespace bufferization {
+
+using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;
+
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
new file mode 100644
index 00000000000000..baeb7308aad107
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
@@ -0,0 +1,67 @@
+//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
+#define BUFFER_VIEW_FLOW_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def BufferViewFlowOpInterface :
+    OpInterface<"BufferViewFlowOpInterface"> {
+  let description = [{
+    An op interface for the buffer view flow analysis. This interface describes
+    buffer dependencies between operands and op results/region entry block
+    arguments.
+  }];
+  let cppNamespace = "::mlir::bufferization";
+  let methods = [
+      InterfaceMethod<
+        /*desc=*/[{
+          Populate buffer dependencies between operands and op results/region
+          entry block arguments.
+
+          Implementations should register dependencies between an operand ("X")
+          and an op result/region entry block argument ("Y") if Y may depend
+          on X. Y depends on X if Y and X are the same buffer or if Y is a
+          subview of X.
+
+          Example:
+          ```
+          %r = arith.select %c, %m1, %m2 : memref<5xf32>
+          ```
+          In the above example, %0 may depend on %m1 or %m2 and a correct
+          interface implementation should call:
+          - "registerDependenciesFn(%m1, %r)".
+          - "registerDependenciesFn(%m2, %r)"
+        }],
+        /*retType=*/"void",
+        /*methodName=*/"populateDependencies",
+        /*args=*/(ins
+            "::mlir::bufferization::RegisterDependenciesFn"
+                :$registerDependenciesFn)
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Return "true" if the given value is a terminal buffer. A buffer value
+          is "terminal" if it cannot be traced back any further in the buffer
+          view flow analysis. E.g., because the value is a newly allocated
+          buffer or because there is not enough information available.
+
+          The given SSA value is guaranteed to be an OpResult of this operation
+          or a region entry block argument of this operation.
+        }],
+        /*retType=*/"bool",
+        /*methodName=*/"isTerminalBuffer",
+        /*args=*/(ins "Value":$value),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/"return false;"
+      >,
+  ];
+}
+
+#endif  // BUFFER_VIEW_FLOW_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 31a553f9a32f55..13a5bc370a4fce 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
 add_mlir_interface(AllocationOpInterface)
 add_mlir_interface(BufferDeallocationOpInterface)
 add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(BufferViewFlowOpInterface)
 
 set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
 mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 24825db69f90c5..110f7f51007d5b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -63,6 +63,9 @@ class BufferViewFlowAnalysis {
   /// results have to be changed.
   void rename(Value from, Value to);
 
+  /// Return "true" if the given value is a terminal.
+  bool isTerminalBuffer(Value value) const;
+
 private:
   /// This function constructs a mapping from values to its immediate
   /// dependencies.
@@ -70,6 +73,9 @@ class BufferViewFlowAnalysis {
 
   /// Maps values to all immediate dependencies this value can have.
   ValueMapT dependencies;
+
+  /// A set of all terminal values. I.e., values at which the analysis stopped.
+  DenseSet<Value> terminals;
 };
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 00000000000000..1497070510dd6e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry &registry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04..93833723229569 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -53,6 +54,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
 #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
 #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
 #include "mlir/Dialect/Mesh/IR/MeshOps.h"
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -145,6 +147,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   affine::registerValueBoundsOpInterfaceExternalModels(registry);
   arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
   arith::registerBufferizableOpInterfaceExternalModels(registry);
+  arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
   arith::registerValueBoundsOpInterfaceExternalModels(registry);
   bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
       registry);
@@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
   linalg::registerTilingInterfaceExternalModels(registry);
   linalg::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerAllocationOpInterfaceExternalModels(registry);
+  memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
   memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
   memref::registerValueBoundsOpInterfaceExternalModels(registry);
   memref::registerMemorySlotExternalModels(registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..77ba4269b5fcf6
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,45 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::arith;
+
+namespace mlir {
+namespace arith {
+namespace {
+
+struct SelectOpInterface
+    : public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
+                                                      SelectOp> {
+  void
+  populateDependencies(Operation *op,
+                       RegisterDependenciesFn registerDependenciesFn) const {
+    auto selectOp = cast<SelectOp>(op);
+
+    // Either one of the true/false value may be selected at runtime.
+    registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
+    registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
+  }
+};
+
+} // namespace
+} // namespace arith
+} // namespace mlir
+
+void arith::registerBufferViewFlowOpInterfaceExternalModels(
+    DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+    SelectOp::attachInterface<SelectOpInterface>(*ctx);
+  });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 02240601bcd35a..12659eaba1fa5e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
   BufferDeallocationOpInterfaceImpl.cpp
   BufferizableOpInterfaceImpl.cpp
   Bufferize.cpp
+  BufferViewFlowOpInterfaceImpl.cpp
   EmulateUnsupportedFloats.cpp
   EmulateWideInt.cpp
   EmulateNarrowType.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
new file mode 100644
index 00000000000000..ea726a4bfc3fb9
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
@@ -0,0 +1,18 @@
+//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 9895db9d93ce0b..63dcc1eb233e92 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
   BufferDeallocationOpInterface.cpp
   BufferizationOps.cpp
   BufferizationDialect.cpp
+  BufferViewFlowOpInterface.cpp
   UnstructuredControlFlow.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 88ef1b639fc5ce..7cf202ac81d7c0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,12 +8,16 @@
 
 #include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
 
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
 #include "llvm/ADT/SetOperations.h"
 #include "llvm/ADT/SetVector.h"
 
 using namespace mlir;
+using namespace mlir::bufferization;
 
 /// Constructs a new alias analysis using the op provided.
 BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
@@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
 void BufferViewFlowAnalysis::build(Operation *op) {
   // Registers all dependencies of the given values.
   auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
-    for (auto [value, dep] : llvm::zip(values, dependencies))
+    for (auto [value, dep] : llvm::zip_equal(values, dependencies))
       this->dependencies[value].insert(dep);
   };
 
+  // Mark all buffer results and buffer region entry block arguments of the
+  // given op as terminals.
+  auto populateTerminalValues = [&](Operation *op) {
+    for (Value v : op->getResults())
+      if (isa<BaseMemRefType>(v.getType()))
+        this->terminals.insert(v);
+    for (Region &r : op->getRegions())
+      for (BlockArgument v : r.getArguments())
+        if (isa<BaseMemRefType>(v.getType()))
+          this->terminals.insert(v);
+  };
+
   op->walk([&](Operation *op) {
-    // TODO: We should have an op interface instead of a hard-coded list of
-    // interfaces/ops.
+    // Query BufferViewFlowOpInterface. If the op does not implement that
+    // interface, try to infer the dependencies from other interfaces that the
+    // op may implement.
+    if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
+      bufferViewFlowOp.populateDependencies(registerDependencies);
+      for (Value v : op->getResults())
+        if (isa<BaseMemRefType>(v.getType()) &&
+            bufferViewFlowOp.isTerminalBuffer(v))
+          this->terminals.insert(v);
+      for (Region &r : op->getRegions())
+        for (BlockArgument v : r.getArguments())
+          if (isa<BaseMemRefType>(v.getType()) &&
+              bufferViewFlowOp.isTerminalBuffer(v))
+            this->terminals.insert(v);
+      return WalkResult::advance();
+    }
 
     // Add additional dependencies created by view changes to the alias list.
     if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
-      dependencies[viewInterface.getViewSource()].insert(
-          viewInterface->getResult(0));
+      registerDependencies(viewInterface.getViewSource(),
+                           viewInterface->getResult(0));
       return WalkResult::advance();
     }
 
@@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
       return WalkResult::advance();
     }
 
-    // Unknown op: Assume that all operands alias with all results.
-    for (Value operand : op->getOperands()) {
-      if (!isa<BaseMemRefType>(operand.getType()))
-        continue;
-      for (Value result : op->getResults()) {
-        if (!isa<BaseMemRefType>(result.getType()))
-          continue;
-        registerDependencies({operand}, {result});
-      }
+    // Region terminators are handled together with RegionBranchOpInterface.
+    if (isa<RegionBranchTerminatorOpInterface>(op))
+      return WalkResult::advance();
+
+    if (isa<CallOpInterface>(op)) {
+      // This is an intra-function analysis. We have no information about other
+      // functions. Conservatively assume that each operand may alias with each
+      // result. Also mark the results are terminals because the function could
+      // return newly allocated buffers.
+      populateTerminalValues(op);
+      for (Value operand : op->getOperands())
+        for (Value result : op->getResults())
+          registerDependencies({operand}, {result});
+      return WalkResult::advance();
     }
+
+    // We have no information about unknown ops.
+    populateTerminalValues(op);
+
     return WalkResult::advance();
   });
 }
+
+bool BufferViewFlowAnalysis::isTerminalBuffer(Value value) const {
+  assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
+  return terminals.contains(value);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 8617c17e7a5e5e..58feae66427b7a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
   LINK_LIBS PUBLIC
   MLIRArithDialect
   MLIRBufferizationDialect
+  MLIRBufferizationTransforms
   MLIRControlFlowInterfaces
   MLIRFuncDialect
   MLIRFunctionInterfaces
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..ff1df72ba37407
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,49 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::memref;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+struct ReallocOpInterface
+    : public B...
[truncated]

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a test case easily that exercises this?

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Good to have this gap covered.

void
populateDependencies(Operation *op,
RegisterDependenciesFn registerDependenciesFn) const {
auto reallocOp = cast<ReallocOp>(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also update mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp simultaneously; your patch is needed there to correctly model alias relationships.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with that part of the code base, let's do that in a follow-up PR.

@llvmbot llvmbot added the bazel "Peripheral" support tier build system: utils/bazel label Mar 24, 2024
This commit adds the `BufferViewFlowOpInterface` to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the `BufferViewFlowAnalysis`.

There are currently no ops that implement this interface. The first op implementations will be added in a consecutive commit.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
Copy link

✅ With the latest revision this PR passed the Python code formatter.

Copy link

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer merged commit a45e58a into llvm:main Mar 24, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:arith mlir:bufferization Bufferization infrastructure mlir:memref mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants