diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index cce7e99b32..8b5622745c 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -162,26 +162,25 @@ async def get_session( **list_events_kwargs, ), ) + if get_session_response.user_id != user_id: + raise ValueError( + f'Session {session_id} does not belong to user {user_id}.' + ) - if get_session_response.user_id != user_id: - raise ValueError( - f'Session {session_id} does not belong to user {user_id}.' + update_timestamp = get_session_response.update_time.timestamp() + session = Session( + app_name=app_name, + user_id=user_id, + id=session_id, + state=getattr(get_session_response, 'session_state', None) or {}, + last_update_time=update_timestamp, ) - - update_timestamp = get_session_response.update_time.timestamp() - session = Session( - app_name=app_name, - user_id=user_id, - id=session_id, - state=getattr(get_session_response, 'session_state', None) or {}, - last_update_time=update_timestamp, - ) - # Preserve the entire event stream that Vertex returns rather than trying - # to discard events written milliseconds after the session resource was - # updated. Clock skew between those writes can otherwise drop tool_result - # events and permanently break the replayed conversation. - async for event in events_iterator: - session.events.append(_from_api_event(event)) + # Preserve the entire event stream that Vertex returns rather than trying + # to discard events written milliseconds after the session resource was + # updated. Clock skew between those writes can otherwise drop tool_result + # events and permanently break the replayed conversation. + async for event in events_iterator: + session.events.append(_from_api_event(event)) if config: # Filter events based on num_recent_events. diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 14d2b15b6e..b30a0cffcd 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -397,6 +397,103 @@ async def _append_event( self.event_dict[session_id] = ([event_json], None) +class MockAsyncClientWithPagination: + """Mock client that simulates pagination requiring an open client connection. + + This mock tracks whether the client context is active and raises RuntimeError + if iteration occurs outside the context, simulating the real httpx behavior. + """ + + def __init__(self, session_data: dict, events_pages: list[list[dict]]): + self._session_data = session_data + self._events_pages = events_pages + self._context_active = False + self.agent_engines = mock.AsyncMock() + self.agent_engines.sessions.get.side_effect = self._get_session + self.agent_engines.sessions.events.list.side_effect = self._list_events + + async def __aenter__(self): + self._context_active = True + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self._context_active = False + + async def _get_session(self, name: str): + return _convert_to_object(self._session_data) + + async def _list_events(self, name: str, **kwargs): + return self._paginated_events_iterator() + + async def _paginated_events_iterator(self): + for page in self._events_pages: + for event in page: + if not self._context_active: + raise RuntimeError( + 'Cannot send a request, as the client has been closed.' + ) + yield _convert_to_object(event) + + +def _generate_events_for_page(session_id: str, start_idx: int, count: int): + events = [] + start_time = isoparse('2024-12-12T12:12:12.123456Z') + for i in range(count): + idx = start_idx + i + event_time = start_time + datetime.timedelta(microseconds=idx * 1000) + events.append({ + 'name': ( + 'projects/test-project/locations/test-location/' + f'reasoningEngines/123/sessions/{session_id}/events/{idx}' + ), + 'invocation_id': f'invocation_{idx}', + 'author': 'pagination_user', + 'timestamp': event_time.isoformat().replace('+00:00', 'Z'), + }) + return events + + +@pytest.mark.asyncio +async def test_get_session_pagination_keeps_client_open(): + """Regression test: event iteration must occur inside the api_client context. + + This test verifies that get_session() keeps the API client open while + iterating through paginated events. Before the fix, the events_iterator + was consumed outside the async with block, causing RuntimeError when + fetching subsequent pages. + """ + session_data = { + 'name': ( + 'projects/test-project/locations/test-location/' + 'reasoningEngines/123/sessions/pagination_test' + ), + 'update_time': '2024-12-12T12:12:12.123456Z', + 'user_id': 'pagination_user', + } + page1_events = _generate_events_for_page('pagination_test', 0, 100) + page2_events = _generate_events_for_page('pagination_test', 100, 100) + page3_events = _generate_events_for_page('pagination_test', 200, 50) + + mock_client = MockAsyncClientWithPagination( + session_data=session_data, + events_pages=[page1_events, page2_events, page3_events], + ) + + session_service = mock_vertex_ai_session_service() + + with mock.patch.object( + session_service, '_get_api_client', return_value=mock_client + ): + session = await session_service.get_session( + app_name='123', user_id='pagination_user', session_id='pagination_test' + ) + + assert session is not None + assert len(session.events) == 250 + assert session.events[0].invocation_id == 'invocation_0' + assert session.events[249].invocation_id == 'invocation_249' + + def mock_vertex_ai_session_service( project: Optional[str] = 'test-project', location: Optional[str] = 'test-location',