diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 23846346d6386..6cdd8b4942158 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -111,9 +111,14 @@ getFuncAnalysisState(const AnalysisState &state) { /// Return the state (phase) of analysis of the FuncOp. static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp) { - const FuncAnalysisState &funcState = getFuncAnalysisState(state); - auto it = funcState.analyzedFuncOps.find(funcOp); - if (it == funcState.analyzedFuncOps.end()) + Optional maybeState = + state.getDialectState( + func::FuncDialect::getDialectNamespace()); + if (!maybeState.hasValue()) + return FuncOpAnalysisState::NotAnalyzed; + const auto &analyzedFuncOps = maybeState.getValue()->analyzedFuncOps; + auto it = analyzedFuncOps.find(funcOp); + if (it == analyzedFuncOps.end()) return FuncOpAnalysisState::NotAnalyzed; return it->second; } @@ -145,11 +150,11 @@ struct CallOpInterface FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is read. return true; + const FuncAnalysisState &funcState = getFuncAnalysisState(state); return funcState.readBbArgs.lookup(funcOp).contains( opOperand.getOperandNumber()); } @@ -160,11 +165,11 @@ struct CallOpInterface FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is written. return true; + const FuncAnalysisState &funcState = getFuncAnalysisState(state); return funcState.writtenBbArgs.lookup(funcOp).contains( opOperand.getOperandNumber()); } @@ -174,7 +179,6 @@ struct CallOpInterface func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) { // FuncOp not analyzed yet. Any OpResult may be aliasing. @@ -186,6 +190,7 @@ struct CallOpInterface } // Get aliasing results from state. + const FuncAnalysisState &funcState = getFuncAnalysisState(state); auto aliasingReturnVals = funcState.aliasingReturnVals.lookup(funcOp).lookup( opOperand.getOperandNumber()); @@ -201,7 +206,6 @@ struct CallOpInterface func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) { // FuncOp not analyzed yet. Any OpOperand may be aliasing. @@ -213,6 +217,7 @@ struct CallOpInterface } // Get aliasing bbArgs from state. + const FuncAnalysisState &funcState = getFuncAnalysisState(state); auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( opResult.getResultNumber()); SmallVector result; @@ -226,13 +231,13 @@ struct CallOpInterface func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) { // Function not analyzed yet. The conservative answer is "None". return BufferRelation::None; } + const FuncAnalysisState &funcState = getFuncAnalysisState(state); Optional maybeEquiv = getEquivalentFuncArgIdx(funcOp, funcState, opResult.getResultNumber()); if (maybeEquiv.hasValue()) {