diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 824cd26be1..b958dfc3c8 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -66,6 +66,7 @@ DEFAULT_REQUEST_QUEUE_TIMEOUT = 0.25 DEFAULT_TRANSFER_AGENT_DELAY = 1.0 DEFAULT_TASK_COMPLETION_DELAY = 1.0 +DEFAULT_LLM_CLEANUP_TIMEOUT = 5.0 # Statistics configuration DEFAULT_ENABLE_CACHE_STATISTICS = False @@ -751,6 +752,22 @@ async def _call_llm_async( # Calls the LLM. llm = self.__get_llm(invocation_context) + # Determine if this LLM instance was created just for this request + # (needs cleanup) or is a reused instance from the agent (no cleanup). + from ...agents.llm_agent import LlmAgent + from ...models.base_llm import BaseLlm + + needs_cleanup = False + if isinstance(invocation_context.agent, LlmAgent): + agent_model = invocation_context.agent.model + # If agent.model is a string, canonical_model creates a new instance + # that needs cleanup. If agent.model is a BaseLlm instance, it's reused. + needs_cleanup = not isinstance(agent_model, BaseLlm) + logger.debug( + f'LLM cleanup check: agent.model type={type(agent_model).__name__}, ' + f'needs_cleanup={needs_cleanup}, llm type={type(llm).__name__}' + ) + async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: with tracer.start_as_current_span('call_llm'): if invocation_context.run_config.support_cfc: @@ -812,9 +829,35 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: yield llm_response - async with Aclosing(_call_llm_with_tracing()) as agen: - async for event in agen: - yield event + try: + async with Aclosing(_call_llm_with_tracing()) as agen: + async for event in agen: + yield event + finally: + # Clean up the LLM instance if it was created for this request + if needs_cleanup: + try: + import asyncio + + logger.info(f'Cleaning up LLM instance: {type(llm).__name__}') + # Use timeout to prevent hanging on cleanup + await asyncio.wait_for( + llm.aclose(), timeout=DEFAULT_LLM_CLEANUP_TIMEOUT + ) + logger.info( + f'Successfully cleaned up LLM instance: {type(llm).__name__}' + ) + except asyncio.TimeoutError: + logger.warning( + 'LLM cleanup timed out after' + f' {DEFAULT_LLM_CLEANUP_TIMEOUT} seconds' + ) + except Exception as e: + logger.warning(f'Error closing LLM instance: {e}') + else: + logger.debug( + f'Skipping LLM cleanup (reused instance): {type(llm).__name__}' + ) async def _handle_before_model_callback( self, diff --git a/src/google/adk/models/base_llm.py b/src/google/adk/models/base_llm.py index 0f419a9b06..f070edf38f 100644 --- a/src/google/adk/models/base_llm.py +++ b/src/google/adk/models/base_llm.py @@ -203,3 +203,14 @@ def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: raise NotImplementedError( f'Live connection is not supported for {self.model}.' ) + + async def aclose(self) -> None: + """Closes the LLM and releases resources. + + This method provides a lifecycle hook for cleanup when the LLM is no longer + needed. The default implementation is a no-op for backward compatibility. + + Subclasses that manage resources (e.g., HTTP clients) should override this + method to perform proper cleanup. + """ + pass diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 1bdd311104..6ce08fc797 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -394,6 +394,46 @@ def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: headers[key] = ' '.join(value_parts) return headers + @override + async def aclose(self) -> None: + """Closes API clients if they were accessed. + + Checks if the cached_property clients have been instantiated and closes + them if necessary. Uses asyncio.gather to ensure all cleanup attempts + complete even if some fail. + """ + import asyncio + + _CLIENT_CLOSE_TIMEOUT = 10.0 + close_tasks = [] + + def _add_close_task(client): + """Appends the appropriate aclose coroutine to close_tasks.""" + if hasattr(client, 'aio') and hasattr(client.aio, 'aclose'): + close_tasks.append(client.aio.aclose()) + elif hasattr(client, 'aclose'): + close_tasks.append(client.aclose()) + + # Check if api_client was accessed and close it + if 'api_client' in self.__dict__: + _add_close_task(self.__dict__['api_client']) + + # Check if _live_api_client was accessed and close it + if '_live_api_client' in self.__dict__: + _add_close_task(self.__dict__['_live_api_client']) + + # Execute all close operations concurrently with timeout + if close_tasks: + try: + await asyncio.wait_for( + asyncio.gather(*close_tasks, return_exceptions=True), + timeout=_CLIENT_CLOSE_TIMEOUT, + ) + except asyncio.TimeoutError: + logger.warning('Timeout waiting for API clients to close') + except Exception as e: + logger.warning(f'Error during API client cleanup: {e}') + def _build_function_declaration_log( func_decl: types.FunctionDeclaration, diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 2bb0168928..17463e65c9 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -65,6 +65,9 @@ logger = logging.getLogger('google_adk.' + __name__) +# LLM cleanup configuration +_LLM_MODEL_CLEANUP_TIMEOUT = 5.0 + class Runner: """The Runner class is used to run agents. @@ -1311,6 +1314,40 @@ def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]: toolsets.update(self._collect_toolset(sub_agent)) return toolsets + def _collect_llm_models(self, agent: BaseAgent) -> list: + """Recursively collects all LLM model instances from the agent tree. + + Args: + agent: The root agent to collect LLM models from. + + Returns: + A list of unique BaseLlm instances found in the agent tree. + """ + from google.adk.models.base_llm import BaseLlm + + llm_models = [] + seen_ids = set() + + def _collect(current_agent: BaseAgent): + """Helper to recursively collect models.""" + if isinstance(current_agent, LlmAgent): + try: + canonical = current_agent.canonical_model + if isinstance(canonical, BaseLlm): + model_id = id(canonical) + if model_id not in seen_ids: + llm_models.append(canonical) + seen_ids.add(model_id) + except (ValueError, AttributeError): + # Agent might not have a model configured or canonical_model fails + pass + + for sub_agent in current_agent.sub_agents: + _collect(sub_agent) + + _collect(agent) + return llm_models + async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): """Clean up toolsets with proper task context management.""" if not toolsets_to_close: @@ -1341,12 +1378,50 @@ async def _cleanup_toolsets(self, toolsets_to_close: set[BaseToolset]): except Exception as e: logger.error('Error closing toolset %s: %s', type(toolset).__name__, e) + async def _cleanup_llm_models(self, llm_models_to_close: list): + """Clean up LLM models with proper error handling and timeout. + + Args: + llm_models_to_close: List of BaseLlm instances to close. + """ + if not llm_models_to_close: + return + + for llm_model in llm_models_to_close: + try: + logger.info('Closing LLM model: %s', type(llm_model).__name__) + # Use asyncio.wait_for to add timeout protection + await asyncio.wait_for( + llm_model.aclose(), timeout=_LLM_MODEL_CLEANUP_TIMEOUT + ) + logger.info( + 'Successfully closed LLM model: %s', type(llm_model).__name__ + ) + except asyncio.TimeoutError: + logger.warning( + 'LLM model %s cleanup timed out after %s seconds', + type(llm_model).__name__, + _LLM_MODEL_CLEANUP_TIMEOUT, + ) + except Exception as e: + logger.error( + 'Error closing LLM model %s: %s', type(llm_model).__name__, e + ) + async def close(self): - """Closes the runner.""" + """Closes the runner and cleans up all resources. + + Cleans up toolsets first, then LLM models, to ensure proper resource + cleanup order. + """ logger.info('Closing runner...') - # Close Toolsets + # Clean up toolsets first await self._cleanup_toolsets(self._collect_toolset(self.agent)) + # Then clean up LLM models + llm_models_to_close = self._collect_llm_models(self.agent) + await self._cleanup_llm_models(llm_models_to_close) + # Close Plugins if self.plugin_manager: await self.plugin_manager.close()