Skip to content

Commit

Permalink
[mlir][linalg][bufferize][NFC] Merge AllocationCallbacks into Bufferi…
Browse files Browse the repository at this point in the history
…zationOptions

Also move `createAlloc` and related helper functions out of BufferizationState. The goal is to make BufferizationState as small as possible. (Code cleanup)

Differential Revision: https://reviews.llvm.org/D117476
  • Loading branch information
matthias-springer committed Jan 19, 2022
1 parent b44defa commit be8742b
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 111 deletions.
Expand Up @@ -38,34 +38,6 @@ struct BufferizationOptions;
class BufferizationState;
struct PostAnalysisStep;

/// Callback functions that are used to allocate/deallocate/copy memory buffers.
/// Comprehensive Bufferize provides default implementations of these functions.
// TODO: Could be replaced with a "bufferization strategy" object with virtual
// functions in the future.
struct AllocationCallbacks {
using AllocationFn = std::function<FailureOr<Value>(
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;

AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
MemCpyFn copyFn)
: allocationFn(std::move(allocFn)), deallocationFn(std::move(deallocFn)),
memCpyFn(std::move(copyFn)) {}

/// A function that allocates memory.
AllocationFn allocationFn;

/// A function that deallocated memory. Must be allocated by `allocationFn`.
DeallocationFn deallocationFn;

/// A function that copies memory between two allocations.
MemCpyFn memCpyFn;
};

/// Return default allocation callbacks.
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();

/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used to
/// implement custom dialect-specific optimizations.
Expand All @@ -84,6 +56,13 @@ using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;

/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
using AllocationFn = std::function<FailureOr<Value>(
OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
using DeallocationFn =
std::function<LogicalResult(OpBuilder &, Location, Value)>;
using MemCpyFn =
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;

BufferizationOptions();

// BufferizationOptions cannot be copied.
Expand Down Expand Up @@ -126,7 +105,9 @@ struct BufferizationOptions {
BufferizableOpInterface dynCastBufferizableOp(Value value) const;

/// Helper functions for allocation, deallocation, memory copying.
std::unique_ptr<AllocationCallbacks> allocationFns;
Optional<AllocationFn> allocationFn;
Optional<DeallocationFn> deallocationFn;
Optional<MemCpyFn> memCpyFn;

/// Specifies whether returning newly allocated memrefs should be allowed.
/// Otherwise, a pass failure is triggered.
Expand Down Expand Up @@ -362,24 +343,6 @@ class BufferizationState {
/// is returned regardless of whether it is a memory write or not.
SetVector<Value> findLastPrecedingWrite(Value value) const;

/// Creates a memref allocation.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape) const;

/// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref) const;

/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
void createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer) const;

/// Creates a memcpy between two given buffers.
void createMemCpy(OpBuilder &b, Location loc, Value from, Value to) const;

/// Return `true` if the given OpResult has been decided to bufferize inplace.
bool isInPlace(OpOperand &opOperand) const;

Expand Down Expand Up @@ -458,6 +421,28 @@ UnrankedMemRefType getUnrankedMemRefType(Type elementType,
MemRefType getDynamicMemRefType(RankedTensorType tensorType,
unsigned addressSpace = 0);

/// Creates a memref allocation.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape,
const BufferizationOptions &options);

/// Creates a memref allocation for the given shaped value. This function may
/// perform additional optimizations such as buffer allocation hoisting. If
/// `createDealloc`, a deallocation op is inserted at the point where the
/// allocation goes out of scope.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, Value shapedValue,
bool deallocMemref,
const BufferizationOptions &options);

/// Creates a memref deallocation. The given memref buffer must have been
/// allocated using `createAlloc`.
LogicalResult createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer,
const BufferizationOptions &options);

/// Creates a memcpy between two given buffers.
LogicalResult createMemCpy(OpBuilder &b, Location loc, Value from, Value to,
const BufferizationOptions &options);

} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir
Expand Down
Expand Up @@ -39,40 +39,8 @@ using namespace linalg::comprehensive_bufferize;
// BufferizationOptions
//===----------------------------------------------------------------------===//

/// Default allocation function that is used by the comprehensive bufferization
/// pass. The default currently creates a ranked memref using `memref.alloc`.
static FailureOr<Value> defaultAllocationFn(OpBuilder &b, Location loc,
MemRefType type,
ArrayRef<Value> dynShape) {
Value allocated = b.create<memref::AllocOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
}

/// Default deallocation function that is used by the comprehensive
/// bufferization pass. It expects to recieve back the value called from the
/// `defaultAllocationFn`.
static void defaultDeallocationFn(OpBuilder &b, Location loc,
Value allocatedBuffer) {
b.create<memref::DeallocOp>(loc, allocatedBuffer);
}

/// Default memory copy function that is used by the comprehensive bufferization
/// pass. Creates a `memref.copy` op.
static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
b.create<memref::CopyOp>(loc, from, to);
}

std::unique_ptr<AllocationCallbacks>
mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
return std::make_unique<AllocationCallbacks>(
defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
}

// Default constructor for BufferizationOptions that sets all allocation
// callbacks to their default functions.
BufferizationOptions::BufferizationOptions()
: allocationFns(defaultAllocationCallbacks()) {}
// Default constructor for BufferizationOptions.
BufferizationOptions::BufferizationOptions() {}

BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
Expand Down Expand Up @@ -393,8 +361,8 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
// allocation should be inserted (in the absence of allocation hoisting).
setInsertionPointAfter(rewriter, operandBuffer);
// Allocate the result buffer.
FailureOr<Value> resultBuffer =
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
FailureOr<Value> resultBuffer = createAlloc(rewriter, loc, operandBuffer,
options.createDeallocs, options);
if (failed(resultBuffer))
return failure();
// Do not copy if the last preceding writes of `operand` are ops that do
Expand Down Expand Up @@ -425,7 +393,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
// The copy happens right before the op that is bufferized.
rewriter.setInsertionPoint(op);
}
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer);
if (failed(
createMemCpy(rewriter, loc, operandBuffer, *resultBuffer, options)))
return failure();

return resultBuffer;
}
Expand Down Expand Up @@ -545,9 +515,9 @@ static MemRefType getAllocationTypeAndShape(OpBuilder &b, Location loc,
/// Create an AllocOp/DeallocOp pair, where the AllocOp is after
/// `shapedValue.getDefiningOp` (or at the top of the block in case of a
/// bbArg) and the DeallocOp is at the end of the block.
FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref) const {
FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref,
const BufferizationOptions &options) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(b);

Expand All @@ -558,7 +528,8 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
// Note: getAllocationTypeAndShape also sets the insertion point.
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
FailureOr<Value> allocated = createAlloc(b, loc, allocMemRefType, dynShape);
FailureOr<Value> allocated =
createAlloc(b, loc, allocMemRefType, dynShape, options);
if (failed(allocated))
return failure();
Value casted = allocated.getValue();
Expand All @@ -572,30 +543,47 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
if (deallocMemref) {
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
createDealloc(b, loc, allocated.getValue());
if (failed(createDealloc(b, loc, allocated.getValue(), options)))
return failure();
}

return casted;
}

/// Create a memref allocation.
FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::createAlloc(
OpBuilder &b, Location loc, MemRefType type,
ArrayRef<Value> dynShape) const {
return options.allocationFns->allocationFn(b, loc, type, dynShape);
FailureOr<Value> mlir::linalg::comprehensive_bufferize::createAlloc(
OpBuilder &b, Location loc, MemRefType type, ArrayRef<Value> dynShape,
const BufferizationOptions &options) {
if (options.allocationFn)
return (*options.allocationFn)(b, loc, type, dynShape);

// Default bufferallocation via AllocOp.
Value allocated = b.create<memref::AllocOp>(
loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
return allocated;
}

/// Create a memref deallocation.
void mlir::linalg::comprehensive_bufferize::BufferizationState::createDealloc(
OpBuilder &b, Location loc, Value allocatedBuffer) const {
return options.allocationFns->deallocationFn(b, loc, allocatedBuffer);
LogicalResult mlir::linalg::comprehensive_bufferize::createDealloc(
OpBuilder &b, Location loc, Value allocatedBuffer,
const BufferizationOptions &options) {
if (options.deallocationFn)
return (*options.deallocationFn)(b, loc, allocatedBuffer);

// Default buffer deallocation via DeallocOp.
b.create<memref::DeallocOp>(loc, allocatedBuffer);
return success();
}

/// Create a memory copy between two memref buffers.
void mlir::linalg::comprehensive_bufferize::BufferizationState::createMemCpy(
OpBuilder &b, Location loc, Value from, Value to) const {
return options.allocationFns->memCpyFn(b, loc, from, to);
LogicalResult mlir::linalg::comprehensive_bufferize::createMemCpy(
OpBuilder &b, Location loc, Value from, Value to,
const BufferizationOptions &options) {
if (options.memCpyFn)
return (*options.memCpyFn)(b, loc, from, to);

b.create<memref::CopyOp>(loc, from, to);
return success();
}

//===----------------------------------------------------------------------===//
Expand Down
Expand Up @@ -221,9 +221,9 @@ struct InitTensorOpInterface
if (initTensorOp->getUses().empty())
return success();

FailureOr<Value> alloc = state.createAlloc(
rewriter, initTensorOp->getLoc(), initTensorOp.result(),
state.getOptions().createDeallocs);
FailureOr<Value> alloc =
createAlloc(rewriter, initTensorOp->getLoc(), initTensorOp.result(),
state.getOptions().createDeallocs, state.getOptions());
if (failed(alloc))
return failure();
replaceOpWithBufferizedValues(rewriter, op, *alloc);
Expand Down Expand Up @@ -367,7 +367,9 @@ struct TiledLoopOpInterface
Value output = std::get<1>(it);
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
newTerminator.getLoc(), output.getType(), std::get<0>(it));
state.createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, output);
if (failed(createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp,
output, state.getOptions())))
return failure();
}

// Erase old terminator.
Expand Down
Expand Up @@ -158,8 +158,8 @@ struct ExtractSliceOpInterface
Value alloc;
if (!inplace) {
FailureOr<Value> allocOrFailure =
state.createAlloc(rewriter, loc, extractSliceOp.result(),
state.getOptions().createDeallocs);
createAlloc(rewriter, loc, extractSliceOp.result(),
state.getOptions().createDeallocs, state.getOptions());
if (failed(allocOrFailure))
return failure();
alloc = *allocOrFailure;
Expand Down Expand Up @@ -191,7 +191,9 @@ struct ExtractSliceOpInterface
if (!inplace) {
// Do not copy if the copied data is never read.
if (state.isValueRead(extractSliceOp.result()))
state.createMemCpy(rewriter, extractSliceOp.getLoc(), subView, alloc);
if (failed(createMemCpy(rewriter, extractSliceOp.getLoc(), subView,
alloc, state.getOptions())))
return failure();
subView = alloc;
}

Expand Down Expand Up @@ -461,7 +463,9 @@ struct InsertSliceOpInterface
// tensor.extract_slice, the copy operation will eventually fold away.
Value srcMemref =
*state.getBuffer(rewriter, insertSliceOp->getOpOperand(0) /*source*/);
state.createMemCpy(rewriter, loc, srcMemref, subView);
if (failed(createMemCpy(rewriter, loc, srcMemref, subView,
state.getOptions())))
return failure();

replaceOpWithBufferizedValues(rewriter, op, *dstMemref);
return success();
Expand Down
Expand Up @@ -77,9 +77,10 @@ static FailureOr<Value> allocationFnUsingAlloca(OpBuilder &b, Location loc,
void LinalgComprehensiveModuleBufferize::runOnOperation() {
auto options = std::make_unique<BufferizationOptions>();
if (useAlloca) {
options->allocationFns->allocationFn = allocationFnUsingAlloca;
options->allocationFns->deallocationFn = [](OpBuilder &b, Location loc,
Value v) {};
options->allocationFn = allocationFnUsingAlloca;
options->deallocationFn = [](OpBuilder &b, Location loc, Value v) {
return success();
};
}

options->allowReturnMemref = allowReturnMemref;
Expand Down

0 comments on commit be8742b

Please sign in to comment.