Skip to content

Commit

Permalink
[mlir][bufferize][NFC] Change signature of allocateTensorForShapedValue
Browse files Browse the repository at this point in the history
Add a failure return value and bufferization options argument. This is to keep a subsequent change smaller.

Differential Revision: https://reviews.llvm.org/D128278
  • Loading branch information
matthias-springer committed Jun 27, 2022
1 parent f5d781d commit 45b995c
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 38 deletions.
Expand Up @@ -472,9 +472,10 @@ class AnalysisState {
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
Value shapedValue, bool escape,
bool copy = true);
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
bool escape, const BufferizationOptions &options,
bool copy = true);

/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
Expand Down
30 changes: 17 additions & 13 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Expand Up @@ -46,9 +46,9 @@ constexpr const ::llvm::StringLiteral
/// Create an AllocTensorOp for the given shaped value. If `copy` is set, the
/// shaped value is copied. Otherwise, a tensor with undefined contents is
/// allocated.
Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
Value shapedValue,
bool escape, bool copy) {
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue, bool escape,
const BufferizationOptions &options, bool copy) {
Value tensor;
if (shapedValue.getType().isa<RankedTensorType>()) {
tensor = shapedValue;
Expand Down Expand Up @@ -88,7 +88,7 @@ Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc,
copy ? tensor : Value());
allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName,
b.getBoolArrayAttr({escape}));
return allocTensorOp;
return allocTensorOp.getResult();
}

LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
Expand Down Expand Up @@ -147,26 +147,30 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// Insert copies of OpOperands.
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
Value copy = allocateTensorForShapedValue(
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), opOperand->get(),
escapingOpOperandCopies.contains(opOperand),
escapingOpOperandCopies.contains(opOperand), state.getOptions(),
copiedOpOperands.contains(opOperand));
rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); });
if (failed(copy))
return failure();
rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); });
}

// Insert copies of OpResults.
rewriter.setInsertionPointAfter(op);
for (OpResult opResult : outOfPlaceOpResults) {
Value copy =
allocateTensorForShapedValue(rewriter, op->getLoc(), opResult,
escapingOpResultCopies.contains(opResult),
copiedOpResults.count(opResult));
FailureOr<Value> copy = allocateTensorForShapedValue(
rewriter, op->getLoc(), opResult,
escapingOpResultCopies.contains(opResult), state.getOptions(),
copiedOpResults.count(opResult));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(llvm::map_range(
opResult.getUses(), [](OpOperand &use) { return &use; }));
for (OpOperand *use : uses) {
// Do not update the alloc_tensor op that we just created.
if (use->getOwner() != copy.getDefiningOp())
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); });
if (use->getOwner() != copy->getDefiningOp())
rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); });
}
}

Expand Down
37 changes: 24 additions & 13 deletions mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
Expand Up @@ -458,9 +458,12 @@ struct ForOpInterface
yieldValues.push_back(value);
continue;
}
Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
value, /*escape=*/true);
yieldValues.push_back(alloc);
FailureOr<Value> alloc =
allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
/*escape=*/true, state.getOptions());
if (failed(alloc))
return failure();
yieldValues.push_back(*alloc);
}

rewriter.updateRootInPlace(
Expand Down Expand Up @@ -669,9 +672,12 @@ struct WhileOpInterface
beforeYieldValues.push_back(value);
continue;
}
Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(),
value, /*escape=*/true);
beforeYieldValues.push_back(alloc);
FailureOr<Value> alloc =
allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value,
/*escape=*/true, state.getOptions());
if (failed(alloc))
return failure();
beforeYieldValues.push_back(*alloc);
}
rewriter.updateRootInPlace(conditionOp, [&]() {
conditionOp.getArgsMutable().assign(beforeYieldValues);
Expand All @@ -687,9 +693,12 @@ struct WhileOpInterface
afterYieldValues.push_back(value);
continue;
}
Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(),
value, /*escape=*/true);
afterYieldValues.push_back(alloc);
FailureOr<Value> alloc =
allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value,
/*escape=*/true, state.getOptions());
if (failed(alloc))
return failure();
afterYieldValues.push_back(*alloc);
}
rewriter.updateRootInPlace(yieldOp, [&]() {
yieldOp.getResultsMutable().assign(afterYieldValues);
Expand Down Expand Up @@ -972,13 +981,15 @@ struct ForeachThreadOpInterface

// Insert tensor allocation.
bool isYielded = state.isTensorYielded(opResult);
Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(),
destOperands.front()->get(),
/*escape=*/isYielded);
FailureOr<Value> alloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), destOperands.front()->get(),
/*escape=*/isYielded, state.getOptions());
if (failed(alloc))
return failure();

// Update terminator operand.
rewriter.updateRootInPlace(destOperands.front()->getOwner(),
[&]() { destOperands.front()->set(alloc); });
[&]() { destOperands.front()->set(*alloc); });
}

return success();
Expand Down
24 changes: 15 additions & 9 deletions mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Expand Up @@ -154,15 +154,17 @@ struct CollapseShapeOpInterface
if (!canBeCollapsed) {
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), collapseShapeOp.getSrc(),
analysisState.isTensorYielded(collapseShapeOp.getResult()));
analysisState.isTensorYielded(collapseShapeOp.getResult()), options);
if (failed(tensorAlloc))
return failure();
auto memrefType =
MemRefType::get(collapseShapeOp.getSrcType().getShape(),
collapseShapeOp.getSrcType().getElementType(),
AffineMap(), bufferType.getMemorySpaceAsInt());
buffer = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, tensorAlloc);
op->getLoc(), memrefType, *tensorAlloc);
}

// Result type is inferred by the builder.
Expand Down Expand Up @@ -383,14 +385,16 @@ struct FromElementsOpInterface
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, fromElementsOp.getResult(),
analysisState.isTensorYielded(fromElementsOp.getResult()),
analysisState.isTensorYielded(fromElementsOp.getResult()), options,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, tensorAlloc);
op->getLoc(), memrefType, *tensorAlloc);

// Case: tensor<0xelem_type>.
if (fromElementsOp.getElements().empty()) {
Expand Down Expand Up @@ -436,14 +440,16 @@ struct GenerateOpInterface
Location loc = op->getLoc();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, generateOp.getResult(),
analysisState.isTensorYielded(generateOp.getResult()),
analysisState.isTensorYielded(generateOp.getResult()), options,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
Value buffer = rewriter.create<bufferization::ToMemrefOp>(
op->getLoc(), memrefType, tensorAlloc);
op->getLoc(), memrefType, *tensorAlloc);

// Collect loop bounds.
int64_t rank = memrefType.getRank();
Expand Down

0 comments on commit 45b995c

Please sign in to comment.