diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index b609a7fd78fb6..ff8db00f7644e 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -472,9 +472,10 @@ class AnalysisState { /// Create an AllocTensorOp for the given shaped value (memref or tensor). /// If `copy` is set, the shaped value is copied. Otherwise, a tensor with /// undefined contents is allocated. -Value allocateTensorForShapedValue(OpBuilder &b, Location loc, - Value shapedValue, bool escape, - bool copy = true); +FailureOr +allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, + bool escape, const BufferizationOptions &options, + bool copy = true); /// Lookup the buffer for the given value. If the value was not bufferized /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 7e5ccd031abeb..6073c931e53ae 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -46,9 +46,9 @@ constexpr const ::llvm::StringLiteral /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the /// shaped value is copied. Otherwise, a tensor with undefined contents is /// allocated. -Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc, - Value shapedValue, - bool escape, bool copy) { +FailureOr bufferization::allocateTensorForShapedValue( + OpBuilder &b, Location loc, Value shapedValue, bool escape, + const BufferizationOptions &options, bool copy) { Value tensor; if (shapedValue.getType().isa()) { tensor = shapedValue; @@ -88,7 +88,7 @@ Value bufferization::allocateTensorForShapedValue(OpBuilder &b, Location loc, copy ? tensor : Value()); allocTensorOp->setAttr(BufferizationDialect::kEscapeAttrName, b.getBoolArrayAttr({escape})); - return allocTensorOp; + return allocTensorOp.getResult(); } LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( @@ -147,26 +147,30 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( // Insert copies of OpOperands. rewriter.setInsertionPoint(op); for (OpOperand *opOperand : outOfPlaceOpOperands) { - Value copy = allocateTensorForShapedValue( + FailureOr copy = allocateTensorForShapedValue( rewriter, op->getLoc(), opOperand->get(), - escapingOpOperandCopies.contains(opOperand), + escapingOpOperandCopies.contains(opOperand), state.getOptions(), copiedOpOperands.contains(opOperand)); - rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); }); + if (failed(copy)) + return failure(); + rewriter.updateRootInPlace(op, [&]() { opOperand->set(*copy); }); } // Insert copies of OpResults. rewriter.setInsertionPointAfter(op); for (OpResult opResult : outOfPlaceOpResults) { - Value copy = - allocateTensorForShapedValue(rewriter, op->getLoc(), opResult, - escapingOpResultCopies.contains(opResult), - copiedOpResults.count(opResult)); + FailureOr copy = allocateTensorForShapedValue( + rewriter, op->getLoc(), opResult, + escapingOpResultCopies.contains(opResult), state.getOptions(), + copiedOpResults.count(opResult)); + if (failed(copy)) + return failure(); SmallVector uses = llvm::to_vector(llvm::map_range( opResult.getUses(), [](OpOperand &use) { return &use; })); for (OpOperand *use : uses) { // Do not update the alloc_tensor op that we just created. - if (use->getOwner() != copy.getDefiningOp()) - rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); }); + if (use->getOwner() != copy->getDefiningOp()) + rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(*copy); }); } } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 36d0f0cefbae3..0bf4fd3405dd4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -458,9 +458,12 @@ struct ForOpInterface yieldValues.push_back(value); continue; } - Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), - value, /*escape=*/true); - yieldValues.push_back(alloc); + FailureOr alloc = + allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value, + /*escape=*/true, state.getOptions()); + if (failed(alloc)) + return failure(); + yieldValues.push_back(*alloc); } rewriter.updateRootInPlace( @@ -669,9 +672,12 @@ struct WhileOpInterface beforeYieldValues.push_back(value); continue; } - Value alloc = allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), - value, /*escape=*/true); - beforeYieldValues.push_back(alloc); + FailureOr alloc = + allocateTensorForShapedValue(rewriter, conditionOp.getLoc(), value, + /*escape=*/true, state.getOptions()); + if (failed(alloc)) + return failure(); + beforeYieldValues.push_back(*alloc); } rewriter.updateRootInPlace(conditionOp, [&]() { conditionOp.getArgsMutable().assign(beforeYieldValues); @@ -687,9 +693,12 @@ struct WhileOpInterface afterYieldValues.push_back(value); continue; } - Value alloc = allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), - value, /*escape=*/true); - afterYieldValues.push_back(alloc); + FailureOr alloc = + allocateTensorForShapedValue(rewriter, yieldOp.getLoc(), value, + /*escape=*/true, state.getOptions()); + if (failed(alloc)) + return failure(); + afterYieldValues.push_back(*alloc); } rewriter.updateRootInPlace(yieldOp, [&]() { yieldOp.getResultsMutable().assign(afterYieldValues); @@ -972,13 +981,15 @@ struct ForeachThreadOpInterface // Insert tensor allocation. bool isYielded = state.isTensorYielded(opResult); - Value alloc = allocateTensorForShapedValue(rewriter, op->getLoc(), - destOperands.front()->get(), - /*escape=*/isYielded); + FailureOr alloc = allocateTensorForShapedValue( + rewriter, op->getLoc(), destOperands.front()->get(), + /*escape=*/isYielded, state.getOptions()); + if (failed(alloc)) + return failure(); // Update terminator operand. rewriter.updateRootInPlace(destOperands.front()->getOwner(), - [&]() { destOperands.front()->set(alloc); }); + [&]() { destOperands.front()->set(*alloc); }); } return success(); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index e7e31dcd42f54..6f24843522103 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -154,15 +154,17 @@ struct CollapseShapeOpInterface if (!canBeCollapsed) { // TODO: Create alloc_tensor ops during TensorCopyInsertion. AnalysisState analysisState(options); - Value tensorAlloc = allocateTensorForShapedValue( + FailureOr tensorAlloc = allocateTensorForShapedValue( rewriter, op->getLoc(), collapseShapeOp.getSrc(), - analysisState.isTensorYielded(collapseShapeOp.getResult())); + analysisState.isTensorYielded(collapseShapeOp.getResult()), options); + if (failed(tensorAlloc)) + return failure(); auto memrefType = MemRefType::get(collapseShapeOp.getSrcType().getShape(), collapseShapeOp.getSrcType().getElementType(), AffineMap(), bufferType.getMemorySpaceAsInt()); buffer = rewriter.create( - op->getLoc(), memrefType, tensorAlloc); + op->getLoc(), memrefType, *tensorAlloc); } // Result type is inferred by the builder. @@ -383,14 +385,16 @@ struct FromElementsOpInterface auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. AnalysisState analysisState(options); - Value tensorAlloc = allocateTensorForShapedValue( + FailureOr tensorAlloc = allocateTensorForShapedValue( rewriter, loc, fromElementsOp.getResult(), - analysisState.isTensorYielded(fromElementsOp.getResult()), + analysisState.isTensorYielded(fromElementsOp.getResult()), options, /*copy=*/false); + if (failed(tensorAlloc)) + return failure(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); Value buffer = rewriter.create( - op->getLoc(), memrefType, tensorAlloc); + op->getLoc(), memrefType, *tensorAlloc); // Case: tensor<0xelem_type>. if (fromElementsOp.getElements().empty()) { @@ -436,14 +440,16 @@ struct GenerateOpInterface Location loc = op->getLoc(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. AnalysisState analysisState(options); - Value tensorAlloc = allocateTensorForShapedValue( + FailureOr tensorAlloc = allocateTensorForShapedValue( rewriter, loc, generateOp.getResult(), - analysisState.isTensorYielded(generateOp.getResult()), + analysisState.isTensorYielded(generateOp.getResult()), options, /*copy=*/false); + if (failed(tensorAlloc)) + return failure(); auto memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); Value buffer = rewriter.create( - op->getLoc(), memrefType, tensorAlloc); + op->getLoc(), memrefType, *tensorAlloc); // Collect loop bounds. int64_t rank = memrefType.getRank();