diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index df89f682fae38..243db9651c4f0 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -258,7 +258,8 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) { // TODO: This does not handle cyclic function call graphs etc. static void equivalenceAnalysis(func::FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - FuncAnalysisState &funcState) { + OneShotAnalysisState &state) { + FuncAnalysisState &funcState = getFuncAnalysisState(state); funcOp->walk([&](func::CallOp callOp) { func::FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called func::FuncOp"); @@ -270,6 +271,8 @@ static void equivalenceAnalysis(func::FuncOp funcOp, for (auto it : funcState.equivalentFuncArgs[calledFunction]) { int64_t returnIdx = it.first; int64_t bbargIdx = it.second; + if (!state.isInPlace(callOp->getOpOperand(bbargIdx))) + continue; Value returnVal = callOp.getResult(returnIdx); Value argVal = callOp->getOperand(bbargIdx); aliasInfo.unionEquivalenceClasses(returnVal, argVal); @@ -409,7 +412,7 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, funcState.startFunctionAnalysis(funcOp); // Gather equivalence info for CallOps. - equivalenceAnalysis(funcOp, aliasInfo, funcState); + equivalenceAnalysis(funcOp, aliasInfo, state); // Analyze funcOp. if (failed(analyzeOp(funcOp, state))) diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir index d617a29c03642..beb3b38da7b0e 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -196,8 +196,9 @@ func.func @call_func_with_non_tensor_return( // CHECK: %[[call:.*]] = call @inner_func(%[[casted]]) %0, %1 = call @inner_func(%t0) : (tensor) -> (tensor, f32) - // Note: The tensor return value has folded away. - // CHECK: return %[[call]] : f32 + // Note: The tensor return value cannot fold away because the CallOp + // bufferized out-of-place. + // CHECK: return %[[call]], %[[alloc]] : f32, memref return %1, %0 : f32, tensor }