Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
7 changes: 7 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ class CreateSessionRequest(common.BaseModel):
default=None,
description="A list of events to initialize the session with.",
)
display_name: Optional[str] = Field(
default=None,
description="The display name of the session.",
)


class AddSessionToEvalSetRequest(common.BaseModel):
Expand Down Expand Up @@ -594,13 +598,15 @@ async def _create_session(
user_id: str,
session_id: Optional[str] = None,
state: Optional[dict[str, Any]] = None,
display_name: Optional[str] = None,
) -> Session:
try:
session = await self.session_service.create_session(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
)
logger.info("New session created: %s", session.id)
return session
Expand Down Expand Up @@ -795,6 +801,7 @@ async def create_session(
user_id=user_id,
state=req.state,
session_id=req.session_id,
display_name=req.display_name,
)

if req.events:
Expand Down
2 changes: 2 additions & 0 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
"""Creates a new session.

Expand All @@ -65,6 +66,7 @@ async def create_session(
state: the initial state of the session.
session_id: the client-provided id of the session. If not provided, a
generated ID will be used.
display_name: the display name of the session.

Returns:
session: The newly created session instance.
Expand Down
7 changes: 7 additions & 0 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ class StorageSession(Base):
PreciseTimestamp, default=func.now(), onupdate=func.now()
)

display_name: Mapped[Optional[str]] = mapped_column(
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
)

storage_events: Mapped[list[StorageEvent]] = relationship(
"StorageEvent",
back_populates="storage_session",
Expand Down Expand Up @@ -215,6 +219,7 @@ def to_session(
state=state,
events=events,
last_update_time=self.update_timestamp_tz,
display_name=self.display_name,
)


Expand Down Expand Up @@ -477,6 +482,7 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
# 1. Populate states.
# 2. Build storage session object
Expand Down Expand Up @@ -526,6 +532,7 @@ async def create_session(
user_id=user_id,
id=session_id,
state=session_state,
display_name=display_name,
)
sql_session.add(storage_session)
await sql_session.commit()
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
return self._create_session_impl(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
)

def create_session_sync(
Expand All @@ -73,13 +75,15 @@ def create_session_sync(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
logger.warning('Deprecated. Please migrate to the async method.')
return self._create_session_impl(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
display_name=display_name,
)

def _create_session_impl(
Expand All @@ -89,6 +93,7 @@ def _create_session_impl(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
if session_id and self._get_session_impl(
app_name=app_name, user_id=user_id, session_id=session_id
Expand Down Expand Up @@ -116,6 +121,7 @@ def _create_session_impl(
id=session_id,
state=session_state or {},
last_update_time=time.time(),
display_name=display_name,
)

if app_name not in self.sessions:
Expand Down
3 changes: 3 additions & 0 deletions src/google/adk/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from typing import Any
from typing import Optional

from pydantic import alias_generators
from pydantic import BaseModel
Expand Down Expand Up @@ -48,3 +49,5 @@ class Session(BaseModel):
call/response, etc."""
last_update_time: float = 0.0
"""The last update time of the session."""
display_name: Optional[str] = None
"""The display name of the session."""
23 changes: 15 additions & 8 deletions src/google/adk/sessions/sqlite_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
state TEXT NOT NULL,
create_time REAL NOT NULL,
update_time REAL NOT NULL,
display_name TEXT,
PRIMARY KEY (app_name, user_id, id)
);
"""
Expand Down Expand Up @@ -121,6 +122,7 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
) -> Session:
if session_id:
session_id = session_id.strip()
Expand Down Expand Up @@ -160,8 +162,8 @@ async def create_session(
# Store the session
await db.execute(
"""
INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time)
VALUES (?, ?, ?, ?, ?, ?)
INSERT INTO sessions (app_name, user_id, id, state, create_time, update_time, display_name)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(
app_name,
Expand All @@ -170,6 +172,7 @@ async def create_session(
json.dumps(session_state),
now,
now,
display_name,
),
)
await db.commit()
Expand All @@ -185,6 +188,7 @@ async def create_session(
state=merged_state,
events=[],
last_update_time=now,
display_name=display_name,
)

@override
Expand All @@ -198,15 +202,16 @@ async def get_session(
) -> Optional[Session]:
async with self._get_db_connection() as db:
async with db.execute(
"SELECT state, update_time FROM sessions WHERE app_name=? AND"
" user_id=? AND id=?",
"SELECT state, update_time, display_name FROM sessions WHERE"
" app_name=? AND user_id=? AND id=?",
(app_name, user_id, session_id),
) as cursor:
session_row = await cursor.fetchone()
if session_row is None:
return None
session_state = json.loads(session_row["state"])
last_update_time = session_row["update_time"]
display_name = session_row["display_name"]

# Build events query
query_parts = [
Expand Down Expand Up @@ -248,6 +253,7 @@ async def get_session(
state=merged_state,
events=events,
last_update_time=last_update_time,
display_name=display_name,
)

@override
Expand All @@ -259,14 +265,14 @@ async def list_sessions(
# Fetch sessions
if user_id:
session_rows = await db.execute_fetchall(
"SELECT id, user_id, state, update_time FROM sessions WHERE"
" app_name=? AND user_id=?",
"SELECT id, user_id, state, update_time, display_name FROM sessions"
" WHERE app_name=? AND user_id=?",
(app_name, user_id),
)
else:
session_rows = await db.execute_fetchall(
"SELECT id, user_id, state, update_time FROM sessions WHERE"
" app_name=?",
"SELECT id, user_id, state, update_time, display_name FROM sessions"
" WHERE app_name=?",
(app_name,),
)

Expand Down Expand Up @@ -301,6 +307,7 @@ async def list_sessions(
state=merged_state,
events=[],
last_update_time=row["update_time"],
display_name=row["display_name"],
)
)
return ListSessionsResponse(sessions=sessions_list)
Expand Down
7 changes: 7 additions & 0 deletions src/google/adk/sessions/vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def create_session(
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
display_name: Optional[str] = None,
**kwargs: Any,
) -> Session:
"""Creates a new session.
Expand All @@ -92,6 +93,7 @@ async def create_session(
user_id: The ID of the user.
state: The initial state of the session.
session_id: The ID of the session.
display_name: An optional display name for the session.
**kwargs: Additional arguments to pass to the session creation. E.g. set
expire_time='2025-10-01T00:00:00Z' to set the session expiration time.
See https://cloud.google.com/vertex-ai/generative-ai/docs/reference/rest/v1beta1/projects.locations.reasoningEngines.sessions
Expand All @@ -109,6 +111,8 @@ async def create_session(
reasoning_engine_id = self._get_reasoning_engine_id(app_name)

config = {'session_state': state} if state else {}
if display_name is not None:
config['display_name'] = display_name
config.update(kwargs)
async with self._get_api_client() as api_client:
api_response = await api_client.agent_engines.sessions.create(
Expand All @@ -126,6 +130,7 @@ async def create_session(
id=session_id,
state=getattr(get_session_response, 'session_state', None) or {},
last_update_time=get_session_response.update_time.timestamp(),
display_name=getattr(get_session_response, 'display_name', None),
)
return session

Expand Down Expand Up @@ -175,6 +180,7 @@ async def get_session(
id=session_id,
state=getattr(get_session_response, 'session_state', None) or {},
last_update_time=update_timestamp,
display_name=getattr(get_session_response, 'display_name', None),
)
session.events += [
_from_api_event(event)
Expand Down Expand Up @@ -213,6 +219,7 @@ async def list_sessions(
id=api_session.name.split('/')[-1],
state=getattr(api_session, 'session_state', None) or {},
last_update_time=api_session.update_time.timestamp(),
display_name=getattr(api_session, 'display_name', None),
)
)

Expand Down
39 changes: 39 additions & 0 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ async def test_create_get_session(service_type, tmp_path):
assert session.user_id == user_id
assert session.id
assert session.state == state
assert session.display_name is None
assert (
session.last_update_time
<= datetime.now().astimezone(timezone.utc).timestamp()
Expand Down Expand Up @@ -618,3 +619,41 @@ async def test_partial_events_are_not_persisted(service_type, tmp_path):
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(session_got.events) == 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
'service_type',
[
SessionServiceType.IN_MEMORY,
SessionServiceType.DATABASE,
SessionServiceType.SQLITE,
],
)
async def test_create_session_with_display_name(service_type, tmp_path):
"""Test that display_name is properly stored and retrieved."""
session_service = get_session_service(service_type, tmp_path)
app_name = 'my_app'
user_id = 'test_user'
display_name = 'My Test Session'

# Create a session with a display_name
session = await session_service.create_session(
app_name=app_name,
user_id=user_id,
display_name=display_name,
)
assert session.display_name == display_name

# Verify display_name is persisted when fetching the session
got_session = await session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert got_session.display_name == display_name

# Verify display_name appears in list_sessions
list_response = await session_service.list_sessions(
app_name=app_name, user_id=user_id
)
assert len(list_response.sessions) == 1
assert list_response.sessions[0].display_name == display_name
16 changes: 16 additions & 0 deletions tests/unittests/sessions/test_vertex_ai_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,22 @@ async def test_create_session_with_custom_config(mock_api_client_instance):
)


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_create_session_with_display_name(mock_api_client_instance):
"""Test that display_name parameter is passed through to the API."""
session_service = mock_vertex_ai_session_service()

display_name = 'Display Name'
await session_service.create_session(
app_name='123', user_id='user', display_name=display_name
)
assert (
mock_api_client_instance.last_create_session_config['display_name']
== display_name
)


@pytest.mark.asyncio
@pytest.mark.usefixtures('mock_get_api_client')
async def test_append_event():
Expand Down