Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,26 +162,25 @@ async def get_session(
**list_events_kwargs,
),
)
if get_session_response.user_id != user_id:
raise ValueError(
f'Session {session_id} does not belong to user {user_id}.'
)

if get_session_response.user_id != user_id:
raise ValueError(
f'Session {session_id} does not belong to user {user_id}.'
update_timestamp = get_session_response.update_time.timestamp()
session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=getattr(get_session_response, 'session_state', None) or {},
last_update_time=update_timestamp,
)

update_timestamp = get_session_response.update_time.timestamp()
session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=getattr(get_session_response, 'session_state', None) or {},
last_update_time=update_timestamp,
)
# Preserve the entire event stream that Vertex returns rather than trying
# to discard events written milliseconds after the session resource was
# updated. Clock skew between those writes can otherwise drop tool_result
# events and permanently break the replayed conversation.
async for event in events_iterator:
session.events.append(_from_api_event(event))
# Preserve the entire event stream that Vertex returns rather than trying
# to discard events written milliseconds after the session resource was
# updated. Clock skew between those writes can otherwise drop tool_result
# events and permanently break the replayed conversation.
async for event in events_iterator:
session.events.append(_from_api_event(event))

if config:
# Filter events based on num_recent_events.
Expand Down
96 changes: 96 additions & 0 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,102 @@ async def _append_event(
else:
self.event_dict[session_id] = ([event_json], None)

class MockAsyncClientWithPagination:
"""Mock client that simulates pagination requiring an open client connection.

This mock tracks whether the client context is active and raises RuntimeError
if iteration occurs outside the context, simulating the real httpx behavior.
"""

def __init__(self, session_data: dict, events_pages: list[list[dict]]):
self._session_data = session_data
self._events_pages = events_pages
self._context_active = False
self.agent_engines = mock.AsyncMock()
self.agent_engines.sessions.get.side_effect = self._get_session
self.agent_engines.sessions.events.list.side_effect = self._list_events

async def __aenter__(self):
self._context_active = True
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._context_active = False

async def _get_session(self, name: str):
return _convert_to_object(self._session_data)

async def _list_events(self, name: str, **kwargs):
return self._paginated_events_iterator()

async def _paginated_events_iterator(self):
for page in self._events_pages:
for event in page:
if not self._context_active:
raise RuntimeError(
'Cannot send a request, as the client has been closed.'
)
yield _convert_to_object(event)


def _generate_events_for_page(session_id: str, start_idx: int, count: int):
events = []
start_time = isoparse('2024-12-12T12:12:12.123456Z')
for i in range(count):
idx = start_idx + i
event_time = start_time + datetime.timedelta(microseconds=idx * 1000)
events.append({
'name': (
'projects/test-project/locations/test-location/'
f'reasoningEngines/123/sessions/{session_id}/events/{idx}'
),
'invocation_id': f'invocation_{idx}',
'author': 'pagination_user',
'timestamp': event_time.isoformat().replace('+00:00', 'Z'),
})
return events


@pytest.mark.asyncio
async def test_get_session_pagination_keeps_client_open():
"""Regression test: event iteration must occur inside the api_client context.

This test verifies that get_session() keeps the API client open while
iterating through paginated events. Before the fix, the events_iterator
was consumed outside the async with block, causing RuntimeError when
fetching subsequent pages.
"""
session_data = {
'name': (
'projects/test-project/locations/test-location/'
'reasoningEngines/123/sessions/pagination_test'
),
'update_time': '2024-12-12T12:12:12.123456Z',
'user_id': 'pagination_user',
}
page1_events = _generate_events_for_page('pagination_test', 0, 100)
page2_events = _generate_events_for_page('pagination_test', 100, 100)
page3_events = _generate_events_for_page('pagination_test', 200, 50)

mock_client = MockAsyncClientWithPagination(
session_data=session_data,
events_pages=[page1_events, page2_events, page3_events],
)

session_service = mock_vertex_ai_session_service()

with mock.patch.object(
session_service, '_get_api_client', return_value=mock_client
):
session = await session_service.get_session(
app_name='123', user_id='pagination_user', session_id='pagination_test'
)

assert session is not None
assert len(session.events) == 250
assert session.events[0].invocation_id == 'invocation_0'
assert session.events[249].invocation_id == 'invocation_249'


def mock_vertex_ai_session_service(
project: Optional[str] = 'test-project',
Expand Down