-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization] Add BufferViewFlowOpInterface
#78718
[mlir][bufferization] Add BufferViewFlowOpInterface
#78718
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Matthias Springer (matthias-springer) ChangesThis commit adds the The new interface has two interface methods:
Ops that implement the This commit makes the This commit addresses a TODO in
It is no longer needed to hard-code ops. Patch is 24.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78718.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 00000000000000..8f79ae4913d1c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
new file mode 100644
index 00000000000000..84e67fe72b623b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
@@ -0,0 +1,27 @@
+//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class ValueRange;
+
+namespace bufferization {
+
+using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;
+
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
new file mode 100644
index 00000000000000..baeb7308aad107
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
@@ -0,0 +1,67 @@
+//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
+#define BUFFER_VIEW_FLOW_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def BufferViewFlowOpInterface :
+ OpInterface<"BufferViewFlowOpInterface"> {
+ let description = [{
+ An op interface for the buffer view flow analysis. This interface describes
+ buffer dependencies between operands and op results/region entry block
+ arguments.
+ }];
+ let cppNamespace = "::mlir::bufferization";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate buffer dependencies between operands and op results/region
+ entry block arguments.
+
+ Implementations should register dependencies between an operand ("X")
+ and an op result/region entry block argument ("Y") if Y may depend
+ on X. Y depends on X if Y and X are the same buffer or if Y is a
+ subview of X.
+
+ Example:
+ ```
+ %r = arith.select %c, %m1, %m2 : memref<5xf32>
+ ```
+ In the above example, %0 may depend on %m1 or %m2 and a correct
+ interface implementation should call:
+ - "registerDependenciesFn(%m1, %r)".
+ - "registerDependenciesFn(%m2, %r)"
+ }],
+ /*retType=*/"void",
+ /*methodName=*/"populateDependencies",
+ /*args=*/(ins
+ "::mlir::bufferization::RegisterDependenciesFn"
+ :$registerDependenciesFn)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return "true" if the given value is a terminal buffer. A buffer value
+ is "terminal" if it cannot be traced back any further in the buffer
+ view flow analysis. E.g., because the value is a newly allocated
+ buffer or because there is not enough information available.
+
+ The given SSA value is guaranteed to be an OpResult of this operation
+ or a region entry block argument of this operation.
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"isTerminalBuffer",
+ /*args=*/(ins "Value":$value),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return false;"
+ >,
+ ];
+}
+
+#endif // BUFFER_VIEW_FLOW_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 31a553f9a32f55..13a5bc370a4fce 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
add_mlir_interface(AllocationOpInterface)
add_mlir_interface(BufferDeallocationOpInterface)
add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(BufferViewFlowOpInterface)
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 24825db69f90c5..110f7f51007d5b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -63,6 +63,9 @@ class BufferViewFlowAnalysis {
/// results have to be changed.
void rename(Value from, Value to);
+ /// Return "true" if the given value is a terminal.
+ bool isTerminalBuffer(Value value) const;
+
private:
/// This function constructs a mapping from values to its immediate
/// dependencies.
@@ -70,6 +73,9 @@ class BufferViewFlowAnalysis {
/// Maps values to all immediate dependencies this value can have.
ValueMapT dependencies;
+
+ /// A set of all terminal values. I.e., values at which the analysis stopped.
+ DenseSet<Value> terminals;
};
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 00000000000000..1497070510dd6e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04..93833723229569 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -53,6 +54,7 @@
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -145,6 +147,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
affine::registerValueBoundsOpInterfaceExternalModels(registry);
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
+ arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
@@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
linalg::registerTilingInterfaceExternalModels(registry);
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
+ memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerMemorySlotExternalModels(registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..77ba4269b5fcf6
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,45 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::arith;
+
+namespace mlir {
+namespace arith {
+namespace {
+
+struct SelectOpInterface
+ : public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
+ SelectOp> {
+ void
+ populateDependencies(Operation *op,
+ RegisterDependenciesFn registerDependenciesFn) const {
+ auto selectOp = cast<SelectOp>(op);
+
+ // Either one of the true/false value may be selected at runtime.
+ registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
+ registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
+ }
+};
+
+} // namespace
+} // namespace arith
+} // namespace mlir
+
+void arith::registerBufferViewFlowOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+ SelectOp::attachInterface<SelectOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 02240601bcd35a..12659eaba1fa5e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ BufferViewFlowOpInterfaceImpl.cpp
EmulateUnsupportedFloats.cpp
EmulateWideInt.cpp
EmulateNarrowType.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
new file mode 100644
index 00000000000000..ea726a4bfc3fb9
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
@@ -0,0 +1,18 @@
+//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 9895db9d93ce0b..63dcc1eb233e92 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferDeallocationOpInterface.cpp
BufferizationOps.cpp
BufferizationDialect.cpp
+ BufferViewFlowOpInterface.cpp
UnstructuredControlFlow.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 88ef1b639fc5ce..7cf202ac81d7c0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,12 +8,16 @@
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
+using namespace mlir::bufferization;
/// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
@@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
void BufferViewFlowAnalysis::build(Operation *op) {
// Registers all dependencies of the given values.
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
- for (auto [value, dep] : llvm::zip(values, dependencies))
+ for (auto [value, dep] : llvm::zip_equal(values, dependencies))
this->dependencies[value].insert(dep);
};
+ // Mark all buffer results and buffer region entry block arguments of the
+ // given op as terminals.
+ auto populateTerminalValues = [&](Operation *op) {
+ for (Value v : op->getResults())
+ if (isa<BaseMemRefType>(v.getType()))
+ this->terminals.insert(v);
+ for (Region &r : op->getRegions())
+ for (BlockArgument v : r.getArguments())
+ if (isa<BaseMemRefType>(v.getType()))
+ this->terminals.insert(v);
+ };
+
op->walk([&](Operation *op) {
- // TODO: We should have an op interface instead of a hard-coded list of
- // interfaces/ops.
+ // Query BufferViewFlowOpInterface. If the op does not implement that
+ // interface, try to infer the dependencies from other interfaces that the
+ // op may implement.
+ if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
+ bufferViewFlowOp.populateDependencies(registerDependencies);
+ for (Value v : op->getResults())
+ if (isa<BaseMemRefType>(v.getType()) &&
+ bufferViewFlowOp.isTerminalBuffer(v))
+ this->terminals.insert(v);
+ for (Region &r : op->getRegions())
+ for (BlockArgument v : r.getArguments())
+ if (isa<BaseMemRefType>(v.getType()) &&
+ bufferViewFlowOp.isTerminalBuffer(v))
+ this->terminals.insert(v);
+ return WalkResult::advance();
+ }
// Add additional dependencies created by view changes to the alias list.
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
- dependencies[viewInterface.getViewSource()].insert(
- viewInterface->getResult(0));
+ registerDependencies(viewInterface.getViewSource(),
+ viewInterface->getResult(0));
return WalkResult::advance();
}
@@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
return WalkResult::advance();
}
- // Unknown op: Assume that all operands alias with all results.
- for (Value operand : op->getOperands()) {
- if (!isa<BaseMemRefType>(operand.getType()))
- continue;
- for (Value result : op->getResults()) {
- if (!isa<BaseMemRefType>(result.getType()))
- continue;
- registerDependencies({operand}, {result});
- }
+ // Region terminators are handled together with RegionBranchOpInterface.
+ if (isa<RegionBranchTerminatorOpInterface>(op))
+ return WalkResult::advance();
+
+ if (isa<CallOpInterface>(op)) {
+ // This is an intra-function analysis. We have no information about other
+ // functions. Conservatively assume that each operand may alias with each
+ // result. Also mark the results are terminals because the function could
+ // return newly allocated buffers.
+ populateTerminalValues(op);
+ for (Value operand : op->getOperands())
+ for (Value result : op->getResults())
+ registerDependencies({operand}, {result});
+ return WalkResult::advance();
}
+
+ // We have no information about unknown ops.
+ populateTerminalValues(op);
+
return WalkResult::advance();
});
}
+
+bool BufferViewFlowAnalysis::isTerminalBuffer(Value value) const {
+ assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
+ return terminals.contains(value);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 8617c17e7a5e5e..58feae66427b7a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRBufferizationDialect
+ MLIRBufferizationTransforms
MLIRControlFlowInterfaces
MLIRFuncDialect
MLIRFunctionInterfaces
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..ff1df72ba37407
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,49 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::memref;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+struct ReallocOpInterface
+ : public B...
[truncated]
|
@llvm/pr-subscribers-mlir-bufferization Author: Matthias Springer (matthias-springer) ChangesThis commit adds the The new interface has two interface methods:
Ops that implement the This commit makes the This commit addresses a TODO in
It is no longer needed to hard-code ops. Patch is 24.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/78718.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 00000000000000..8f79ae4913d1c6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace arith {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace arith
+} // namespace mlir
+
+#endif // MLIR_DIALECT_ARITH_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
new file mode 100644
index 00000000000000..84e67fe72b623b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h
@@ -0,0 +1,27 @@
+//===- BufferViewFlowOpInterface.h - Buffer View Flow Analysis --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class ValueRange;
+
+namespace bufferization {
+
+using RegisterDependenciesFn = std::function<void(ValueRange, ValueRange)>;
+
+} // namespace bufferization
+} // namespace mlir
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h.inc"
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERVIEWFLOWOPINTERFACE_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
new file mode 100644
index 00000000000000..baeb7308aad107
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
@@ -0,0 +1,67 @@
+//===-- BufferViewFlowOpInterface.td - Buffer View Flow ----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef BUFFER_VIEW_FLOW_OP_INTERFACE
+#define BUFFER_VIEW_FLOW_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def BufferViewFlowOpInterface :
+ OpInterface<"BufferViewFlowOpInterface"> {
+ let description = [{
+ An op interface for the buffer view flow analysis. This interface describes
+ buffer dependencies between operands and op results/region entry block
+ arguments.
+ }];
+ let cppNamespace = "::mlir::bufferization";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Populate buffer dependencies between operands and op results/region
+ entry block arguments.
+
+ Implementations should register dependencies between an operand ("X")
+ and an op result/region entry block argument ("Y") if Y may depend
+ on X. Y depends on X if Y and X are the same buffer or if Y is a
+ subview of X.
+
+ Example:
+ ```
+ %r = arith.select %c, %m1, %m2 : memref<5xf32>
+ ```
+ In the above example, %0 may depend on %m1 or %m2 and a correct
+ interface implementation should call:
+ - "registerDependenciesFn(%m1, %r)".
+ - "registerDependenciesFn(%m2, %r)"
+ }],
+ /*retType=*/"void",
+ /*methodName=*/"populateDependencies",
+ /*args=*/(ins
+ "::mlir::bufferization::RegisterDependenciesFn"
+ :$registerDependenciesFn)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return "true" if the given value is a terminal buffer. A buffer value
+ is "terminal" if it cannot be traced back any further in the buffer
+ view flow analysis. E.g., because the value is a newly allocated
+ buffer or because there is not enough information available.
+
+ The given SSA value is guaranteed to be an OpResult of this operation
+ or a region entry block argument of this operation.
+ }],
+ /*retType=*/"bool",
+ /*methodName=*/"isTerminalBuffer",
+ /*args=*/(ins "Value":$value),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return false;"
+ >,
+ ];
+}
+
+#endif // BUFFER_VIEW_FLOW_OP_INTERFACE
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
index 31a553f9a32f55..13a5bc370a4fce 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
add_mlir_interface(AllocationOpInterface)
add_mlir_interface(BufferDeallocationOpInterface)
add_mlir_interface(BufferizableOpInterface)
+add_mlir_interface(BufferViewFlowOpInterface)
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
index 24825db69f90c5..110f7f51007d5b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
@@ -63,6 +63,9 @@ class BufferViewFlowAnalysis {
/// results have to be changed.
void rename(Value from, Value to);
+ /// Return "true" if the given value is a terminal.
+ bool isTerminalBuffer(Value value) const;
+
private:
/// This function constructs a mapping from values to its immediate
/// dependencies.
@@ -70,6 +73,9 @@ class BufferViewFlowAnalysis {
/// Maps values to all immediate dependencies this value can have.
ValueMapT dependencies;
+
+ /// A set of all terminal values. I.e., values at which the analysis stopped.
+ DenseSet<Value> terminals;
};
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
new file mode 100644
index 00000000000000..1497070510dd6e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- BufferViewFlowOpInterfaceImpl.h - Buffer View Analysis ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+#define MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace memref {
+void registerBufferViewFlowOpInterfaceExternalModels(DialectRegistry ®istry);
+} // namespace memref
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MEMREF_BUFFERVIEWFLOWOPINTERFACEIMPL_H
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 19a62cadaa2e04..93833723229569 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
@@ -53,6 +54,7 @@
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
@@ -145,6 +147,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
affine::registerValueBoundsOpInterfaceExternalModels(registry);
arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);
+ arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
arith::registerValueBoundsOpInterfaceExternalModels(registry);
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
registry);
@@ -157,6 +160,7 @@ inline void registerAllDialects(DialectRegistry ®istry) {
linalg::registerTilingInterfaceExternalModels(registry);
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
+ memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
memref::registerMemorySlotExternalModels(registry);
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..77ba4269b5fcf6
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,45 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::arith;
+
+namespace mlir {
+namespace arith {
+namespace {
+
+struct SelectOpInterface
+ : public BufferViewFlowOpInterface::ExternalModel<SelectOpInterface,
+ SelectOp> {
+ void
+ populateDependencies(Operation *op,
+ RegisterDependenciesFn registerDependenciesFn) const {
+ auto selectOp = cast<SelectOp>(op);
+
+ // Either one of the true/false value may be selected at runtime.
+ registerDependenciesFn(selectOp.getTrueValue(), selectOp.getResult());
+ registerDependenciesFn(selectOp.getFalseValue(), selectOp.getResult());
+ }
+};
+
+} // namespace
+} // namespace arith
+} // namespace mlir
+
+void arith::registerBufferViewFlowOpInterfaceExternalModels(
+ DialectRegistry ®istry) {
+ registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
+ SelectOp::attachInterface<SelectOpInterface>(*ctx);
+ });
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 02240601bcd35a..12659eaba1fa5e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRArithTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
+ BufferViewFlowOpInterfaceImpl.cpp
EmulateUnsupportedFloats.cpp
EmulateWideInt.cpp
EmulateNarrowType.cpp
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
new file mode 100644
index 00000000000000..ea726a4bfc3fb9
--- /dev/null
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp
@@ -0,0 +1,18 @@
+//===- BufferViewFlowOpInterface.cpp - Buffer View Flow Analysis ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+
+namespace mlir {
+namespace bufferization {
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.cpp.inc"
+
+} // namespace bufferization
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
index 9895db9d93ce0b..63dcc1eb233e92 100644
--- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
BufferDeallocationOpInterface.cpp
BufferizationOps.cpp
BufferizationDialect.cpp
+ BufferViewFlowOpInterface.cpp
UnstructuredControlFlow.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
index 88ef1b639fc5ce..7cf202ac81d7c0 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp
@@ -8,12 +8,16 @@
#include "mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h"
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SetVector.h"
using namespace mlir;
+using namespace mlir::bufferization;
/// Constructs a new alias analysis using the op provided.
BufferViewFlowAnalysis::BufferViewFlowAnalysis(Operation *op) { build(op); }
@@ -65,18 +69,44 @@ void BufferViewFlowAnalysis::rename(Value from, Value to) {
void BufferViewFlowAnalysis::build(Operation *op) {
// Registers all dependencies of the given values.
auto registerDependencies = [&](ValueRange values, ValueRange dependencies) {
- for (auto [value, dep] : llvm::zip(values, dependencies))
+ for (auto [value, dep] : llvm::zip_equal(values, dependencies))
this->dependencies[value].insert(dep);
};
+ // Mark all buffer results and buffer region entry block arguments of the
+ // given op as terminals.
+ auto populateTerminalValues = [&](Operation *op) {
+ for (Value v : op->getResults())
+ if (isa<BaseMemRefType>(v.getType()))
+ this->terminals.insert(v);
+ for (Region &r : op->getRegions())
+ for (BlockArgument v : r.getArguments())
+ if (isa<BaseMemRefType>(v.getType()))
+ this->terminals.insert(v);
+ };
+
op->walk([&](Operation *op) {
- // TODO: We should have an op interface instead of a hard-coded list of
- // interfaces/ops.
+ // Query BufferViewFlowOpInterface. If the op does not implement that
+ // interface, try to infer the dependencies from other interfaces that the
+ // op may implement.
+ if (auto bufferViewFlowOp = dyn_cast<BufferViewFlowOpInterface>(op)) {
+ bufferViewFlowOp.populateDependencies(registerDependencies);
+ for (Value v : op->getResults())
+ if (isa<BaseMemRefType>(v.getType()) &&
+ bufferViewFlowOp.isTerminalBuffer(v))
+ this->terminals.insert(v);
+ for (Region &r : op->getRegions())
+ for (BlockArgument v : r.getArguments())
+ if (isa<BaseMemRefType>(v.getType()) &&
+ bufferViewFlowOp.isTerminalBuffer(v))
+ this->terminals.insert(v);
+ return WalkResult::advance();
+ }
// Add additional dependencies created by view changes to the alias list.
if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) {
- dependencies[viewInterface.getViewSource()].insert(
- viewInterface->getResult(0));
+ registerDependencies(viewInterface.getViewSource(),
+ viewInterface->getResult(0));
return WalkResult::advance();
}
@@ -131,16 +161,30 @@ void BufferViewFlowAnalysis::build(Operation *op) {
return WalkResult::advance();
}
- // Unknown op: Assume that all operands alias with all results.
- for (Value operand : op->getOperands()) {
- if (!isa<BaseMemRefType>(operand.getType()))
- continue;
- for (Value result : op->getResults()) {
- if (!isa<BaseMemRefType>(result.getType()))
- continue;
- registerDependencies({operand}, {result});
- }
+ // Region terminators are handled together with RegionBranchOpInterface.
+ if (isa<RegionBranchTerminatorOpInterface>(op))
+ return WalkResult::advance();
+
+ if (isa<CallOpInterface>(op)) {
+ // This is an intra-function analysis. We have no information about other
+ // functions. Conservatively assume that each operand may alias with each
+ // result. Also mark the results are terminals because the function could
+ // return newly allocated buffers.
+ populateTerminalValues(op);
+ for (Value operand : op->getOperands())
+ for (Value result : op->getResults())
+ registerDependencies({operand}, {result});
+ return WalkResult::advance();
}
+
+ // We have no information about unknown ops.
+ populateTerminalValues(op);
+
return WalkResult::advance();
});
}
+
+bool BufferViewFlowAnalysis::isTerminalBuffer(Value value) const {
+ assert(isa<BaseMemRefType>(value.getType()) && "expected memref");
+ return terminals.contains(value);
+}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
index 8617c17e7a5e5e..58feae66427b7a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
@@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
LINK_LIBS PUBLIC
MLIRArithDialect
MLIRBufferizationDialect
+ MLIRBufferizationTransforms
MLIRControlFlowInterfaces
MLIRFuncDialect
MLIRFunctionInterfaces
diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
new file mode 100644
index 00000000000000..ff1df72ba37407
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp
@@ -0,0 +1,49 @@
+//===- BufferViewFlowOpInterfaceImpl.cpp - Buffer View Flow Analysis ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+
+#include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+
+using namespace mlir;
+using namespace mlir::bufferization;
+using namespace mlir::memref;
+
+namespace mlir {
+namespace memref {
+namespace {
+
+struct ReallocOpInterface
+ : public B...
[truncated]
|
f7082ae
to
5e32eba
Compare
mlir/include/mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
5e32eba
to
65e0e8b
Compare
mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to add a test case easily that exercises this?
mlir/include/mlir/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.h
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Good to have this gap covered.
mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td
Outdated
Show resolved
Hide resolved
void | ||
populateDependencies(Operation *op, | ||
RegisterDependenciesFn registerDependenciesFn) const { | ||
auto reallocOp = cast<ReallocOp>(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also update mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp simultaneously; your patch is needed there to correctly model alias relationships.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very familiar with that part of the code base, let's do that in a follow-up PR.
65e0e8b
to
9d80c37
Compare
This commit adds the `BufferViewFlowOpInterface` to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by the `BufferViewFlowAnalysis`. There are currently no ops that implement this interface. The first op implementations will be added in a consecutive commit. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
9d80c37
to
fa46b2e
Compare
✅ With the latest revision this PR passed the Python code formatter. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
This commit adds the
BufferViewFlowOpInterface
to the bufferization dialect. This interface can be implemented by ops that operate on buffers to indicate that a buffer op result and/or region entry block argument may be the same buffer as a buffer operand (or a view thereof). This interface is queried by theBufferViewFlowAnalysis
.The new interface has two interface methods:
populateDependencies
: Implementations use the provided callback to declare dependencies between operands and op results/region entry block arguments. E.g., for%r = arith.select %c, %m1, %m2 : memref<5xf32>
, the interface implementation should declare two dependencies: %m1 -> %r and %m2 -> %r.isTerminalBuffer
: An SSA value is a terminal buffer if the buffer view flow analysis stops at the specified value. E.g., because the value is a newly allocated buffer or because no further information is available about the origin of the buffer.Ops that implement the
RegionBranchOpInterface
orBranchOpInterface
do not have to implement theBufferViewFlowOpInterface
. The buffer dependencies can be inferred from those two interfaces.This commit makes the
BufferViewFlowAnalysis
more accurate. For unknown ops, it conservatively used to declare all combinations of operands and op results/region entry block arguments as dependencies (false positives). This is no longer the case. While the analysis is still a "maybe" analysis with false positives (e.g., when analyzing ops such asarith.select
orscf.if
where the taken branch is not known at compile time), results and region entry block arguments of unknown ops are now marked as terminal buffers.This commit addresses a TODO in
BufferViewFlowAnalysis.cpp
:It is no longer needed to hard-code ops.