Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
logger = logging.getLogger('google_adk.' + __name__)

_COMPACTION_CUSTOM_METADATA_KEY = '_compaction'
_REWIND_CUSTOM_METADATA_KEY = '_rewind_before_invocation_id'
_USAGE_METADATA_CUSTOM_METADATA_KEY = '_usage_metadata'


Expand Down Expand Up @@ -284,7 +285,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
},
# TODO: add requested_tool_confirmations, agent_state once
# they are available in the API.
# Note: compaction is stored via event_metadata.custom_metadata.
# Note: compaction and rewind_before_invocation_id are stored via
# event_metadata.custom_metadata.
}
if event.error_code:
config['error_code'] = event.error_code
Expand Down Expand Up @@ -320,6 +322,16 @@ async def append_event(self, session: Session, event: Event) -> Event:
key=_COMPACTION_CUSTOM_METADATA_KEY,
value=compaction_dict,
)
# Store rewind_before_invocation_id in custom_metadata since the Vertex AI
# service does not yet support the field in EventActions.
# TODO: Stop writing to custom_metadata once the Vertex AI service
# supports rewind_before_invocation_id natively in EventActions.
if event.actions and event.actions.rewind_before_invocation_id:
_set_internal_custom_metadata(
metadata_dict,
key=_REWIND_CUSTOM_METADATA_KEY,
value=event.actions.rewind_before_invocation_id,
)
# Store usage_metadata in custom_metadata since the Vertex AI service
# does not persist it in EventMetadata.
if event.usage_metadata:
Expand Down Expand Up @@ -405,15 +417,20 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
# written before native compaction support store compaction data
# in custom_metadata under the compaction metadata key.
compaction_data = None
rewind_data = None
usage_metadata_data = None
if custom_metadata and (
_COMPACTION_CUSTOM_METADATA_KEY in custom_metadata
or _REWIND_CUSTOM_METADATA_KEY in custom_metadata
or _USAGE_METADATA_CUSTOM_METADATA_KEY in custom_metadata
):
custom_metadata = dict(custom_metadata) # avoid mutating the API response
compaction_data = custom_metadata.pop(
_COMPACTION_CUSTOM_METADATA_KEY, None
)
rewind_data = custom_metadata.pop(
_REWIND_CUSTOM_METADATA_KEY, None
)
usage_metadata_data = custom_metadata.pop(
_USAGE_METADATA_CUSTOM_METADATA_KEY, None
)
Expand All @@ -431,6 +448,7 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
branch = None
custom_metadata = None
compaction_data = None
rewind_data = None
usage_metadata_data = None
grounding_metadata = None

Expand All @@ -442,11 +460,18 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event:
}
if compaction_data:
renamed_actions_dict['compaction'] = compaction_data
if rewind_data:
renamed_actions_dict['rewind_before_invocation_id'] = rewind_data
event_actions = EventActions.model_validate(renamed_actions_dict)
else:
if compaction_data:
if compaction_data or rewind_data:
event_actions = EventActions(
compaction=EventCompaction.model_validate(compaction_data)
compaction=(
EventCompaction.model_validate(compaction_data)
if compaction_data
else None
),
rewind_before_invocation_id=rewind_data,
)
else:
event_actions = EventActions()
Expand Down
29 changes: 29 additions & 0 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,35 @@ async def test_append_event():
assert retrieved_session.events[1] == event_to_append


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_event_with_rewind():
"""rewind_before_invocation_id round-trips through append_event and get_session."""
session_service = mock_vertex_ai_session_service()
session = await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)
event_to_append = Event(
invocation_id='rewind_invocation',
author='model',
timestamp=1734005533.0,
actions=EventActions(
rewind_before_invocation_id='target_invocation',
),
)

await session_service.append_event(session, event_to_append)

retrieved_session = await session_service.get_session(
app_name='123', user_id='user', session_id='1'
)

appended_event = retrieved_session.events[-1]
assert (
appended_event.actions.rewind_before_invocation_id == 'target_invocation'
)


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_event_with_compaction():
Expand Down