diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h new file mode 100644 index 0000000000000..b88270f1c150a --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h @@ -0,0 +1,217 @@ +//===- BufferDeallocationOpInterface.h --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_ +#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_ + +#include "mlir/Analysis/Liveness.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Support/LLVM.h" + +namespace mlir { +namespace bufferization { + +/// Compare two SSA values in a deterministic manner. Two block arguments are +/// ordered by argument number, block arguments are always less than operation +/// results, and operation results are ordered by the `isBeforeInBlock` order of +/// their defining operation. +struct ValueComparator { + bool operator()(const Value &lhs, const Value &rhs) const; +}; + +/// This class is used to track the ownership of values. The ownership can +/// either be not initialized yet ('Uninitialized' state), set to a unique SSA +/// value which indicates the ownership at runtime (or statically if it is a +/// constant value) ('Unique' state), or it cannot be represented in a single +/// SSA value ('Unknown' state). An artificial example of a case where ownership +/// cannot be represented in a single i1 SSA value could be the following: +/// `%0 = test.non_deterministic_select %arg0, %arg1 : i32` +/// Since the operation does not provide us a separate boolean indicator on +/// which of the two operands was selected, we would need to either insert an +/// alias check at runtime to determine if `%0` aliases with `%arg0` or `%arg1`, +/// or insert a `bufferization.clone` operation to get a fresh buffer which we +/// could assign ownership to. +/// +/// The three states this class can represent form a lattice on a partial order: +/// forall X in SSA values. uninitialized < unique(X) < unknown +/// forall X, Y in SSA values. +/// unique(X) == unique(Y) iff X and Y always evaluate to the same value +/// unique(X) != unique(Y) otherwise +class Ownership { +public: + /// Constructor that creates an 'Uninitialized' ownership. This is needed for + /// default-construction when used in DenseMap. + Ownership() = default; + + /// Constructor that creates an 'Unique' ownership. This is a non-explicit + /// constructor to allow implicit conversion from 'Value'. + Ownership(Value indicator); + + /// Get an ownership value in 'Unknown' state. + static Ownership getUnknown(); + /// Get an ownership value in 'Unique' state with 'indicator' as parameter. + static Ownership getUnique(Value indicator); + /// Get an ownership value in 'Uninitialized' state. + static Ownership getUninitialized(); + + /// Check if this ownership value is in the 'Uninitialized' state. + bool isUninitialized() const; + /// Check if this ownership value is in the 'Unique' state. + bool isUnique() const; + /// Check if this ownership value is in the 'Unknown' state. + bool isUnknown() const; + + /// If this ownership value is in 'Unique' state, this function can be used to + /// get the indicator parameter. Using this function in any other state is UB. + Value getIndicator() const; + + /// Get the join of the two-element subset {this,other}. Does not modify + /// 'this'. + Ownership getCombined(Ownership other) const; + + /// Modify 'this' ownership to be the join of the current 'this' and 'other'. + void combine(Ownership other); + +private: + enum class State { + Uninitialized, + Unique, + Unknown, + }; + + // The indicator value is only relevant in the 'Unique' state. + Value indicator; + State state = State::Uninitialized; +}; + +/// Options for BufferDeallocationOpInterface-based buffer deallocation. +struct DeallocationOptions { + // A pass option indicating whether private functions should be modified to + // pass the ownership of MemRef values instead of adhering to the function + // boundary ABI. + bool privateFuncDynamicOwnership = false; +}; + +/// This class collects all the state that we need to perform the buffer +/// deallocation pass with associated helper functions such that we have easy +/// access to it in the BufferDeallocationOpInterface implementations and the +/// BufferDeallocation pass. +class DeallocationState { +public: + DeallocationState(Operation *op); + + // The state should always be passed by reference. + DeallocationState(const DeallocationState &) = delete; + + /// Small helper function to update the ownership map by taking the current + /// ownership ('Uninitialized' state if not yet present), computing the join + /// with the passed ownership and storing this new value in the map. By + /// default, it will be performed for the block where 'owned' is defined. If + /// the ownership of the given value should be updated for another block, the + /// 'block' argument can be explicitly passed. + void updateOwnership(Value memref, Ownership ownership, + Block *block = nullptr); + + /// Removes ownerships associated with all values in the passed range for + /// 'block'. + void resetOwnerships(ValueRange memrefs, Block *block); + + /// Returns the ownership of 'memref' for the given basic block. + Ownership getOwnership(Value memref, Block *block) const; + + /// Remember the given 'memref' to deallocate it at the end of the 'block'. + void addMemrefToDeallocate(Value memref, Block *block); + + /// Forget about a MemRef that we originally wanted to deallocate at the end + /// of 'block', possibly because it already gets deallocated before the end of + /// the block. + void dropMemrefToDeallocate(Value memref, Block *block); + + /// Return a sorted list of MemRef values which are live at the start of the + /// given block. + void getLiveMemrefsIn(Block *block, SmallVectorImpl &memrefs); + + /// Given an SSA value of MemRef type, this function queries the ownership and + /// if it is not already in the 'Unique' state, potentially inserts IR to get + /// a new SSA value, returned as the first element of the pair, which has + /// 'Unique' ownership and can be used instead of the passed Value with the + /// the ownership indicator returned as the second element of the pair. + std::pair getMemrefWithUniqueOwnership(OpBuilder &builder, + Value memref); + + /// Given two basic blocks and the values passed via block arguments to the + /// destination block, compute the list of MemRefs that have to be retained in + /// the 'fromBlock' to not run into a use-after-free situation. + /// This list consists of the MemRefs in the successor operand list of the + /// terminator and the MemRefs in the 'out' set of the liveness analysis + /// intersected with the 'in' set of the destination block. + /// + /// toRetain = filter(successorOperands + (liveOut(fromBlock) insersect + /// liveIn(toBlock)), isMemRef) + void getMemrefsToRetain(Block *fromBlock, Block *toBlock, + ValueRange destOperands, + SmallVectorImpl &toRetain) const; + + /// For a given block, computes the list of MemRefs that potentially need to + /// be deallocated at the end of that block. This list also contains values + /// that have to be retained (and are thus part of the list returned by + /// `getMemrefsToRetain`) and is computed by taking the MemRefs in the 'in' + /// set of the liveness analysis of 'block' appended by the set of MemRefs + /// allocated in 'block' itself and subtracted by the set of MemRefs + /// deallocated in 'block'. + /// Note that we don't have to take the intersection of the liveness 'in' set + /// with the 'out' set of the predecessor block because a value that is in the + /// 'in' set must be defined in an ancestor block that dominates all direct + /// predecessors and thus the 'in' set of this block is a subset of the 'out' + /// sets of each predecessor. + /// + /// memrefs = filter((liveIn(block) U + /// allocated(block) U arguments(block)) \ deallocated(block), isMemRef) + /// + /// The list of conditions is then populated by querying the internal + /// datastructures for the ownership value of that MemRef. + LogicalResult + getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, + Block *block, + SmallVectorImpl &memrefs, + SmallVectorImpl &conditions) const; + + /// Returns the symbol cache to lookup functions from call operations to check + /// attributes on the function operation. + SymbolTableCollection *getSymbolTable() { return &symbolTable; } + +private: + // Symbol cache to lookup functions from call operations to check attributes + // on the function operation. + SymbolTableCollection symbolTable; + + // Mapping from each SSA value with MemRef type to the associated ownership in + // each block. + DenseMap, Ownership> ownershipMap; + + // Collects the list of MemRef values that potentially need to be deallocated + // per block. It is also fine (albeit not efficient) to add MemRef values that + // don't have to be deallocated, but only when the ownership is not 'Unknown'. + DenseMap> memrefsToDeallocatePerBlock; + + // The underlying liveness analysis to compute fine grained information about + // alloc and dealloc positions. + Liveness liveness; +}; + +} // namespace bufferization +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Buffer Deallocation Interface +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h.inc" + +#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERDEALLOCATIONOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td new file mode 100644 index 0000000000000..c35fe417184ff --- /dev/null +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.td @@ -0,0 +1,46 @@ +//===-- BufferDeallocationOpInterface.td -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BUFFER_DEALLOCATION_OP_INTERFACE +#define BUFFER_DEALLOCATION_OP_INTERFACE + +include "mlir/IR/OpBase.td" + +def BufferDeallocationOpInterface : + OpInterface<"BufferDeallocationOpInterface"> { + let description = [{ + An op interface for Buffer Deallocation. Ops that implement this interface + can provide custom logic for computing the ownership of OpResults, modify + the operation to properly pass the ownership values around, and insert + `bufferization.dealloc` operations when necessary. + }]; + let cppNamespace = "::mlir::bufferization"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + This method takes the current deallocation state and transformation + options and updates the deallocation state as necessary for the + operation implementing this interface. It may also insert + `bufferization.dealloc` operations and rebuild itself with different + result types. For operations implementing this interface all other + interface handlers (e.g., default handlers for interfaces like + RegionBranchOpInterface, CallOpInterface, etc.) are skipped by the + deallocation pass. On success, either the current operation or one of + the newly inserted operations is returned from which on the driver + should continue the processing. On failure, the deallocation pass + will terminate. It is recommended to emit a useful error message in + that case. + }], + /*retType=*/"FailureOr", + /*methodName=*/"process", + /*args=*/(ins "DeallocationState &":$state, + "const DeallocationOptions &":$options)> + ]; +} + +#endif // BUFFER_DEALLOCATION_OP_INTERFACE diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt index 440125031b1ac..38057d4910d29 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect(BufferizationOps bufferization) add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc) add_mlir_interface(AllocationOpInterface) +add_mlir_interface(BufferDeallocationOpInterface) add_mlir_interface(BufferizableOpInterface) add_mlir_interface(SubsetInsertionOpInterface) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h index 83e55fd70de6b..85e9c47ad5302 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h @@ -121,14 +121,6 @@ class BufferPlacementTransformationBase { Liveness liveness; }; -/// Compare two SSA values in a deterministic manner. Two block arguments are -/// ordered by argument number, block arguments are always less than operation -/// results, and operation results are ordered by the `isBeforeInBlock` order of -/// their defining operation. -struct ValueComparator { - bool operator()(const Value &lhs, const Value &rhs) const; -}; - // Create a global op for the given tensor-valued constant in the program. // Globals are created lazily at the top of the enclosing ModuleOp with pretty // names. Duplicates are avoided. diff --git a/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h new file mode 100644 index 0000000000000..c34ebd0494fec --- /dev/null +++ b/mlir/include/mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h @@ -0,0 +1,22 @@ +//===- BufferDeallocationOpInterfaceImpl.h ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H +#define MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H + +namespace mlir { + +class DialectRegistry; + +namespace cf { +void registerBufferDeallocationOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace cf +} // namespace mlir + +#endif // MLIR_DIALECT_CONTROLFLOW_TRANSFORMS_BUFFERDEALLOCATIONOPINTERFACEIMPL_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 6eaa0cc0d46aa..ee91bfa57d12a 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -29,6 +29,7 @@ #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" #include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" @@ -138,6 +139,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { registry); builtin::registerCastOpInterfaceExternalModels(registry); cf::registerBufferizableOpInterfaceExternalModels(registry); + cf::registerBufferDeallocationOpInterfaceExternalModels(registry); linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); linalg::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp new file mode 100644 index 0000000000000..2314cee2ff2c1 --- /dev/null +++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp @@ -0,0 +1,274 @@ +//===- BufferDeallocationOpInterface.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/SetOperations.h" + +//===----------------------------------------------------------------------===// +// BufferDeallocationOpInterface +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace bufferization { + +#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc" + +} // namespace bufferization +} // namespace mlir + +using namespace mlir; +using namespace bufferization; + +//===----------------------------------------------------------------------===// +// Helpers +//===----------------------------------------------------------------------===// + +static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { + return builder.create(loc, builder.getBoolAttr(value)); +} + +static bool isMemref(Value v) { return v.getType().isa(); } + +//===----------------------------------------------------------------------===// +// Ownership +//===----------------------------------------------------------------------===// + +Ownership::Ownership(Value indicator) + : indicator(indicator), state(State::Unique) {} + +Ownership Ownership::getUnknown() { + Ownership unknown; + unknown.indicator = Value(); + unknown.state = State::Unknown; + return unknown; +} +Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); } +Ownership Ownership::getUninitialized() { return Ownership(); } + +bool Ownership::isUninitialized() const { + return state == State::Uninitialized; +} +bool Ownership::isUnique() const { return state == State::Unique; } +bool Ownership::isUnknown() const { return state == State::Unknown; } + +Value Ownership::getIndicator() const { + assert(isUnique() && "must have unique ownership to get the indicator"); + return indicator; +} + +Ownership Ownership::getCombined(Ownership other) const { + if (other.isUninitialized()) + return *this; + if (isUninitialized()) + return other; + + if (!isUnique() || !other.isUnique()) + return getUnknown(); + + // Since we create a new constant i1 value for (almost) each use-site, we + // should compare the actual value rather than just the SSA Value to avoid + // unnecessary invalidations. + if (isEqualConstantIntOrValue(indicator, other.indicator)) + return *this; + + // Return the join of the lattice if the indicator of both ownerships cannot + // be merged. + return getUnknown(); +} + +void Ownership::combine(Ownership other) { *this = getCombined(other); } + +//===----------------------------------------------------------------------===// +// DeallocationState +//===----------------------------------------------------------------------===// + +DeallocationState::DeallocationState(Operation *op) : liveness(op) {} + +void DeallocationState::updateOwnership(Value memref, Ownership ownership, + Block *block) { + // In most cases we care about the block where the value is defined. + if (block == nullptr) + block = memref.getParentBlock(); + + // Update ownership of current memref itself. + ownershipMap[{memref, block}].combine(ownership); +} + +void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) { + for (Value val : memrefs) + ownershipMap[{val, block}] = Ownership::getUninitialized(); +} + +Ownership DeallocationState::getOwnership(Value memref, Block *block) const { + return ownershipMap.lookup({memref, block}); +} + +void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) { + memrefsToDeallocatePerBlock[block].push_back(memref); +} + +void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) { + llvm::erase_if(memrefsToDeallocatePerBlock[block], + [&](const auto &mr) { return mr == memref; }); +} + +void DeallocationState::getLiveMemrefsIn(Block *block, + SmallVectorImpl &memrefs) { + SmallVector liveMemrefs( + llvm::make_filter_range(liveness.getLiveIn(block), isMemref)); + llvm::sort(liveMemrefs, ValueComparator()); + memrefs.append(liveMemrefs); +} + +std::pair +DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder, + Value memref) { + auto iter = ownershipMap.find({memref, memref.getParentBlock()}); + assert(iter != ownershipMap.end() && + "Value must already have been registered in the ownership map"); + + Ownership ownership = iter->second; + if (ownership.isUnique()) + return {memref, ownership.getIndicator()}; + + // Instead of inserting a clone operation we could also insert a dealloc + // operation earlier in the block and use the updated ownerships returned by + // the op for the retained values. Alternatively, we could insert code to + // check aliasing at runtime and use this information to combine two unique + // ownerships more intelligently to not end up with an 'Unknown' ownership in + // the first place. + auto cloneOp = + builder.create(memref.getLoc(), memref); + Value condition = buildBoolValue(builder, memref.getLoc(), true); + Value newMemref = cloneOp.getResult(); + updateOwnership(newMemref, condition); + memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref); + return {newMemref, condition}; +} + +void DeallocationState::getMemrefsToRetain( + Block *fromBlock, Block *toBlock, ValueRange destOperands, + SmallVectorImpl &toRetain) const { + for (Value operand : destOperands) { + if (!isMemref(operand)) + continue; + toRetain.push_back(operand); + } + + SmallPtrSet liveOut; + for (auto val : liveness.getLiveOut(fromBlock)) + if (isMemref(val)) + liveOut.insert(val); + + if (toBlock) + llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock)); + + // liveOut has non-deterministic order because it was constructed by iterating + // over a hash-set. + SmallVector retainedByLiveness(liveOut.begin(), liveOut.end()); + std::sort(retainedByLiveness.begin(), retainedByLiveness.end(), + ValueComparator()); + toRetain.append(retainedByLiveness); +} + +LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( + OpBuilder &builder, Location loc, Block *block, + SmallVectorImpl &memrefs, SmallVectorImpl &conditions) const { + + for (auto [i, memref] : + llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) { + Ownership ownership = ownershipMap.lookup({memref, block}); + if (!ownership.isUnique()) + return emitError(memref.getLoc(), + "MemRef value does not have valid ownership"); + + // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such + // that we can call extract_strided_metadata on it. + if (auto unrankedMemRefTy = dyn_cast(memref.getType())) + memref = builder.create( + loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref, + 0, SmallVector{}, SmallVector{}); + + // Use the `memref.extract_strided_metadata` operation to get the base + // memref. This is needed because the same MemRef that was produced by the + // alloc operation has to be passed to the dealloc operation. Passing + // subviews, etc. to a dealloc operation is not allowed. + memrefs.push_back( + builder.create(loc, memref) + .getResult(0)); + conditions.push_back(ownership.getIndicator()); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// ValueComparator +//===----------------------------------------------------------------------===// + +bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const { + if (lhs == rhs) + return false; + + // Block arguments are less than results. + bool lhsIsBBArg = lhs.isa(); + if (lhsIsBBArg != rhs.isa()) { + return lhsIsBBArg; + } + + Region *lhsRegion; + Region *rhsRegion; + if (lhsIsBBArg) { + auto lhsBBArg = llvm::cast(lhs); + auto rhsBBArg = llvm::cast(rhs); + if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) { + return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber(); + } + lhsRegion = lhsBBArg.getParentRegion(); + rhsRegion = rhsBBArg.getParentRegion(); + assert(lhsRegion != rhsRegion && + "lhsRegion == rhsRegion implies lhs == rhs"); + } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) { + return llvm::cast(lhs).getResultNumber() < + llvm::cast(rhs).getResultNumber(); + } else { + lhsRegion = lhs.getDefiningOp()->getParentRegion(); + rhsRegion = rhs.getDefiningOp()->getParentRegion(); + if (lhsRegion == rhsRegion) { + return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp()); + } + } + + // lhsRegion != rhsRegion, so if we look at their ancestor chain, they + // - have different heights + // - or there's a spot where their region numbers differ + // - or their parent regions are the same and their parent ops are + // different. + while (lhsRegion && rhsRegion) { + if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) { + return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber(); + } + if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) { + return lhsRegion->getParentOp()->isBeforeInBlock( + rhsRegion->getParentOp()); + } + lhsRegion = lhsRegion->getParentRegion(); + rhsRegion = rhsRegion->getParentRegion(); + } + if (rhsRegion) + return true; + assert(lhsRegion && "this should only happen if lhs == rhs"); + return false; +} diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt index 3fd9221624d0f..b1940e40ba341 100644 --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect AllocationOpInterface.cpp BufferizableOpInterface.cpp + BufferDeallocationOpInterface.cpp BufferizationOps.cpp BufferizationDialect.cpp SubsetInsertionOpInterface.cpp diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index b8fd99a554124..119801f9cc92f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -202,62 +202,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, global->moveBefore(&moduleOp.front()); return global; } - -//===----------------------------------------------------------------------===// -// ValueComparator -//===----------------------------------------------------------------------===// - -bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const { - if (lhs == rhs) - return false; - - // Block arguments are less than results. - bool lhsIsBBArg = lhs.isa(); - if (lhsIsBBArg != rhs.isa()) { - return lhsIsBBArg; - } - - Region *lhsRegion; - Region *rhsRegion; - if (lhsIsBBArg) { - auto lhsBBArg = llvm::cast(lhs); - auto rhsBBArg = llvm::cast(rhs); - if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) { - return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber(); - } - lhsRegion = lhsBBArg.getParentRegion(); - rhsRegion = rhsBBArg.getParentRegion(); - assert(lhsRegion != rhsRegion && - "lhsRegion == rhsRegion implies lhs == rhs"); - } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) { - return llvm::cast(lhs).getResultNumber() < - llvm::cast(rhs).getResultNumber(); - } else { - lhsRegion = lhs.getDefiningOp()->getParentRegion(); - rhsRegion = rhs.getDefiningOp()->getParentRegion(); - if (lhsRegion == rhsRegion) { - return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp()); - } - } - - // lhsRegion != rhsRegion, so if we look at their ancestor chain, they - // - have different heights - // - or there's a spot where their region numbers differ - // - or their parent regions are the same and their parent ops are - // different. - while (lhsRegion && rhsRegion) { - if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) { - return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber(); - } - if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) { - return lhsRegion->getParentOp()->isBeforeInBlock( - rhsRegion->getParentOp()); - } - lhsRegion = lhsRegion->getParentRegion(); - rhsRegion = rhsRegion->getParentRegion(); - } - if (rhsRegion) - return true; - assert(lhsRegion && "this should only happen if lhs == rhs"); - return false; -} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt index cbbfe7a812058..ed8dbd57bf40b 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -35,7 +35,6 @@ add_mlir_dialect_library(MLIRBufferizationTransforms MLIRPass MLIRTensorDialect MLIRSCFDialect - MLIRControlFlowDialect MLIRSideEffectInterfaces MLIRTransforms MLIRViewLikeInterface diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index eaced7202f4e6..d4b8e0dff67ba 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -18,16 +18,14 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" -#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Iterators.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "llvm/ADT/SetOperations.h" namespace mlir { namespace bufferization { @@ -137,103 +135,13 @@ class Backedges { //===----------------------------------------------------------------------===// namespace { -/// This class is used to track the ownership of values. The ownership can -/// either be not initialized yet ('Uninitialized' state), set to a unique SSA -/// value which indicates the ownership at runtime (or statically if it is a -/// constant value) ('Unique' state), or it cannot be represented in a single -/// SSA value ('Unknown' state). An artificial example of a case where ownership -/// cannot be represented in a single i1 SSA value could be the following: -/// `%0 = test.non_deterministic_select %arg0, %arg1 : i32` -/// Since the operation does not provide us a separate boolean indicator on -/// which of the two operands was selected, we would need to either insert an -/// alias check at runtime to determine if `%0` aliases with `%arg0` or `%arg1`, -/// or insert a `bufferization.clone` operation to get a fresh buffer which we -/// could assign ownership to. -/// -/// The three states this class can represent form a lattice on a partial order: -/// forall X in SSA values. uninitialized < unique(X) < unknown -/// forall X, Y in SSA values. -/// unique(X) == unique(Y) iff X and Y always evaluate to the same value -/// unique(X) != unique(Y) otherwise -class Ownership { -public: - /// Constructor that creates an 'Uninitialized' ownership. This is needed for - /// default-construction when used in DenseMap. - Ownership() = default; - - /// Constructor that creates an 'Unique' ownership. This is a non-explicit - /// constructor to allow implicit conversion from 'Value'. - Ownership(Value indicator) : indicator(indicator), state(State::Unique) {} - - /// Get an ownership value in 'Unknown' state. - static Ownership getUnknown() { - Ownership unknown; - unknown.indicator = Value(); - unknown.state = State::Unknown; - return unknown; - } - /// Get an ownership value in 'Unique' state with 'indicator' as parameter. - static Ownership getUnique(Value indicator) { return Ownership(indicator); } - /// Get an ownership value in 'Uninitialized' state. - static Ownership getUninitialized() { return Ownership(); } - - /// Check if this ownership value is in the 'Uninitialized' state. - bool isUninitialized() const { return state == State::Uninitialized; } - /// Check if this ownership value is in the 'Unique' state. - bool isUnique() const { return state == State::Unique; } - /// Check if this ownership value is in the 'Unknown' state. - bool isUnknown() const { return state == State::Unknown; } - - /// If this ownership value is in 'Unique' state, this function can be used to - /// get the indicator parameter. Using this function in any other state is UB. - Value getIndicator() const { - assert(isUnique() && "must have unique ownership to get the indicator"); - return indicator; - } - - /// Get the join of the two-element subset {this,other}. Does not modify - /// 'this'. - Ownership getCombined(Ownership other) const { - if (other.isUninitialized()) - return *this; - if (isUninitialized()) - return other; - - if (!isUnique() || !other.isUnique()) - return getUnknown(); - - // Since we create a new constant i1 value for (almost) each use-site, we - // should compare the actual value rather than just the SSA Value to avoid - // unnecessary invalidations. - if (isEqualConstantIntOrValue(indicator, other.indicator)) - return *this; - - // Return the join of the lattice if the indicator of both ownerships cannot - // be merged. - return getUnknown(); - } - - /// Modify 'this' ownership to be the join of the current 'this' and 'other'. - void combine(Ownership other) { *this = getCombined(other); } - -private: - enum class State { - Uninitialized, - Unique, - Unknown, - }; - - // The indicator value is only relevant in the 'Unique' state. - Value indicator; - State state = State::Uninitialized; -}; - /// The buffer deallocation transformation which ensures that all allocs in the /// program have a corresponding de-allocation. class BufferDeallocation { public: BufferDeallocation(Operation *op, bool privateFuncDynamicOwnership) - : liveness(op), privateFuncDynamicOwnership(privateFuncDynamicOwnership) { + : state(op) { + options.privateFuncDynamicOwnership = privateFuncDynamicOwnership; } /// Performs the actual placement/creation of all dealloc operations. @@ -291,57 +199,17 @@ class BufferDeallocation { /// Apply all supported interface handlers to the given op. FailureOr handleAllInterfaces(Operation *op) { + if (auto deallocOpInterface = dyn_cast(op)) + return deallocOpInterface.process(state, options); + if (failed(verifyOperationPreconditions(op))) return failure(); return handleOp(op); } - /// While CondBranchOp also implements the BranchOpInterface, we add a - /// special-case implementation here because the BranchOpInterface does not - /// offer all of the functionality we need to insert dealloc operations in an - /// efficient way. More precisely, there is no way to extract the branch - /// condition without casting to CondBranchOp specifically. It would still be - /// possible to implement deallocation for cases where we don't know to which - /// successor the terminator branches before the actual branch happens by - /// inserting auxiliary blocks and putting the dealloc op there, however, this - /// can lead to less efficient code. - /// This function inserts two dealloc operations (one for each successor) and - /// adjusts the dealloc conditions according to the branch condition, then the - /// ownerships of the retained MemRefs are updated by combining the result - /// values of the two dealloc operations. - /// - /// Example: - /// ``` - /// ^bb1: - /// - /// cf.cond_br cond, ^bb2(), ^bb3() - /// ``` - /// becomes - /// ``` - /// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1) - /// // let r0 = getMemrefsToRetain(bb1, bb2, ) - /// // let r1 = getMemrefsToRetain(bb1, bb3, ) - /// ^bb1: - /// - /// let thenCond = map(c, (c) -> arith.andi cond, c) - /// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c) - /// o0 = bufferization.dealloc m if thenCond retain r0 - /// o1 = bufferization.dealloc m if elseCond retain r1 - /// // replace ownership(r0) with o0 element-wise - /// // replace ownership(r1) with o1 element-wise - /// // let ownership0 := (r) -> o in o0 corresponding to r - /// // let ownership1 := (r) -> o in o1 corresponding to r - /// // let cmn := intersection(r0, r1) - /// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)): - /// forall r in r0: replace ownership0(r) with arith.select cond, a, b) - /// forall r in r1: replace ownership1(r) with arith.select cond, a, b) - /// cf.cond_br cond, ^bb2(, o0), ^bb3(, o1) - /// ``` - FailureOr handleInterface(cf::CondBranchOp op); - /// Make sure that for each forwarded MemRef value, an ownership indicator /// `i1` value is forwarded as well such that the successor block knows /// whether the MemRef has to be deallocated. @@ -492,18 +360,6 @@ class BufferDeallocation { /// this function has to be called on blocks in a region in dominance order. LogicalResult deallocate(Block *block); - /// Small helper function to update the ownership map by taking the current - /// ownership ('Uninitialized' state if not yet present), computing the join - /// with the passed ownership and storing this new value in the map. By - /// default, it will be performed for the block where 'owned' is defined. If - /// the ownership of the given value should be updated for another block, the - /// 'block' argument can be explicitly passed. - void joinOwnership(Value owned, Ownership ownership, Block *block = nullptr); - - /// Removes ownerships associated with all values in the passed range for - /// 'block'. - void clearOwnershipOf(ValueRange values, Block *block); - /// After all relevant interfaces of an operation have been processed by the /// 'handleInterface' functions, this function sets the ownership of operation /// results that have not been set yet by the 'handleInterface' functions. It @@ -517,51 +373,6 @@ class BufferDeallocation { /// operations, etc.). void populateRemainingOwnerships(Operation *op); - /// Given two basic blocks and the values passed via block arguments to the - /// destination block, compute the list of MemRefs that have to be retained in - /// the 'fromBlock' to not run into a use-after-free situation. - /// This list consists of the MemRefs in the successor operand list of the - /// terminator and the MemRefs in the 'out' set of the liveness analysis - /// intersected with the 'in' set of the destination block. - /// - /// toRetain = filter(successorOperands + (liveOut(fromBlock) insersect - /// liveIn(toBlock)), isMemRef) - void getMemrefsToRetain(Block *fromBlock, Block *toBlock, - ValueRange destOperands, - SmallVectorImpl &toRetain) const; - - /// For a given block, computes the list of MemRefs that potentially need to - /// be deallocated at the end of that block. This list also contains values - /// that have to be retained (and are thus part of the list returned by - /// `getMemrefsToRetain`) and is computed by taking the MemRefs in the 'in' - /// set of the liveness analysis of 'block' appended by the set of MemRefs - /// allocated in 'block' itself and subtracted by the set of MemRefs - /// deallocated in 'block'. - /// Note that we don't have to take the intersection of the liveness 'in' set - /// with the 'out' set of the predecessor block because a value that is in the - /// 'in' set must be defined in an ancestor block that dominates all direct - /// predecessors and thus the 'in' set of this block is a subset of the 'out' - /// sets of each predecessor. - /// - /// memrefs = filter((liveIn(block) U - /// allocated(block) U arguments(block)) \ deallocated(block), isMemRef) - /// - /// The list of conditions is then populated by querying the internal - /// datastructures for the ownership value of that MemRef. - LogicalResult - getMemrefsAndConditionsToDeallocate(OpBuilder &builder, Location loc, - Block *block, - SmallVectorImpl &memrefs, - SmallVectorImpl &conditions) const; - - /// Given an SSA value of MemRef type, this function queries the ownership and - /// if it is not already in the 'Unique' state, potentially inserts IR to get - /// a new SSA value, returned as the first element of the pair, which has - /// 'Unique' ownership and can be used instead of the passed Value with the - /// the ownership indicator returned as the second element of the pair. - std::pair getMemrefWithUniqueOwnership(OpBuilder &builder, - Value memref); - /// Given an SSA value of MemRef type, returns the same of a new SSA value /// which has 'Unique' ownership where the ownership indicator is guaranteed /// to be always 'true'. @@ -602,27 +413,13 @@ class BufferDeallocation { static LogicalResult updateFunctionSignature(FunctionOpInterface op); private: - // Mapping from each SSA value with MemRef type to the associated ownership in - // each block. - DenseMap, Ownership> ownershipMap; - - // Collects the list of MemRef values that potentially need to be deallocated - // per block. It is also fine (albeit not efficient) to add MemRef values that - // don't have to be deallocated, but only when the ownership is not 'Unknown'. - DenseMap> memrefsToDeallocatePerBlock; - - // Symbol cache to lookup functions from call operations to check attributes - // on the function operation. - SymbolTableCollection symbolTable; - - // The underlying liveness analysis to compute fine grained information about - // alloc and dealloc positions. - Liveness liveness; - - // A pass option indicating whether private functions should be modified to - // pass the ownership of MemRef values instead of adhering to the function - // boundary ABI. - bool privateFuncDynamicOwnership; + /// Collects all analysis state and including liveness, caches, ownerships of + /// already processed values and operations, and the MemRefs that have to be + /// deallocated at the end of each block. + DeallocationState state; + + /// Collects all pass options in a single place. + DeallocationOptions options; }; } // namespace @@ -631,22 +428,6 @@ class BufferDeallocation { // BufferDeallocation Implementation //===----------------------------------------------------------------------===// -void BufferDeallocation::joinOwnership(Value owned, Ownership ownership, - Block *block) { - // In most cases we care about the block where the value is defined. - if (block == nullptr) - block = owned.getParentBlock(); - - // Update ownership of current memref itself. - ownershipMap[{owned, block}].combine(ownership); -} - -void BufferDeallocation::clearOwnershipOf(ValueRange values, Block *block) { - for (Value val : values) { - ownershipMap[{val, block}] = Ownership::getUninitialized(); - } -} - static bool regionOperatesOnMemrefValues(Region ®ion) { WalkResult result = region.walk([](Block *block) { if (llvm::any_of(block->getArguments(), isMemref)) @@ -717,10 +498,10 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) { // We only support terminators with 0 or 1 successors for now and // special-case the conditional branch op. - if (op->getSuccessors().size() > 1 && !isa(op)) + if (op->getSuccessors().size() > 1) return op->emitError("Terminators with more than one successor " - "are not supported (except cf.cond_br)!"); + "are not supported!"); } return success(); @@ -776,80 +557,26 @@ LogicalResult BufferDeallocation::deallocate(FunctionOpInterface op) { return updateFunctionSignature(op); } -void BufferDeallocation::getMemrefsToRetain( - Block *fromBlock, Block *toBlock, ValueRange destOperands, - SmallVectorImpl &toRetain) const { - for (Value operand : destOperands) { - if (!isMemref(operand)) - continue; - toRetain.push_back(operand); - } - - SmallPtrSet liveOut; - for (auto val : liveness.getLiveOut(fromBlock)) - if (isMemref(val)) - liveOut.insert(val); - - if (toBlock) - llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock)); - - // liveOut has non-deterministic order because it was constructed by iterating - // over a hash-set. - SmallVector retainedByLiveness(liveOut.begin(), liveOut.end()); - std::sort(retainedByLiveness.begin(), retainedByLiveness.end(), - ValueComparator()); - toRetain.append(retainedByLiveness); -} - -LogicalResult BufferDeallocation::getMemrefsAndConditionsToDeallocate( - OpBuilder &builder, Location loc, Block *block, - SmallVectorImpl &memrefs, SmallVectorImpl &conditions) const { - - for (auto [i, memref] : - llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) { - Ownership ownership = ownershipMap.lookup({memref, block}); - assert(ownership.isUnique() && "MemRef value must have valid ownership"); - - // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such - // that we can call extract_strided_metadata on it. - if (auto unrankedMemRefTy = dyn_cast(memref.getType())) - memref = builder.create( - loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref, - 0, SmallVector{}, SmallVector{}); - - // Use the `memref.extract_strided_metadata` operation to get the base - // memref. This is needed because the same MemRef that was produced by the - // alloc operation has to be passed to the dealloc operation. Passing - // subviews, etc. to a dealloc operation is not allowed. - memrefs.push_back( - builder.create(loc, memref) - .getResult(0)); - conditions.push_back(ownership.getIndicator()); - } - - return success(); -} - LogicalResult BufferDeallocation::deallocate(Block *block) { OpBuilder builder = OpBuilder::atBlockBegin(block); // Compute liveness transfers of ownership to this block. - for (auto li : liveness.getLiveIn(block)) { - if (!isMemref(li)) - continue; - + SmallVector liveMemrefs; + state.getLiveMemrefsIn(block, liveMemrefs); + for (auto li : liveMemrefs) { // Ownership of implicitly captured memrefs from other regions is never // taken, but ownership of memrefs in the same region (but different block) // is taken. if (li.getParentRegion() == block->getParent()) { - joinOwnership(li, ownershipMap[{li, li.getParentBlock()}], block); - memrefsToDeallocatePerBlock[block].push_back(li); + state.updateOwnership(li, state.getOwnership(li, li.getParentBlock()), + block); + state.addMemrefToDeallocate(li, block); continue; } if (li.getParentRegion()->isProperAncestor(block->getParent())) { Value falseVal = buildBoolValue(builder, li.getLoc(), false); - joinOwnership(li, falseVal, block); + state.updateOwnership(li, falseVal, block); } } @@ -863,14 +590,15 @@ LogicalResult BufferDeallocation::deallocate(Block *block) { if (isFunctionWithoutDynamicOwnership(block->getParentOp()) && block->isEntryBlock()) { Value newArg = buildBoolValue(builder, arg.getLoc(), false); - joinOwnership(arg, newArg); + state.updateOwnership(arg, newArg); + state.addMemrefToDeallocate(arg, block); continue; } // Pass MemRef ownerships along via `i1` values. Value newArg = block->addArgument(builder.getI1Type(), arg.getLoc()); - joinOwnership(arg, newArg); - memrefsToDeallocatePerBlock[block].push_back(arg); + state.updateOwnership(arg, newArg); + state.addMemrefToDeallocate(arg, block); } // For each operation in the block, handle the interfaces that affect aliasing @@ -906,97 +634,6 @@ Operation *BufferDeallocation::appendOpResults(Operation *op, return newOp; } -FailureOr -BufferDeallocation::handleInterface(cf::CondBranchOp op) { - OpBuilder builder(op); - - // The list of memrefs to pass to the `bufferization.dealloc` op as "memrefs - // to deallocate" in this block is independent of which branch is taken. - SmallVector memrefs, ownerships; - if (failed(getMemrefsAndConditionsToDeallocate( - builder, op.getLoc(), op->getBlock(), memrefs, ownerships))) - return failure(); - - // Helper lambda to factor out common logic for inserting the dealloc - // operations for each successor. - auto insertDeallocForBranch = - [&](Block *target, MutableOperandRange destOperands, - ArrayRef conditions, - DenseMap &ownershipMapping) -> DeallocOp { - SmallVector toRetain; - getMemrefsToRetain(op->getBlock(), target, OperandRange(destOperands), - toRetain); - auto deallocOp = builder.create( - op.getLoc(), memrefs, conditions, toRetain); - clearOwnershipOf(deallocOp.getRetained(), op->getBlock()); - for (auto [retained, ownership] : - llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { - joinOwnership(retained, ownership, op->getBlock()); - ownershipMapping[retained] = ownership; - } - SmallVector replacements, ownerships; - for (Value operand : destOperands) { - replacements.push_back(operand); - if (isMemref(operand)) { - assert(ownershipMapping.contains(operand) && - "Should be contained at this point"); - ownerships.push_back(ownershipMapping[operand]); - } - } - replacements.append(ownerships); - destOperands.assign(replacements); - return deallocOp; - }; - - // Call the helper lambda and make sure the dealloc conditions are properly - // modified to reflect the branch condition as well. - DenseMap thenOwnershipMap, elseOwnershipMap; - - // Retain `trueDestOperands` if "true" branch is taken. - SmallVector thenOwnerships( - llvm::map_range(ownerships, [&](Value cond) { - return builder.create(op.getLoc(), cond, - op.getCondition()); - })); - DeallocOp thenTakenDeallocOp = - insertDeallocForBranch(op.getTrueDest(), op.getTrueDestOperandsMutable(), - thenOwnerships, thenOwnershipMap); - - // Retain `elseDestOperands` if "false" branch is taken. - SmallVector elseOwnerships( - llvm::map_range(ownerships, [&](Value cond) { - Value trueVal = builder.create( - op.getLoc(), builder.getBoolAttr(true)); - Value negation = builder.create(op.getLoc(), trueVal, - op.getCondition()); - return builder.create(op.getLoc(), cond, negation); - })); - DeallocOp elseTakenDeallocOp = insertDeallocForBranch( - op.getFalseDest(), op.getFalseDestOperandsMutable(), elseOwnerships, - elseOwnershipMap); - - // We specifically need to update the ownerships of values that are retained - // in both dealloc operations again to get a combined 'Unique' ownership - // instead of an 'Unknown' ownership. - SmallPtrSet thenValues(thenTakenDeallocOp.getRetained().begin(), - thenTakenDeallocOp.getRetained().end()); - SetVector commonValues; - for (Value val : elseTakenDeallocOp.getRetained()) { - if (thenValues.contains(val)) - commonValues.insert(val); - } - - for (Value retained : commonValues) { - clearOwnershipOf(retained, op->getBlock()); - Value combinedOwnership = builder.create( - op.getLoc(), op.getCondition(), thenOwnershipMap[retained], - elseOwnershipMap[retained]); - joinOwnership(retained, combinedOwnership, op->getBlock()); - } - - return op.getOperation(); -} - FailureOr BufferDeallocation::handleInterface(RegionBranchOpInterface op) { OpBuilder builder = OpBuilder::atBlockBegin(op->getBlock()); @@ -1033,44 +670,18 @@ BufferDeallocation::handleInterface(RegionBranchOpInterface op) { RegionBranchOpInterface newOp = appendOpResults(op, ownershipResults); for (auto result : llvm::make_filter_range(newOp->getResults(), isMemref)) { - joinOwnership(result, newOp->getResult(counter++)); - memrefsToDeallocatePerBlock[newOp->getBlock()].push_back(result); + state.updateOwnership(result, newOp->getResult(counter++)); + state.addMemrefToDeallocate(result, newOp->getBlock()); } return newOp.getOperation(); } -std::pair -BufferDeallocation::getMemrefWithUniqueOwnership(OpBuilder &builder, - Value memref) { - auto iter = ownershipMap.find({memref, memref.getParentBlock()}); - assert(iter != ownershipMap.end() && - "Value must already have been registered in the ownership map"); - - Ownership ownership = iter->second; - if (ownership.isUnique()) - return {memref, ownership.getIndicator()}; - - // Instead of inserting a clone operation we could also insert a dealloc - // operation earlier in the block and use the updated ownerships returned by - // the op for the retained values. Alternatively, we could insert code to - // check aliasing at runtime and use this information to combine two unique - // ownerships more intelligently to not end up with an 'Unknown' ownership in - // the first place. - auto cloneOp = - builder.create(memref.getLoc(), memref); - Value condition = buildBoolValue(builder, memref.getLoc(), true); - Value newMemref = cloneOp.getResult(); - joinOwnership(newMemref, condition); - memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref); - return {newMemref, condition}; -} - Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder, Value memref) { // First, make sure we at least have 'Unique' ownership already. std::pair newMemrefAndOnwership = - getMemrefWithUniqueOwnership(builder, memref); + state.getMemrefWithUniqueOwnership(builder, memref); Value newMemref = newMemrefAndOnwership.first; Value condition = newMemrefAndOnwership.second; @@ -1096,17 +707,16 @@ Value BufferDeallocation::getMemrefWithGuaranteedOwnership(OpBuilder &builder, }) .getResult(0); Value trueVal = buildBoolValue(builder, memref.getLoc(), true); - joinOwnership(maybeClone, trueVal); - memrefsToDeallocatePerBlock[maybeClone.getParentBlock()].push_back( - maybeClone); + state.updateOwnership(maybeClone, trueVal); + state.addMemrefToDeallocate(maybeClone, maybeClone.getParentBlock()); return maybeClone; } FailureOr BufferDeallocation::handleInterface(BranchOpInterface op) { - // Skip conditional branches since we special case them for now. - if (isa(op.getOperation())) - return op.getOperation(); + if (op->getNumSuccessors() > 1) + return op->emitError("BranchOpInterface operations with multiple " + "successors are not supported yet"); if (op->getNumSuccessors() != 1) return emitError(op.getLoc(), @@ -1121,23 +731,24 @@ BufferDeallocation::handleInterface(BranchOpInterface op) { Block *block = op->getBlock(); OpBuilder builder(op); SmallVector memrefs, conditions, toRetain; - if (failed(getMemrefsAndConditionsToDeallocate(builder, op.getLoc(), block, - memrefs, conditions))) + if (failed(state.getMemrefsAndConditionsToDeallocate( + builder, op.getLoc(), block, memrefs, conditions))) return failure(); OperandRange forwardedOperands = op.getSuccessorOperands(0).getForwardedOperands(); - getMemrefsToRetain(block, op->getSuccessor(0), forwardedOperands, toRetain); + state.getMemrefsToRetain(block, op->getSuccessor(0), forwardedOperands, + toRetain); auto deallocOp = builder.create( op.getLoc(), memrefs, conditions, toRetain); // We want to replace the current ownership of the retained values with the // result values of the dealloc operation as they are always unique. - clearOwnershipOf(deallocOp.getRetained(), block); + state.resetOwnerships(deallocOp.getRetained(), block); for (auto [retained, ownership] : llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { - joinOwnership(retained, ownership, block); + state.updateOwnership(retained, ownership, block); } unsigned numAdditionalReturns = llvm::count_if(forwardedOperands, isMemref); @@ -1156,7 +767,7 @@ FailureOr BufferDeallocation::handleInterface(CallOpInterface op) { // Lookup the function operation and check if it has private visibility. If // the function is referenced by SSA value instead of a Symbol, it's assumed // to be always private. - Operation *funcOp = op.resolveCallable(&symbolTable); + Operation *funcOp = op.resolveCallable(state.getSymbolTable()); bool isPrivate = true; if (auto symbol = dyn_cast(funcOp)) isPrivate &= (symbol.getVisibility() == SymbolTable::Visibility::Private); @@ -1166,14 +777,15 @@ FailureOr BufferDeallocation::handleInterface(CallOpInterface op) { // argument/result for each MemRef argument/result to dynamically pass the // current ownership indicator rather than adhering to the function boundary // ABI. - if (privateFuncDynamicOwnership && isPrivate) { + if (options.privateFuncDynamicOwnership && isPrivate) { SmallVector newOperands, ownershipIndicatorsToAdd; for (Value operand : op.getArgOperands()) { if (!isMemref(operand)) { newOperands.push_back(operand); continue; } - auto [memref, condition] = getMemrefWithUniqueOwnership(builder, operand); + auto [memref, condition] = + state.getMemrefWithUniqueOwnership(builder, operand); newOperands.push_back(memref); ownershipIndicatorsToAdd.push_back(condition); } @@ -1187,8 +799,8 @@ FailureOr BufferDeallocation::handleInterface(CallOpInterface op) { op = appendOpResults(op, ownershipTypesToAppend); for (auto result : llvm::make_filter_range(op->getResults(), isMemref)) { - joinOwnership(result, op->getResult(ownershipCounter++)); - memrefsToDeallocatePerBlock[result.getParentBlock()].push_back(result); + state.updateOwnership(result, op->getResult(ownershipCounter++)); + state.addMemrefToDeallocate(result, result.getParentBlock()); } return op.getOperation(); @@ -1199,8 +811,8 @@ FailureOr BufferDeallocation::handleInterface(CallOpInterface op) { // 'true' and remember to deallocate it. Value trueVal = buildBoolValue(builder, op.getLoc(), true); for (auto result : llvm::make_filter_range(op->getResults(), isMemref)) { - joinOwnership(result, trueVal); - memrefsToDeallocatePerBlock[result.getParentBlock()].push_back(result); + state.updateOwnership(result, trueVal); + state.addMemrefToDeallocate(result, result.getParentBlock()); } return op.getOperation(); @@ -1228,13 +840,13 @@ BufferDeallocation::handleInterface(MemoryEffectOpInterface op) { // `memref.alloc`. If we wouldn't set the ownership of the result here, // the default ownership population in `populateRemainingOwnerships` // would assume aliasing with the MemRef operand. - clearOwnershipOf(res, block); - joinOwnership(res, buildBoolValue(builder, op.getLoc(), false)); + state.resetOwnerships(res, block); + state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), false)); continue; } - joinOwnership(res, buildBoolValue(builder, op.getLoc(), true)); - memrefsToDeallocatePerBlock[block].push_back(res); + state.updateOwnership(res, buildBoolValue(builder, op.getLoc(), true)); + state.addMemrefToDeallocate(res, block); } } @@ -1271,11 +883,11 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // dealloc operation. Block *block = op->getBlock(); SmallVector memrefs, conditions, toRetain; - if (failed(getMemrefsAndConditionsToDeallocate(builder, op.getLoc(), block, - memrefs, conditions))) + if (failed(state.getMemrefsAndConditionsToDeallocate( + builder, op.getLoc(), block, memrefs, conditions))) return failure(); - getMemrefsToRetain(block, nullptr, OperandRange(operands), toRetain); + state.getMemrefsToRetain(block, nullptr, OperandRange(operands), toRetain); if (memrefs.empty() && toRetain.empty()) return op.getOperation(); @@ -1284,10 +896,10 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { // We want to replace the current ownership of the retained values with the // result values of the dealloc operation as they are always unique. - clearOwnershipOf(deallocOp.getRetained(), block); + state.resetOwnerships(deallocOp.getRetained(), block); for (auto [retained, ownership] : llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) - joinOwnership(retained, ownership, block); + state.updateOwnership(retained, ownership, block); // Add an additional operand for every MemRef for the ownership indicator. if (!funcWithoutDynamicOwnership) { @@ -1304,7 +916,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) { bool BufferDeallocation::isFunctionWithoutDynamicOwnership(Operation *op) { auto funcOp = dyn_cast(op); - return funcOp && (!privateFuncDynamicOwnership || + return funcOp && (!options.privateFuncDynamicOwnership || funcOp.getVisibility() != SymbolTable::Visibility::Private); } @@ -1312,14 +924,14 @@ void BufferDeallocation::populateRemainingOwnerships(Operation *op) { for (auto res : op->getResults()) { if (!isMemref(res)) continue; - if (ownershipMap.count({res, op->getBlock()})) + if (!state.getOwnership(res, op->getBlock()).isUninitialized()) continue; // Don't take ownership of a returned memref if no allocate side-effect is // present, relevant for memref.get_global, for example. if (op->getNumOperands() == 0) { OpBuilder builder(op); - joinOwnership(res, buildBoolValue(builder, op->getLoc(), false)); + state.updateOwnership(res, buildBoolValue(builder, op->getLoc(), false)); continue; } @@ -1329,8 +941,9 @@ void BufferDeallocation::populateRemainingOwnerships(Operation *op) { if (!isMemref(operand)) continue; - ownershipMap[{res, op->getBlock()}].combine( - ownershipMap[{operand, operand.getParentBlock()}]); + state.updateOwnership( + res, state.getOwnership(operand, operand.getParentBlock()), + op->getBlock()); } } } diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp new file mode 100644 index 0000000000000..e847e946eef1b --- /dev/null +++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -0,0 +1,163 @@ +//===- BufferDeallocationOpInterfaceImpl.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace mlir::bufferization; + +static bool isMemref(Value v) { return v.getType().isa(); } + +namespace { +/// While CondBranchOp also implement the BranchOpInterface, we add a +/// special-case implementation here because the BranchOpInterface does not +/// offer all of the functionallity we need to insert dealloc oeprations in an +/// efficient way. More precisely, there is no way to extract the branch +/// condition without casting to CondBranchOp specifically. It is still +/// possible to implement deallocation for cases where we don't know to which +/// successor the terminator branches before the actual branch happens by +/// inserting auxiliary blocks and putting the dealloc op there, however, this +/// can lead to less efficient code. +/// This function inserts two dealloc operations (one for each successor) and +/// adjusts the dealloc conditions according to the branch condition, then the +/// ownerships of the retained MemRefs are updated by combining the result +/// values of the two dealloc operations. +/// +/// Example: +/// ``` +/// ^bb1: +/// +/// cf.cond_br cond, ^bb2(), ^bb3() +/// ``` +/// becomes +/// ``` +/// // let (m, c) = getMemrefsAndConditionsToDeallocate(bb1) +/// // let r0 = getMemrefsToRetain(bb1, bb2, ) +/// // let r1 = getMemrefsToRetain(bb1, bb3, ) +/// ^bb1: +/// +/// let thenCond = map(c, (c) -> arith.andi cond, c) +/// let elseCond = map(c, (c) -> arith.andi (arith.xori cond, true), c) +/// o0 = bufferization.dealloc m if thenCond retain r0 +/// o1 = bufferization.dealloc m if elseCond retain r1 +/// // replace ownership(r0) with o0 element-wise +/// // replace ownership(r1) with o1 element-wise +/// // let ownership0 := (r) -> o in o0 corresponding to r +/// // let ownership1 := (r) -> o in o1 corresponding to r +/// // let cmn := intersection(r0, r1) +/// foreach (a, b) in zip(map(cmn, ownership0), map(cmn, ownership1)): +/// forall r in r0: replace ownership0(r) with arith.select cond, a, b) +/// forall r in r1: replace ownership1(r) with arith.select cond, a, b) +/// cf.cond_br cond, ^bb2(, o0), ^bb3(, o1) +/// ``` +struct CondBranchOpInterface + : public BufferDeallocationOpInterface::ExternalModel { + FailureOr process(Operation *op, DeallocationState &state, + const DeallocationOptions &options) const { + OpBuilder builder(op); + auto condBr = cast(op); + + // The list of memrefs to deallocate in this block is independent of which + // branch is taken. + SmallVector memrefs, conditions; + if (failed(state.getMemrefsAndConditionsToDeallocate( + builder, condBr.getLoc(), condBr->getBlock(), memrefs, conditions))) + return failure(); + + // Helper lambda to factor out common logic for inserting the dealloc + // operations for each successor. + auto insertDeallocForBranch = + [&](Block *target, MutableOperandRange destOperands, + const std::function &conditionModifier, + DenseMap &mapping) -> DeallocOp { + SmallVector toRetain; + state.getMemrefsToRetain(condBr->getBlock(), target, + OperandRange(destOperands), toRetain); + SmallVector adaptedConditions( + llvm::map_range(conditions, conditionModifier)); + auto deallocOp = builder.create( + condBr.getLoc(), memrefs, adaptedConditions, toRetain); + state.resetOwnerships(deallocOp.getRetained(), condBr->getBlock()); + for (auto [retained, ownership] : llvm::zip( + deallocOp.getRetained(), deallocOp.getUpdatedConditions())) { + state.updateOwnership(retained, ownership, condBr->getBlock()); + mapping[retained] = ownership; + } + SmallVector replacements, ownerships; + for (Value operand : destOperands) { + replacements.push_back(operand); + if (isMemref(operand)) { + assert(mapping.contains(operand) && + "Should be contained at this point"); + ownerships.push_back(mapping[operand]); + } + } + replacements.append(ownerships); + destOperands.assign(replacements); + return deallocOp; + }; + + // Call the helper lambda and make sure the dealloc conditions are properly + // modified to reflect the branch condition as well. + DenseMap thenMapping, elseMapping; + DeallocOp thenTakenDeallocOp = insertDeallocForBranch( + condBr.getTrueDest(), condBr.getTrueDestOperandsMutable(), + [&](Value cond) { + return builder.create(condBr.getLoc(), cond, + condBr.getCondition()); + }, + thenMapping); + DeallocOp elseTakenDeallocOp = insertDeallocForBranch( + condBr.getFalseDest(), condBr.getFalseDestOperandsMutable(), + [&](Value cond) { + Value trueVal = builder.create( + condBr.getLoc(), builder.getBoolAttr(true)); + Value negation = builder.create( + condBr.getLoc(), trueVal, condBr.getCondition()); + return builder.create(condBr.getLoc(), cond, negation); + }, + elseMapping); + + // We specifically need to update the ownerships of values that are retained + // in both dealloc operations again to get a combined 'Unique' ownership + // instead of an 'Unknown' ownership. + SmallPtrSet thenValues(thenTakenDeallocOp.getRetained().begin(), + thenTakenDeallocOp.getRetained().end()); + SetVector commonValues; + for (Value val : elseTakenDeallocOp.getRetained()) { + if (thenValues.contains(val)) + commonValues.insert(val); + } + + for (Value retained : commonValues) { + state.resetOwnerships(retained, condBr->getBlock()); + Value combinedOwnership = builder.create( + condBr.getLoc(), condBr.getCondition(), thenMapping[retained], + elseMapping[retained]); + state.updateOwnership(retained, combinedOwnership, condBr->getBlock()); + } + + return condBr.getOperation(); + } +}; + +} // namespace + +void mlir::cf::registerBufferDeallocationOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, ControlFlowDialect *dialect) { + CondBranchOp::attachInterface(*ctx); + }); +} diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt index b2ef59887515e..37b4cfc893879 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRControlFlowTransforms + BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS @@ -10,4 +11,4 @@ add_mlir_dialect_library(MLIRControlFlowTransforms MLIRControlFlowDialect MLIRMemRefDialect MLIRIR - ) +) diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir index d8090591c7051..66449aa2ffdb6 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir @@ -500,10 +500,10 @@ func.func @assumingOp( // CHECK: test.copy // CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0 // CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V1]]#0 -// CHECK: bufferization.dealloc ([[BASE0]] :{{.*}}) if ([[V0]]#1) -// CHECK-NOT: retain // CHECK: bufferization.dealloc ([[BASE1]] :{{.*}}) if ([[V1]]#1) // CHECK-NOT: retain +// CHECK: bufferization.dealloc ([[BASE0]] :{{.*}}) if ([[V0]]#1) +// CHECK-NOT: retain // CHECK: return // ----- diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index d4390e7651be0..2447f63bab29a 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -12066,6 +12066,7 @@ gentbl_cc_library( cc_library( name = "BufferizationDialect", srcs = [ + "lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp", "lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp", "lib/Dialect/Bufferization/IR/BufferizationDialect.cpp", "lib/Dialect/Bufferization/IR/BufferizationOps.cpp", @@ -12073,6 +12074,7 @@ cc_library( "lib/Dialect/Bufferization/IR/UnstructuredControlFlow.cpp", ], hdrs = [ + "include/mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h", "include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h", "include/mlir/Dialect/Bufferization/IR/Bufferization.h", "include/mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h", @@ -12083,7 +12085,9 @@ cc_library( deps = [ ":AffineDialect", ":AllocationOpInterface", + ":Analysis", ":ArithDialect", + ":BufferDeallocationOpInterfaceIncGen", ":BufferizableOpInterfaceIncGen", ":BufferizationBaseIncGen", ":BufferizationEnumsIncGen",