diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td index c7775f2407ebd..34aaf7432bdfd 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -287,8 +287,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { /*methodBody=*/"", /*defaultImplementation=*/[{ // Does not have to be implemented for ops without tensor OpOperands. - assert(::llvm::isa<::mlir::TensorType>(opOperand.get().getType()) && - "expected OpOperand with tensor type"); + assert(::llvm::isa<::mlir::bufferization::TensorLikeType>(opOperand.get().getType()) && + "expected OpOperand with tensor like type"); llvm_unreachable("getAliasingValues not implemented"); }] >, @@ -358,8 +358,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { "const ::mlir::bufferization::AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ - assert(isa<::mlir::TensorType>(value.getType()) && - "expected tensor type"); + assert(isa<::mlir::bufferization::TensorLikeType>(value.getType()) && + "expected tensor like type"); return ::mlir::bufferization::detail::defaultGetAliasingOpOperands( value, state); }] diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 08319ef9df79a..f77edf23d4bc4 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -161,6 +161,8 @@ Operation *bufferization::getOwnerOfValue(Value value) { return llvm::cast(value).getOwner()->getParentOp(); } +// TODO: Properly support with options, for now it is hardcoded to builtin +// Tensor/MemRef types based approach /// Create an AllocTensorOp for the given shaped value. If `copy` is set, the /// shaped value is copied. Otherwise, a tensor with undefined contents is /// allocated. @@ -229,6 +231,8 @@ FailureOr bufferization::allocateTensorForShapedValue( return allocTensorOp.getResult(); } +// TODO: Properly support with options, for now it is hardcoded to builtin +// Tensor/MemRef types based approach LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( RewriterBase &rewriter, const AnalysisState &analysisState, const BufferizationState &bufferizationState) { @@ -508,7 +512,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const { /// read. Also takes into account ops that create an alias but do not read by /// themselves (e.g., ExtractSliceOp). bool AnalysisState::isValueRead(Value value) const { - assert(llvm::isa(value.getType()) && "expected TensorType"); + assert(llvm::isa(value.getType()) && + "expected TensorLikeType"); SmallVector workingSet; DenseSet visited; for (OpOperand &use : value.getUses()) @@ -948,7 +953,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( Operation *op = getOwnerOfValue(value); SmallVector result; for (OpOperand &opOperand : op->getOpOperands()) { - if (!llvm::isa(opOperand.get().getType())) + if (!llvm::isa(opOperand.get().getType())) continue; AliasingValueList aliasingValues = state.getAliasingValues(opOperand); for (const auto &it : aliasingValues) @@ -1027,7 +1032,7 @@ bufferization::detail::unknownGetAliasingOpOperands(Value value) { // with every OpOperand. AliasingOpOperandList r; for (OpOperand &operand : value.getDefiningOp()->getOpOperands()) - if (isa(operand.get().getType())) + if (isa(operand.get().getType())) r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } @@ -1040,12 +1045,12 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) { // with every OpOperand. AliasingValueList r; for (OpResult result : opOperand.getOwner()->getOpResults()) - if (llvm::isa(result.getType())) + if (llvm::isa(result.getType())) r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false}); for (Region ®ion : opOperand.getOwner()->getRegions()) if (!region.getBlocks().empty()) for (BlockArgument bbArg : region.getBlocks().front().getArguments()) - if (isa(bbArg.getType())) + if (isa(bbArg.getType())) r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false}); return r; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index 4ebb03915348d..57ef3b88b291c 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -67,7 +67,7 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState) using namespace mlir; using namespace mlir::bufferization; -static bool isaTensor(Type t) { return isa(t); } +static bool isaTensor(Type t) { return isa(t); } //===----------------------------------------------------------------------===// // Bufferization-specific attribute manipulation. @@ -100,7 +100,7 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) { } else { inPlaceVector = SmallVector(op->getNumOperands(), "none"); for (OpOperand &opOperand : op->getOpOperands()) - if (isa(opOperand.get().getType())) + if (isa(opOperand.get().getType())) inPlaceVector[opOperand.getOperandNumber()] = "false"; } inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false"; @@ -118,12 +118,12 @@ OneShotAnalysisState::OneShotAnalysisState( // Set up alias sets. op->walk([&](Operation *op) { for (Value v : op->getResults()) - if (isa(v.getType())) + if (isa(v.getType())) createAliasInfoEntry(v); for (Region &r : op->getRegions()) for (Block &b : r.getBlocks()) for (auto bbArg : b.getArguments()) - if (isa(bbArg.getType())) + if (isa(bbArg.getType())) createAliasInfoEntry(bbArg); }); @@ -132,7 +132,7 @@ OneShotAnalysisState::OneShotAnalysisState( if (!options.isOpAllowed(bufferizableOp)) return WalkResult::skip(); for (OpOperand &opOperand : bufferizableOp->getOpOperands()) - if (isa(opOperand.get().getType())) + if (isa(opOperand.get().getType())) if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) bufferizeInPlace(opOperand); return WalkResult::advance(); @@ -195,7 +195,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { // Check all tensor OpResults. for (OpResult opResult : op->getOpResults()) { - if (!isa(opResult.getType())) + if (!isa(opResult.getType())) continue; // If there is no preceding definition, the tensor contents are @@ -1011,7 +1011,7 @@ LogicalResult OneShotAnalysisState::analyzeSingleOp(Operation *op, const DominanceInfo &domInfo) { for (OpOperand &opOperand : op->getOpOperands()) - if (isa(opOperand.get().getType())) + if (isa(opOperand.get().getType())) if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo))) return failure(); return success(); @@ -1023,7 +1023,7 @@ static void equivalenceAnalysis(SmallVector &ops, for (Operation *op : ops) { if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) { for (OpResult opResult : op->getOpResults()) { - if (!isa(opResult.getType())) + if (!isa(opResult.getType())) continue; AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); if (aliases.getNumAliases() == 0) @@ -1095,7 +1095,7 @@ bottomUpFromTerminatorsHeuristic(Operation *op, // we stay within the same region. SmallVector worklist; for (Value v : term->getOperands()) { - if (!isa(v.getType())) + if (!isa(v.getType())) continue; auto opResult = dyn_cast(v); if (!opResult) @@ -1112,7 +1112,7 @@ bottomUpFromTerminatorsHeuristic(Operation *op, AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); for (auto alias : aliases) { Value v = alias.opOperand->get(); - if (!isa(v.getType())) + if (!isa(v.getType())) continue; auto opResult = dyn_cast(v); if (!opResult) @@ -1232,7 +1232,7 @@ checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo, } for (OpOperand &opOperand : op->getOpOperands()) { - if (isa(opOperand.get().getType())) { + if (isa(opOperand.get().getType())) { if (wouldCreateReadAfterWriteInterference( opOperand, domInfo, state, /*checkConsistencyOnly=*/true)) { @@ -1269,7 +1269,7 @@ annotateOpsWithBufferizationMarkers(Operation *op, // Add __inplace_operands_attr__. op->walk([&](Operation *op) { for (OpOperand &opOperand : op->getOpOperands()) - if (isa(opOperand.get().getType())) + if (isa(opOperand.get().getType())) setInPlaceOpOperand(opOperand, state.isInPlace(opOperand)); }); } @@ -1294,7 +1294,7 @@ static void annotateOpsWithAliasSets(Operation *op, // Build alias set array for every OpResult. SmallVector opResultAliasSets; for (OpResult opResult : op->getOpResults()) { - if (llvm::isa(opResult.getType())) { + if (llvm::isa(opResult.getType())) { opResultAliasSets.push_back(buildAliasesArray(opResult)); } } @@ -1309,7 +1309,7 @@ static void annotateOpsWithAliasSets(Operation *op, for (Block &block : r.getBlocks()) { SmallVector bbArgAliasSets; for (BlockArgument bbArg : block.getArguments()) { - if (llvm::isa(bbArg.getType())) { + if (llvm::isa(bbArg.getType())) { bbArgAliasSets.push_back(buildAliasesArray(bbArg)); hasTensorBbArg = true; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 6c5719ce6df8e..4d044bbb74df1 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -62,6 +62,7 @@ #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" @@ -122,10 +123,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, // return value may alias with any tensor bbArg. FunctionType type = funcOp.getFunctionType(); for (const auto &inputIt : llvm::enumerate(type.getInputs())) { - if (!isa(inputIt.value())) + if (!isa(inputIt.value())) continue; for (const auto &resultIt : llvm::enumerate(type.getResults())) { - if (!isa(resultIt.value())) + if (!isa(resultIt.value())) continue; int64_t returnIdx = resultIt.index(); int64_t bbArgIdx = inputIt.index(); @@ -145,13 +146,13 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, // Build alias sets. Merge all aliases from all func.return ops. for (BlockArgument bbArg : funcOp.getArguments()) { - if (isa(bbArg.getType())) { + if (isa(bbArg.getType())) { int64_t bbArgIdx = bbArg.getArgNumber(); // Store aliases in a set, so that we don't add the same alias twice. SetVector aliases; for (func::ReturnOp returnOp : returnOps) { for (OpOperand &returnVal : returnOp->getOpOperands()) { - if (isa(returnVal.get().getType())) { + if (isa(returnVal.get().getType())) { int64_t returnIdx = returnVal.getOperandNumber(); if (state.areAliasingBufferizedValues(returnVal.get(), bbArg)) aliases.insert(returnIdx); @@ -170,10 +171,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state, auto findEquivalentBlockArgIdx = [&](OpOperand &opOperand) -> std::optional { Value v = opOperand.get(); - if (!isa(v.getType())) + if (!isa(v.getType())) return std::nullopt; for (BlockArgument bbArg : funcOp.getArguments()) { - if (isa(bbArg.getType())) { + if (isa(bbArg.getType())) { if (state.areEquivalentBufferizedValues(v, bbArg)) { if (state.getOptions().testAnalysisOnly) annotateEquivalentReturnBbArg(opOperand, bbArg); @@ -243,7 +244,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state, for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e; ++idx) { // Skip non-tensor arguments. - if (!isa(funcOp.getFunctionType().getInput(idx))) + if (!isa(funcOp.getFunctionType().getInput(idx))) continue; bool isRead; bool isWritten; @@ -297,9 +298,9 @@ getCalledFunction(func::CallOp callOp, /// Return "true" if the given function signature has tensor semantics. static bool hasTensorSignature(func::FuncOp funcOp) { return llvm::any_of(funcOp.getFunctionType().getInputs(), - llvm::IsaPred) || + llvm::IsaPred) || llvm::any_of(funcOp.getFunctionType().getResults(), - llvm::IsaPred); + llvm::IsaPred); } /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 9b6a5a96fbc6b..16eb9aadc06f0 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -32,11 +32,14 @@ namespace { /// Helper function for loop bufferization. Cast the given buffer to the given /// memref type. static Value castBuffer(OpBuilder &b, Value buffer, Type type) { - assert(isa(type) && "expected BaseMemRefType"); - assert(isa(buffer.getType()) && "expected BaseMemRefType"); // If the buffer already has the correct type, no cast is needed. if (buffer.getType() == type) return buffer; + + // TODO: Properly support with options, for now it is hardcoded MemRef type + // based approach + assert(isa(type) && "expected BaseMemRefType"); + assert(isa(buffer.getType()) && "expected BaseMemRefType"); // TODO: In case `type` has a layout map that is not the fully dynamic // one, we may not be able to cast the buffer. In that case, the loop // iter_arg's layout map must be changed (see uses of `castBuffer`). @@ -102,7 +105,7 @@ struct ConditionOpInterface SmallVector newArgs; for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Value value = it.value(); - if (isa(value.getType())) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options, state); if (failed(maybeBuffer)) @@ -247,7 +250,7 @@ struct IfOpInterface // Compute bufferized result types. SmallVector newTypes; for (Value result : ifOp.getResults()) { - if (!isa(result.getType())) { + if (!isa(result.getType())) { newTypes.push_back(result.getType()); continue; } @@ -286,25 +289,23 @@ struct IfOpInterface auto opResult = cast(value); auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); - BaseMemRefType thenBufferType, elseBufferType; - if (isa(thenValue.getType())) { + BufferLikeType thenBufferType, elseBufferType; + if (isa(thenValue.getType())) { // True branch was already bufferized. - thenBufferType = cast(thenValue.getType()); + thenBufferType = cast(thenValue.getType()); } else { - auto maybeBufferType = - bufferization::detail::asMemRefType(bufferization::getBufferType( - thenValue, options, state, invocationStack)); + auto maybeBufferType = bufferization::getBufferType( + thenValue, options, state, invocationStack); if (failed(maybeBufferType)) return failure(); thenBufferType = *maybeBufferType; } - if (isa(elseValue.getType())) { + if (isa(elseValue.getType())) { // False branch was already bufferized. - elseBufferType = cast(elseValue.getType()); + elseBufferType = cast(elseValue.getType()); } else { - auto maybeBufferType = - bufferization::detail::asMemRefType(bufferization::getBufferType( - elseValue, options, state, invocationStack)); + auto maybeBufferType = bufferization::getBufferType( + elseValue, options, state, invocationStack); if (failed(maybeBufferType)) return failure(); elseBufferType = *maybeBufferType; @@ -315,12 +316,19 @@ struct IfOpInterface return cast(thenBufferType); // Memory space mismatch. - if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) + auto thenBaseMemRefType = dyn_cast(thenBufferType); + auto elseBaseMemRefType = dyn_cast(elseBufferType); + if (thenBaseMemRefType && elseBaseMemRefType && + thenBaseMemRefType.getMemorySpace() != + elseBaseMemRefType.getMemorySpace()) return op->emitError("inconsistent memory space on then/else branches"); - // Layout maps are different: Promote to fully dynamic layout map. + // TODO: Properly support with options, for now it is hardcoded MemRef type + // based approach Layout maps are different: Promote to fully dynamic layout + // map. return cast(getMemRefTypeWithFullyDynamicLayout( - cast(opResult.getType()), thenBufferType.getMemorySpace())); + cast(opResult.getType()), + thenBaseMemRefType.getMemorySpace())); } }; @@ -399,7 +407,8 @@ struct IndexSwitchOpInterface assert(value.getDefiningOp() == op && "invalid value"); int64_t resultNum = cast(value).getResultNumber(); - // Helper function to get buffer type of a case. + // TODO: Properly support with options, for now it is hardcoded MemRef type + // based approach Helper function to get buffer type of a case. auto getYieldedBufferType = [&](Block &b) -> FailureOr { auto yieldOp = cast(b.getTerminator()); Value yieldedValue = yieldOp->getOperand(resultNum); @@ -430,7 +439,9 @@ struct IndexSwitchOpInterface if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace()) return op->emitError("inconsistent memory space on switch cases"); - // Layout maps are different: Promote to fully dynamic layout map. + // TODO: Properly support with options, for now it is hardcoded MemRef + // type based approach Layout maps are different: Promote to fully dynamic + // layout map. bufferType = getMemRefTypeWithFullyDynamicLayout( cast(value.getType()), bufferType.getMemorySpace()); } @@ -444,7 +455,7 @@ struct IndexSwitchOpInterface static DenseSet getTensorIndices(ValueRange values) { DenseSet result; for (const auto &it : llvm::enumerate(values)) - if (isa(it.value().getType())) + if (isa(it.value().getType())) result.insert(it.index()); return result; } @@ -457,8 +468,8 @@ DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); DenseSet result; for (unsigned int i = 0; i < minSize; ++i) { - if (!isa(bbArgs[i].getType()) || - !isa(yieldedValues[i].getType())) + if (!isa(bbArgs[i].getType()) || + !isa(yieldedValues[i].getType())) continue; if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) result.insert(i); @@ -473,7 +484,7 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, const BufferizationOptions &options, BufferizationState &state) { SmallVector result; for (OpOperand &opOperand : operands) { - if (isa(opOperand.get().getType())) { + if (isa(opOperand.get().getType())) { FailureOr resultBuffer = getBuffer(rewriter, opOperand.get(), options, state); if (failed(resultBuffer)) @@ -547,7 +558,7 @@ static FailureOr computeLoopRegionIterArgBufferType( // Compute the buffer type of the yielded value. BufferLikeType yieldedValueBufferType; - if (isa(yieldedValue.getType())) { + if (isa(yieldedValue.getType())) { // scf.yield was already bufferized. yieldedValueBufferType = cast(yieldedValue.getType()); } else { @@ -582,6 +593,8 @@ static FailureOr computeLoopRegionIterArgBufferType( "expected same shape"); } #endif // NDEBUG + // TODO: Properly support with options, for now it is hardcoded MemRef type + // based approach return cast(getMemRefTypeWithFullyDynamicLayout( iterTensorType, yieldedBufferType.getMemorySpace())); } @@ -712,7 +725,7 @@ struct ForOpInterface SmallVector &invocationStack) const { auto forOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(isa(value.getType()) && "expected tensor type"); + assert(isa(value.getType()) && "expected tensor type"); if (auto opResult = dyn_cast(value)) { // The type of an OpResult must match the corresponding iter_arg type. @@ -757,7 +770,7 @@ struct ForOpInterface Value initArg = it.value(); Value result = forOp->getResult(it.index()); // If the type is not a tensor, bufferization doesn't need to touch it. - if (!isa(result.getType())) { + if (!isa(result.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -809,9 +822,8 @@ struct ForOpInterface auto forOp = cast(op); auto yieldOp = cast(forOp.getBody()->getTerminator()); for (OpResult opResult : op->getOpResults()) { - if (!isa(opResult.getType())) + if (!isa(opResult.getType())) continue; - // Note: This is overly strict. We should check for aliasing bufferized // values. But we don't have a "must-alias" analysis yet. if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) @@ -938,7 +950,7 @@ struct WhileOpInterface for (int64_t idx = 0; idx < static_cast(conditionOp.getArgs().size()); ++idx) { Value value = conditionOp.getArgs()[idx]; - if (!isa(value.getType()) || + if (!isa(value.getType()) || (equivalentYieldsAfter.contains(idx) && equivalentYieldsBefore.contains(idx))) { beforeYieldValues.push_back(value); @@ -982,7 +994,7 @@ struct WhileOpInterface Value initArg = it.value(); Value beforeArg = whileOp.getBeforeArguments()[it.index()]; // If the type is not a tensor, bufferization doesn't need to touch it. - if (!isa(beforeArg.getType())) { + if (!isa(beforeArg.getType())) { castedInitArgs.push_back(initArg); continue; } @@ -995,7 +1007,7 @@ struct WhileOpInterface // The result types of a WhileOp are the same as the "after" bbArg types. SmallVector argsTypesAfter = llvm::map_to_vector( whileOp.getAfterArguments(), [&](BlockArgument bbArg) { - if (!isa(bbArg.getType())) + if (!isa(bbArg.getType())) return bbArg.getType(); // TODO: error handling return llvm::cast( @@ -1048,7 +1060,7 @@ struct WhileOpInterface SmallVector &invocationStack) const { auto whileOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); - assert(isa(value.getType()) && "expected tensor type"); + assert(isa(value.getType()) && "expected tensor type"); // Case 1: Block argument of the "before" region. if (auto bbArg = dyn_cast(value)) { @@ -1074,7 +1086,7 @@ struct WhileOpInterface llvm_unreachable("invalid value"); } Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; - if (!isa(conditionYieldedVal.getType())) { + if (!isa(conditionYieldedVal.getType())) { // scf.condition was already bufferized. return cast(conditionYieldedVal.getType()); } @@ -1103,7 +1115,7 @@ struct WhileOpInterface auto conditionOp = whileOp.getConditionOp(); for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Block *block = conditionOp->getBlock(); - if (!isa(it.value().getType())) + if (!isa(it.value().getType())) continue; if (it.index() >= block->getNumArguments() || !state.areEquivalentBufferizedValues(it.value(), @@ -1116,7 +1128,7 @@ struct WhileOpInterface auto yieldOp = whileOp.getYieldOp(); for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Block *block = yieldOp->getBlock(); - if (!isa(it.value().getType())) + if (!isa(it.value().getType())) continue; if (it.index() >= block->getNumArguments() || !state.areEquivalentBufferizedValues(it.value(), @@ -1176,7 +1188,7 @@ struct YieldOpInterface SmallVector newResults; for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); - if (isa(value.getType())) { + if (isa(value.getType())) { FailureOr maybeBuffer = getBuffer(rewriter, value, options, state); if (failed(maybeBuffer)) diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index d5cb7a0f14f5a..f8d5a1310ebdf 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -905,3 +905,179 @@ func.func @ranked_return_via_unranked_call(%arg0: tensor<64x20x40xf32>) -> tenso return %b : tensor<64x20x40xf32> } func.func private @relu_unranked(tensor<*xf32>) -> tensor<*xf32> + +// ----- + +// CHECK: func.func @custom_types_scf_for_inplace( +// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64>, +// CHECK-SAME: %[[lb:.+]]: index, %[[ub:.+]]: index, %[[step:.+]]: index +// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64> +func.func @custom_types_scf_for_inplace( + %arg: !test.test_tensor<[4, 4], f64>, + %lb: index, %ub: index, %step: index) + -> !test.test_tensor<[4, 4], f64> { + // CHECK: %[[loop:.+]] = scf.for %{{.*}} = %[[lb]] to %[[ub]] step %[[step]] + // CHECK-SAME: iter_args(%[[iter:.+]] = %[[arg]]) -> (!test.test_memref<[4, 4], f64>) { + // CHECK: %[[call:.+]] = "test.dummy_memref_op"(%[[iter]]) + // CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64> + %loop = scf.for %i = %lb to %ub step %step + iter_args(%iter = %arg) -> (!test.test_tensor<[4, 4], f64>) { + // Inside loop: use iter_args directly (this is inplace modifiable op) + %call = "test.dummy_tensor_op"(%iter) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + // Yield: return the same iter_args value (or result of inplace op on it) + scf.yield %call : !test.test_tensor<[4, 4], f64> + } + + // CHECK: return %[[loop]] : !test.test_memref<[4, 4], f64> + return %loop : !test.test_tensor<[4, 4], f64> +} + +// ----- + +func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> { + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + return %out : !test.test_tensor<[4, 4], f64> +} + +// Same as @custom_types_scf_for_inplace, but with an inner call to test alias analysis +// through function boundaries. +// CHECK-LABEL: func.func @custom_types_scf_for_inplace_with_call( +// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: %[[lb:.+]]: index, %[[ub:.+]]: index, %[[step:.+]]: index +// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64> +// CHECK: %[[loop:.+]] = scf.for %{{.*}} = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[iter:.+]] = %[[arg]]) -> (!test.test_memref<[4, 4], f64>) { +// CHECK: %[[call:.+]] = func.call @custom_types_identity_2d(%[[iter]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64> +// CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64> +// CHECK: return %[[loop]] : !test.test_memref<[4, 4], f64> +func.func @custom_types_scf_for_inplace_with_call( + %arg: !test.test_tensor<[4, 4], f64>, + %lb: index, %ub: index, %step: index) + -> !test.test_tensor<[4, 4], f64> { + %loop = scf.for %i = %lb to %ub step %step + iter_args(%iter = %arg) -> (!test.test_tensor<[4, 4], f64>) { + %call = func.call @custom_types_identity_2d(%iter) + : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> + scf.yield %call : !test.test_tensor<[4, 4], f64> + } + + return %loop : !test.test_tensor<[4, 4], f64> +} + +// ----- + +// CHECK-LABEL: func.func @custom_types_scf_if_inplace( +// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: %[[cond:.+]]: i1 +// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64> +// CHECK: %[[res:.+]] = scf.if %[[cond]] -> (!test.test_memref<[4, 4], f64>) { +// CHECK: %[[dummy:.+]] = "test.dummy_memref_op"(%[[arg]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64> +// CHECK: scf.yield %[[dummy]] : !test.test_memref<[4, 4], f64> +// CHECK: } else { +// CHECK: scf.yield %[[arg]] : !test.test_memref<[4, 4], f64> +// CHECK: } +// CHECK: return %[[res]] : !test.test_memref<[4, 4], f64> +func.func @custom_types_scf_if_inplace( + %arg: !test.test_tensor<[4, 4], f64>, + %cond: i1) + -> !test.test_tensor<[4, 4], f64> { + %res = scf.if %cond -> (!test.test_tensor<[4, 4], f64>) { + %dummy = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + scf.yield %dummy : !test.test_tensor<[4, 4], f64> + } else { + scf.yield %arg : !test.test_tensor<[4, 4], f64> + } + return %res : !test.test_tensor<[4, 4], f64> +} + +// ----- + +func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> { + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + return %out : !test.test_tensor<[4, 4], f64> +} + +// CHECK-LABEL: func.func @custom_types_scf_if_inplace_with_call( +// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64> +// CHECK-SAME: %[[cond:.+]]: i1 +// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64> +// CHECK: %[[res:.+]] = scf.if %[[cond]] -> (!test.test_memref<[4, 4], f64>) { +// CHECK: %[[call:.+]] = func.call @custom_types_identity_2d(%[[arg]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64> +// CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64> +// CHECK: } else { +// CHECK: scf.yield %[[arg]] : !test.test_memref<[4, 4], f64> +// CHECK: } +// CHECK: return %[[res]] : !test.test_memref<[4, 4], f64> +func.func @custom_types_scf_if_inplace_with_call( + %arg: !test.test_tensor<[4, 4], f64>, + %cond: i1) + -> !test.test_tensor<[4, 4], f64> { + %res = scf.if %cond -> (!test.test_tensor<[4, 4], f64>) { + %call = func.call @custom_types_identity_2d(%arg) + : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> + scf.yield %call : !test.test_tensor<[4, 4], f64> + } else { + scf.yield %arg : !test.test_tensor<[4, 4], f64> + } + return %res : !test.test_tensor<[4, 4], f64> +} + +// ----- + +// CHECK-LABEL: func.func @scf_while_inplace( +// CHECK-SAME: !test.test_memref<[4, 4], f64> +// CHECK: scf.while +// CHECK: scf.condition +// CHECK: scf.yield +// CHECK: return +func.func @scf_while_inplace( + %arg: !test.test_tensor<[4, 4], f64>, + %cond: i1) + -> !test.test_tensor<[4, 4], f64> { + %loop = scf.while (%iter = %arg) + : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> { + scf.condition(%cond) %iter : !test.test_tensor<[4, 4], f64> + } do { + ^bb0(%current: !test.test_tensor<[4, 4], f64>): + %dummy = "test.dummy_tensor_op"(%current) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + scf.yield %dummy : !test.test_tensor<[4, 4], f64> + } + return %loop : !test.test_tensor<[4, 4], f64> +} + +// ----- + +func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> { + %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>) + -> !test.test_tensor<[4, 4], f64> + return %out : !test.test_tensor<[4, 4], f64> +} + +// CHECK-LABEL: func.func @scf_while_inplace( +// CHECK-SAME: !test.test_memref<[4, 4], f64> +// CHECK: scf.while +// CHECK: scf.condition +// CHECK: scf.yield +// CHECK: return +func.func @scf_while_inplace( + %arg: !test.test_tensor<[4, 4], f64>, + %cond: i1) + -> !test.test_tensor<[4, 4], f64> { + %loop = scf.while (%iter = %arg) + : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> { + scf.condition(%cond) %iter : !test.test_tensor<[4, 4], f64> + } do { + ^bb0(%current: !test.test_tensor<[4, 4], f64>): + %call = func.call @custom_types_identity_2d(%current) + : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> + scf.yield %call : !test.test_tensor<[4, 4], f64> + } + return %loop : !test.test_tensor<[4, 4], f64> +} diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 5fee060689d24..340b44b14dd96 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1796,6 +1796,19 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( return mlir::success(); } +mlir::FailureOr +test::TestDummyTensorOp::getBufferType( + mlir::Value value, const mlir::bufferization::BufferizationOptions &, + const mlir::bufferization::BufferizationState &, + llvm::SmallVector<::mlir::Value> &) { + const auto type = dyn_cast(value.getType()); + if (type == nullptr) + return failure(); + + return cast(test::TestMemrefType::get( + getContext(), type.getShape(), type.getElementType(), nullptr)); +} + ::mlir::LogicalResult test::TestCreateTensorOp::bufferize( ::mlir::RewriterBase &rewriter, const ::mlir::bufferization::BufferizationOptions &options, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 774329b9d2736..348ff5d7f4ea0 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3936,10 +3936,13 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> { // Test Ops bufferization //===----------------------------------------------------------------------===// -def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", - [DeclareOpInterfaceMethods]> { +def TestDummyTensorOp + : TEST_Op<"dummy_tensor_op", + [DeclareOpInterfaceMethods< + BufferizableOpInterface, + ["bufferize", "getBufferType", "bufferizesToMemoryRead", + "bufferizesToMemoryWrite", "getAliasingValues", + "getAliasingOpOperands"]>]> { let arguments = (ins Arg:$input ); @@ -3959,7 +3962,23 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", ::mlir::bufferization::AliasingValueList test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&, const ::mlir::bufferization::AnalysisState&) { - return {}; + auto relation = getInput().getType() == getOutput().getType() + ? ::mlir::bufferization::BufferRelation::Equivalent + : ::mlir::bufferization::BufferRelation::Unknown; + return {{getOutput(), relation, /*isDefinite=*/true}}; + } + + ::mlir::bufferization::AliasingOpOperandList + test::TestDummyTensorOp::getAliasingOpOperands(::mlir::Value value, + const ::mlir::bufferization::AnalysisState&) { + if (value != getOutput()) + return {}; + + auto relation = getInput().getType() == getOutput().getType() + ? ::mlir::bufferization::BufferRelation::Equivalent + : ::mlir::bufferization::BufferRelation::Unknown; + return {{&getOperation()->getOpOperand(0), relation, + /*isDefinite=*/true}}; } }]; } @@ -3973,11 +3992,13 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> { ); } -def TestCreateTensorOp : TEST_Op<"create_tensor_op", - [DeclareOpInterfaceMethods]> { +def TestCreateTensorOp + : TEST_Op<"create_tensor_op", + [DeclareOpInterfaceMethods< + BufferizableOpInterface, + ["bufferize", "getBufferType", "bufferizesToMemoryRead", + "bufferizesToMemoryWrite", "getAliasingValues", + "getAliasingOpOperands", "bufferizesToAllocation"]>]> { let arguments = (ins); let results = (outs Arg:$output); let extraClassDefinition = [{ @@ -3998,6 +4019,14 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op", const ::mlir::bufferization::AnalysisState&) { return {}; } + + ::mlir::bufferization::AliasingOpOperandList + test::TestCreateTensorOp::getAliasingOpOperands( + ::mlir::Value value, + const ::mlir::bufferization::AnalysisState&) { + (void)value; + return {}; + } }]; }