diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index b2436d93c7..492fa96584 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -461,6 +461,12 @@ async def get_runner_async(self, app_name: str) -> Runner: self.runner_dict[app_name] = runner return runner + def _get_root_agent(self, agent_or_app: BaseAgent | App) -> BaseAgent: + """Extract root agent from either a BaseAgent or App object.""" + if isinstance(agent_or_app, App): + return agent_or_app.root_agent + return agent_or_app + def _create_runner(self, agentic_app: App) -> Runner: """Create a runner with common services.""" return Runner( @@ -803,9 +809,8 @@ async def add_session_to_eval_set( # Populate the session with initial session state. agent_or_app = self.agent_loader.load_agent(app_name) - if isinstance(agent_or_app, App): - agent_or_app = agent_or_app.root_agent - initial_session_state = create_empty_state(agent_or_app) + root_agent = self._get_root_agent(agent_or_app) + initial_session_state = create_empty_state(root_agent) new_eval_case = EvalCase( eval_id=req.eval_id, @@ -966,7 +971,8 @@ async def run_eval( status_code=400, detail=f"Eval set `{eval_set_id}` not found." ) - root_agent = self.agent_loader.load_agent(app_name) + agent_or_app = self.agent_loader.load_agent(app_name) + root_agent = self._get_root_agent(agent_or_app) eval_case_results = [] @@ -1305,7 +1311,8 @@ async def get_event_graph( function_calls = event.get_function_calls() function_responses = event.get_function_responses() - root_agent = self.agent_loader.load_agent(app_name) + agent_or_app = self.agent_loader.load_agent(app_name) + root_agent = self._get_root_agent(agent_or_app) dot_graph = None if function_calls: function_call_highlights = []