diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 1b422fe335..117b0e1aa9 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -192,6 +192,10 @@ class CreateSessionRequest(common.BaseModel): default=None, description="A list of events to initialize the session with.", ) + display_name: Optional[str] = Field( + default=None, + description="The display name of the session.", + ) class AddSessionToEvalSetRequest(common.BaseModel): @@ -594,6 +598,7 @@ async def _create_session( user_id: str, session_id: Optional[str] = None, state: Optional[dict[str, Any]] = None, + display_name: Optional[str] = None, ) -> Session: try: session = await self.session_service.create_session( @@ -601,6 +606,7 @@ async def _create_session( user_id=user_id, state=state, session_id=session_id, + display_name=display_name, ) logger.info("New session created: %s", session.id) return session @@ -795,6 +801,7 @@ async def create_session( user_id=user_id, state=req.state, session_id=req.session_id, + display_name=req.display_name, ) if req.events: diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index f2f6f9f22d..a753968268 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -56,6 +56,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: """Creates a new session. @@ -65,6 +66,7 @@ async def create_session( state: the initial state of the session. session_id: the client-provided id of the session. If not provided, a generated ID will be used. + display_name: the display name of the session. Returns: session: The newly created session instance. diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 91c22fd21e..d0fd66d68b 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -174,6 +174,10 @@ class StorageSession(Base): PreciseTimestamp, default=func.now(), onupdate=func.now() ) + display_name: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + storage_events: Mapped[list[StorageEvent]] = relationship( "StorageEvent", back_populates="storage_session", @@ -215,6 +219,7 @@ def to_session( state=state, events=events, last_update_time=self.update_timestamp_tz, + display_name=self.display_name, ) @@ -477,6 +482,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: # 1. Populate states. # 2. Build storage session object @@ -526,6 +532,7 @@ async def create_session( user_id=user_id, id=session_id, state=session_state, + display_name=display_name, ) sql_session.add(storage_session) await sql_session.commit() diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 6ba7f0bb01..c81a2f74e0 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -58,12 +58,14 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: return self._create_session_impl( app_name=app_name, user_id=user_id, state=state, session_id=session_id, + display_name=display_name, ) def create_session_sync( @@ -73,6 +75,7 @@ def create_session_sync( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: logger.warning('Deprecated. Please migrate to the async method.') return self._create_session_impl( @@ -80,6 +83,7 @@ def create_session_sync( user_id=user_id, state=state, session_id=session_id, + display_name=display_name, ) def _create_session_impl( @@ -89,6 +93,7 @@ def _create_session_impl( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: if session_id and self._get_session_impl( app_name=app_name, user_id=user_id, session_id=session_id @@ -116,6 +121,7 @@ def _create_session_impl( id=session_id, state=session_state or {}, last_update_time=time.time(), + display_name=display_name, ) if app_name not in self.sessions: diff --git a/src/google/adk/sessions/session.py b/src/google/adk/sessions/session.py index e674dd3778..6e664f93d2 100644 --- a/src/google/adk/sessions/session.py +++ b/src/google/adk/sessions/session.py @@ -15,6 +15,7 @@ from __future__ import annotations from typing import Any +from typing import Optional from pydantic import alias_generators from pydantic import BaseModel @@ -48,3 +49,5 @@ class Session(BaseModel): call/response, etc.""" last_update_time: float = 0.0 """The last update time of the session.""" + display_name: Optional[str] = None + """The display name of the session.""" diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 10d05f6dfd..8aa96b2caa 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -66,6 +66,7 @@ state TEXT NOT NULL, create_time REAL NOT NULL, update_time REAL NOT NULL, + display_name TEXT, PRIMARY KEY (app_name, user_id, id) ); """ @@ -121,6 +122,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, ) -> Session: if session_id: session_id = session_id.strip() @@ -160,8 +162,8 @@ async def create_session( # Store the session await db.execute( """ - INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time, display_name) + VALUES (?, ?, ?, ?, ?, ?, ?) """, ( app_name, @@ -170,6 +172,7 @@ async def create_session( json.dumps(session_state), now, now, + display_name, ), ) await db.commit() @@ -185,6 +188,7 @@ async def create_session( state=merged_state, events=[], last_update_time=now, + display_name=display_name, ) @override @@ -198,8 +202,8 @@ async def get_session( ) -> Optional[Session]: async with self._get_db_connection() as db: async with db.execute( - "SELECT state, update_time FROM sessions WHERE app_name=? AND" - " user_id=? AND id=?", + "SELECT state, update_time, display_name FROM sessions WHERE" + " app_name=? AND user_id=? AND id=?", (app_name, user_id, session_id), ) as cursor: session_row = await cursor.fetchone() @@ -207,6 +211,7 @@ async def get_session( return None session_state = json.loads(session_row["state"]) last_update_time = session_row["update_time"] + display_name = session_row["display_name"] # Build events query query_parts = [ @@ -248,6 +253,7 @@ async def get_session( state=merged_state, events=events, last_update_time=last_update_time, + display_name=display_name, ) @override @@ -259,14 +265,14 @@ async def list_sessions( # Fetch sessions if user_id: session_rows = await db.execute_fetchall( - "SELECT id, user_id, state, update_time FROM sessions WHERE" - " app_name=? AND user_id=?", + "SELECT id, user_id, state, update_time, display_name FROM sessions" + " WHERE app_name=? AND user_id=?", (app_name, user_id), ) else: session_rows = await db.execute_fetchall( - "SELECT id, user_id, state, update_time FROM sessions WHERE" - " app_name=?", + "SELECT id, user_id, state, update_time, display_name FROM sessions" + " WHERE app_name=?", (app_name,), ) @@ -301,6 +307,7 @@ async def list_sessions( state=merged_state, events=[], last_update_time=row["update_time"], + display_name=row["display_name"], ) ) return ListSessionsResponse(sessions=sessions_list) diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 571ae53a59..77aec4d923 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -83,6 +83,7 @@ async def create_session( user_id: str, state: Optional[dict[str, Any]] = None, session_id: Optional[str] = None, + display_name: Optional[str] = None, **kwargs: Any, ) -> Session: """Creates a new session. @@ -92,6 +93,7 @@ async def create_session( user_id: The ID of the user. state: The initial state of the session. session_id: The ID of the session. + display_name: An optional display name for the session. **kwargs: Additional arguments to pass to the session creation. E.g. set expire_time='2025-10-01T00:00:00Z' to set the session expiration time. See https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1beta1/projects.locations.reasoningEngines.sessions @@ -109,6 +111,8 @@ async def create_session( reasoning_engine_id = self._get_reasoning_engine_id(app_name) config = {'session_state': state} if state else {} + if display_name is not None: + config['display_name'] = display_name config.update(kwargs) async with self._get_api_client() as api_client: api_response = await api_client.agent_engines.sessions.create( @@ -126,6 +130,7 @@ async def create_session( id=session_id, state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=get_session_response.update_time.timestamp(), + display_name=getattr(get_session_response, 'display_name', None), ) return session @@ -175,6 +180,7 @@ async def get_session( id=session_id, state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=update_timestamp, + display_name=getattr(get_session_response, 'display_name', None), ) session.events += [ _from_api_event(event) @@ -213,6 +219,7 @@ async def list_sessions( id=api_session.name.split('/')[-1], state=getattr(api_session, 'session_state', None) or {}, last_update_time=api_session.update_time.timestamp(), + display_name=getattr(api_session, 'display_name', None), ) ) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 7fb91c9db6..7f5cc0d1ff 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -83,6 +83,7 @@ async def test_create_get_session(service_type, tmp_path): assert session.user_id == user_id assert session.id assert session.state == state + assert session.display_name is None assert ( session.last_update_time <= datetime.now().astimezone(timezone.utc).timestamp() @@ -618,3 +619,41 @@ async def test_partial_events_are_not_persisted(service_type, tmp_path): app_name=app_name, user_id=user_id, session_id=session.id ) assert len(session_got.events) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', + [ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ], +) +async def test_create_session_with_display_name(service_type, tmp_path): + """Test that display_name is properly stored and retrieved.""" + session_service = get_session_service(service_type, tmp_path) + app_name = 'my_app' + user_id = 'test_user' + display_name = 'My Test Session' + + # Create a session with a display_name + session = await session_service.create_session( + app_name=app_name, + user_id=user_id, + display_name=display_name, + ) + assert session.display_name == display_name + + # Verify display_name is persisted when fetching the session + got_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert got_session.display_name == display_name + + # Verify display_name appears in list_sessions + list_response = await session_service.list_sessions( + app_name=app_name, user_id=user_id + ) + assert len(list_response.sessions) == 1 + assert list_response.sessions[0].display_name == display_name diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 29c3c74c0d..6f14db37ca 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -576,6 +576,22 @@ async def test_create_session_with_custom_config(mock_api_client_instance): ) +@pytest.mark.asyncio +@pytest.mark.usefixtures('mock_get_api_client') +async def test_create_session_with_display_name(mock_api_client_instance): + """Test that display_name parameter is passed through to the API.""" + session_service = mock_vertex_ai_session_service() + + display_name = 'Display Name' + await session_service.create_session( + app_name='123', user_id='user', display_name=display_name + ) + assert ( + mock_api_client_instance.last_create_session_config['display_name'] + == display_name + ) + + @pytest.mark.asyncio @pytest.mark.usefixtures('mock_get_api_client') async def test_append_event():