Skip to content

Commit

Permalink
[mlir][bufferization][NFC] Add error handling to getBuffer
Browse files Browse the repository at this point in the history
This is in preparation of adding memory space support.

Differential Revision: https://reviews.llvm.org/D128277
  • Loading branch information
matthias-springer committed Jun 27, 2022
1 parent 0d0a94a commit 5d50f51
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 102 deletions.
Expand Up @@ -479,14 +479,15 @@ Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
Value getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options);
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options);

/// Return the buffer type for a given Value (tensor) after bufferization.
///
/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
/// This function should only be used if `getBuffer` cannot be used.
BaseMemRefType getBufferType(Value value, const BufferizationOptions &options);
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options);

/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
Expand Down
Expand Up @@ -343,7 +343,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Return the bufferized type of the given tensor block argument. The
block argument is guaranteed to belong to a block of this op.
}],
/*retType=*/"BaseMemRefType",
/*retType=*/"FailureOr<BaseMemRefType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "BlockArgument":$bbArg,
"const BufferizationOptions &":$options),
Expand Down
Expand Up @@ -84,8 +84,10 @@ struct IndexCastOpInterface
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = castOp.getType().cast<TensorType>();

Value source = getBuffer(rewriter, castOp.getIn(), options);
auto sourceType = source.getType().cast<BaseMemRefType>();
FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
if (failed(source))
return failure();
auto sourceType = source->getType().cast<BaseMemRefType>();

// Result type should have same layout and address space as the source type.
BaseMemRefType resultType;
Expand All @@ -100,7 +102,7 @@ struct IndexCastOpInterface
}

replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
source);
*source);
return success();
}
};
Expand Down Expand Up @@ -140,8 +142,14 @@ struct SelectOpInterface
// instead of its OpOperands. In the worst case, 2 copies are inserted at
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options);
Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options);
FailureOr<Value> maybeTrueBuffer =
getBuffer(rewriter, selectOp.getTrueValue(), options);
FailureOr<Value> maybeFalseBuffer =
getBuffer(rewriter, selectOp.getFalseValue(), options);
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
return failure();
Value trueBuffer = *maybeTrueBuffer;
Value falseBuffer = *maybeFalseBuffer;

// The "true" and the "false" operands must have the same type. If the
// buffers have different types, they differ only in their layout map. Cast
Expand Down
17 changes: 10 additions & 7 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Expand Up @@ -480,8 +480,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#endif
}

Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options) {
#ifndef NDEBUG
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
Expand All @@ -494,14 +494,17 @@ Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
Type memrefType = getBufferType(value, options);
ensureToMemrefOpIsValid(value, memrefType);
return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
value);
FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
if (failed(memrefType))
return failure();
ensureToMemrefOpIsValid(value, *memrefType);
return rewriter
.create<bufferization::ToMemrefOp>(value.getLoc(), *memrefType, value)
.getResult();
}

/// Return the buffer type for a given Value (tensor) after bufferization.
BaseMemRefType
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Expand Up @@ -165,8 +165,12 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,

// Get "copy" buffer.
Value copyBuffer;
if (getCopy())
copyBuffer = getBuffer(rewriter, getCopy(), options);
if (getCopy()) {
FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
if (failed(maybeCopyBuffer))
return failure();
copyBuffer = *maybeCopyBuffer;
}

// Compute memory space of this allocation.
unsigned memorySpace;
Expand Down
Expand Up @@ -305,8 +305,13 @@ struct CallOpInterface

// Retrieve buffers for tensor operands.
Value buffer = newOperands[idx];
if (!buffer)
buffer = getBuffer(rewriter, opOperand.get(), options);
if (!buffer) {
FailureOr<Value> maybeBuffer =
getBuffer(rewriter, opOperand.get(), options);
if (failed(maybeBuffer))
return failure();
buffer = *maybeBuffer;
}

// Caller / callee type mismatch is handled with a CastOp.
auto memRefType = funcType.getInput(idx);
Expand Down
Expand Up @@ -44,15 +44,21 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
newInputBuffers.push_back(opOperand->get());
continue;
}
newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options));
FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
if (failed(buffer))
return failure();
newInputBuffers.push_back(*buffer);
}

// New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
Value resultBuffer = getBuffer(rewriter, opOperand->get(), options);
newOutputBuffers.push_back(resultBuffer);
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand->get(), options);
if (failed(resultBuffer))
return failure();
newOutputBuffers.push_back(*resultBuffer);
}

// Merge input/output operands.
Expand Down
121 changes: 83 additions & 38 deletions mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Expand Up @@ -281,14 +281,17 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {

/// Helper function for loop bufferization. Return the bufferized values of the
/// given OpOperands. If an operand is not a tensor, return the original value.
static SmallVector<Value> getBuffers(RewriterBase &rewriter,
MutableArrayRef<OpOperand> operands,
const BufferizationOptions &options) {
static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
if (opOperand.get().getType().isa<TensorType>()) {
Value resultBuffer = getBuffer(rewriter, opOperand.get(), options);
result.push_back(resultBuffer);
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand.get(), options);
if (failed(resultBuffer))
return failure();
result.push_back(*resultBuffer);
} else {
result.push_back(opOperand.get());
}
Expand All @@ -298,36 +301,46 @@ static SmallVector<Value> getBuffers(RewriterBase &rewriter,

/// Helper function for loop bufferization. Compute the buffer that should be
/// yielded from a loop block (loop body or loop condition).
static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
BaseMemRefType type,
const BufferizationOptions &options) {
static FailureOr<Value> getYieldedBuffer(RewriterBase &rewriter, Value tensor,
BaseMemRefType type,
const BufferizationOptions &options) {
assert(tensor.getType().isa<TensorType>() && "expected tensor");
ensureToMemrefOpIsValid(tensor, type);
Value yieldedVal = getBuffer(rewriter, tensor, options);
return castBuffer(rewriter, yieldedVal, type);
FailureOr<Value> yieldedVal = getBuffer(rewriter, tensor, options);
if (failed(yieldedVal))
return failure();
return castBuffer(rewriter, *yieldedVal, type);
}

/// Helper function for loop bufferization. Given a range of values, apply
/// `func` to those marked in `tensorIndices`. Otherwise, store the unmodified
/// value in the result vector.
static SmallVector<Value>
static FailureOr<SmallVector<Value>>
convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
llvm::function_ref<Value(Value, int64_t)> func) {
llvm::function_ref<FailureOr<Value>(Value, int64_t)> func) {
SmallVector<Value> result;
for (const auto &it : llvm::enumerate(values)) {
size_t idx = it.index();
Value val = it.value();
result.push_back(tensorIndices.contains(idx) ? func(val, idx) : val);
if (tensorIndices.contains(idx)) {
FailureOr<Value> maybeVal = func(val, idx);
if (failed(maybeVal))
return failure();
result.push_back(*maybeVal);
} else {
result.push_back(val);
}
}
return result;
}

/// Helper function for loop bufferization. Given a list of pre-bufferization
/// yielded values, compute the list of bufferized yielded values.
SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
TypeRange bufferizedTypes,
const DenseSet<int64_t> &tensorIndices,
const BufferizationOptions &options) {
FailureOr<SmallVector<Value>>
getYieldedValues(RewriterBase &rewriter, ValueRange values,
TypeRange bufferizedTypes,
const DenseSet<int64_t> &tensorIndices,
const BufferizationOptions &options) {
return convertTensorValues(
values, tensorIndices, [&](Value val, int64_t index) {
return getYieldedBuffer(rewriter, val,
Expand All @@ -342,10 +355,19 @@ SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
SmallVector<Value>
getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
const DenseSet<int64_t> &tensorIndices) {
return convertTensorValues(
bbArgs, tensorIndices, [&](Value val, int64_t index) {
return rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val);
});
SmallVector<Value> result;
for (const auto &it : llvm::enumerate(bbArgs)) {
size_t idx = it.index();
Value val = it.value();
if (tensorIndices.contains(idx)) {
result.push_back(
rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val)
.getResult());
} else {
result.push_back(val);
}
}
return result;
}

/// Bufferization of scf.for. Replace with a new scf.for that operates on
Expand Down Expand Up @@ -445,8 +467,9 @@ struct ForOpInterface
return success();
}

BaseMemRefType getBufferType(Operation *op, BlockArgument bbArg,
const BufferizationOptions &options) const {
FailureOr<BaseMemRefType>
getBufferType(Operation *op, BlockArgument bbArg,
const BufferizationOptions &options) const {
auto forOp = cast<scf::ForOp>(op);
return bufferization::getBufferType(
forOp.getOpOperandForRegionIterArg(bbArg).get(), options);
Expand All @@ -462,8 +485,11 @@ struct ForOpInterface
DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs());

// The new memref init_args of the loop.
SmallVector<Value> initArgs =
FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, forOp.getIterOpOperands(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;

// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = rewriter.create<scf::ForOp>(
Expand Down Expand Up @@ -689,13 +715,17 @@ struct WhileOpInterface
getTensorIndices(whileOp.getAfterArguments());

// The new memref init_args of the loop.
SmallVector<Value> initArgs =
FailureOr<SmallVector<Value>> maybeInitArgs =
getBuffers(rewriter, whileOp->getOpOperands(), options);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;

// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
return bufferization::getBufferType(bbArg, options).cast<Type>();
// TODO: error handling
return bufferization::getBufferType(bbArg, options)->cast<Type>();
}));

// Construct a new scf.while op with memref instead of tensor values.
Expand Down Expand Up @@ -727,10 +757,12 @@ struct WhileOpInterface
// Only equivalent buffers or new buffer allocations may be yielded to the
// "after" region.
// TODO: This could be relaxed for better bufferization results.
SmallVector<Value> newConditionArgs =
FailureOr<SmallVector<Value>> newConditionArgs =
getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
indicesAfter, options);
newConditionOp.getArgsMutable().assign(newConditionArgs);
if (failed(newConditionArgs))
return failure();
newConditionOp.getArgsMutable().assign(*newConditionArgs);

// Set up new iter_args and move the loop body block to the new op.
// The old block uses tensors, so wrap the (memref) bbArgs of the new block
Expand All @@ -746,10 +778,12 @@ struct WhileOpInterface
// Only equivalent buffers or new buffer allocations may be yielded to the
// "before" region.
// TODO: This could be relaxed for better bufferization results.
SmallVector<Value> newYieldValues =
FailureOr<SmallVector<Value>> newYieldValues =
getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
indicesBefore, options);
newYieldOp.getResultsMutable().assign(newYieldValues);
if (failed(newYieldValues))
return failure();
newYieldOp.getResultsMutable().assign(*newYieldValues);

// Replace loop results.
replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults());
Expand Down Expand Up @@ -849,13 +883,18 @@ struct YieldOpInterface
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
if (value.getType().isa<TensorType>()) {
Value buffer = getBuffer(rewriter, value, options);
FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
BaseMemRefType resultType =
FailureOr<BaseMemRefType> resultType =
cast<BufferizableOpInterface>(forOp.getOperation())
.getBufferType(forOp.getRegionIterArgs()[it.index()],
options);
buffer = castBuffer(rewriter, buffer, resultType);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
}
newResults.push_back(buffer);
} else {
Expand Down Expand Up @@ -1078,24 +1117,30 @@ struct ParallelInsertSliceOpInterface
// If the op bufferizes out-of-place, allocate the copy before the
// ForeachThreadOp.
rewriter.setInsertionPoint(foreachThreadOp);
Value destBuffer = getBuffer(rewriter, insertOp.getDest(), options);
FailureOr<Value> destBuffer =
getBuffer(rewriter, insertOp.getDest(), options);
if (failed(destBuffer))
return failure();

// Bufferize the ParallelInsertSliceOp outside of the PerformConcurrentlyOp.
rewriter.setInsertionPoint(performConcurrentlyOp);
Value srcBuffer = getBuffer(rewriter, insertOp.getSource(), options);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, insertOp.getSource(), options);
if (failed(srcBuffer))
return failure();
Value subview = rewriter.create<memref::SubViewOp>(
insertOp.getLoc(), destBuffer, insertOp.getMixedOffsets(),
insertOp.getLoc(), *destBuffer, insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), srcBuffer,
if (failed(options.createMemCpy(rewriter, insertOp.getLoc(), *srcBuffer,
subview)))
return failure();
rewriter.eraseOp(op);

// Replace all uses of ForeachThreadOp (just the corresponding result).
rewriter.setInsertionPointAfter(foreachThreadOp);
Value toTensorOp =
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), destBuffer);
rewriter.create<ToTensorOp>(foreachThreadOp.getLoc(), *destBuffer);
unsigned resultNum = 0;
for (Operation &nextOp : performConcurrentlyOp.yieldingOps()) {
if (&nextOp == op)
Expand Down

0 comments on commit 5d50f51

Please sign in to comment.