diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 1e0c321376..740f568bd4 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -74,6 +74,7 @@ from ..memory.base_memory_service import BaseMemoryService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService +from ..sessions.base_session_service import GetSessionConfig from ..sessions.session import Session from ..utils.context_utils import Aclosing from .cli_eval import EVAL_SESSION_ID_PREFIX @@ -164,6 +165,7 @@ class RunAgentRequest(common.BaseModel): new_message: types.Content streaming: bool = False state_delta: Optional[dict[str, Any]] = None + get_session_config: Optional[GetSessionConfig] = None class CreateSessionRequest(common.BaseModel): @@ -1021,6 +1023,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]: session_id=req.session_id, new_message=req.new_message, state_delta=req.state_delta, + get_session_config=req.get_session_config, ) ) as agen: events = [event async for event in agen] @@ -1051,6 +1054,7 @@ async def event_generator(): new_message=req.new_message, state_delta=req.state_delta, run_config=RunConfig(streaming_mode=stream_mode), + get_session_config=req.get_session_config, ) ) as agen: async for event in agen: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bcf29839d1..eafcfe4fa3 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -47,6 +47,7 @@ from .plugins.base_plugin import BasePlugin from .plugins.plugin_manager import PluginManager from .sessions.base_session_service import BaseSessionService +from .sessions.base_session_service import GetSessionConfig from .sessions.in_memory_session_service import InMemorySessionService from .sessions.session import Session from .telemetry import tracer @@ -125,6 +126,7 @@ def run( session_id: str, new_message: types.Content, run_config: RunConfig = RunConfig(), + get_session_config: Optional[GetSessionConfig] = None, ) -> Generator[Event, None, None]: """Runs the agent. @@ -137,6 +139,8 @@ def run( session_id: The session ID of the session. new_message: A new message to append to the session. run_config: The run config for the agent. + get_session_config: Configuration for retrieving the session, allowing for + limiting the number of events returned. Yields: The events generated by the agent. @@ -151,6 +155,7 @@ async def _invoke_run_async(): session_id=session_id, new_message=new_message, run_config=run_config, + get_session_config=get_session_config, ) ) as agen: async for event in agen: @@ -185,6 +190,7 @@ async def run_async( new_message: types.Content, state_delta: Optional[dict[str, Any]] = None, run_config: RunConfig = RunConfig(), + get_session_config: Optional[GetSessionConfig] = None, ) -> AsyncGenerator[Event, None]: """Main entry method to run the agent in this runner. @@ -193,6 +199,8 @@ async def run_async( session_id: The session ID of the session. new_message: A new message to append to the session. run_config: The run config for the agent. + get_session_config: Configuration for retrieving the session, allowing for + limiting the number of events returned. Yields: The events generated by the agent. @@ -203,7 +211,10 @@ async def _run_with_trace( ) -> AsyncGenerator[Event, None]: with tracer.start_as_current_span('invocation'): session = await self.session_service.get_session( - app_name=self.app_name, user_id=user_id, session_id=session_id + app_name=self.app_name, + user_id=user_id, + session_id=session_id, + config=get_session_config, ) if not session: raise ValueError(f'Session not found: {session_id}')