Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from ..agents.run_config import RunConfig
from ..agents.run_config import StreamingMode
from ..apps.app import App
from ..artifacts.base_artifact_service import ArtifactVersion
from ..artifacts.base_artifact_service import BaseArtifactService
from ..auth.credential_service.base_credential_service import BaseCredentialService
from ..errors.already_exists_error import AlreadyExistsError
Expand Down Expand Up @@ -1294,6 +1295,24 @@ async def load_artifact(
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact

@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/metadata",
response_model=list[ArtifactVersion],
response_model_exclude_none=True,
)
async def list_artifact_versions_metadata(
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
) -> list[ArtifactVersion]:
return await self.artifact_service.list_artifact_versions(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
)

@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
response_model_exclude_none=True,
Expand All @@ -1316,6 +1335,31 @@ async def load_artifact_version(
raise HTTPException(status_code=404, detail="Artifact not found")
return artifact

@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}/metadata",
response_model=ArtifactVersion,
response_model_exclude_none=True,
)
async def get_artifact_version_metadata(
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version_id: int,
) -> ArtifactVersion:
artifact_version = await self.artifact_service.get_artifact_version(
app_name=app_name,
user_id=user_id,
session_id=session_id,
filename=artifact_name,
version=version_id,
)
if not artifact_version:
raise HTTPException(
status_code=404, detail="Artifact version not found"
)
return artifact_version

@app.get(
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
response_model_exclude_none=True,
Expand Down
42 changes: 42 additions & 0 deletions src/google/adk/cli/conformance/adk_web_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import httpx

from ...artifacts.base_artifact_service import ArtifactVersion
from ...events.event import Event
from ...sessions.session import Session
from ..adk_web_server import RunAgentRequest
Expand Down Expand Up @@ -265,3 +266,44 @@ async def run_agent(
yield Event.model_validate(event_data)
else:
logger.debug("Non data line received: %s", line)

async def get_artifact_version_metadata(
self,
*,
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
version: int,
) -> ArtifactVersion:
"""Retrieve metadata for a specific artifact version."""
async with self._get_client() as client:
response = await client.get(
(
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
f"/artifacts/{artifact_name}/versions/{version}/metadata"
)
)
response.raise_for_status()
return ArtifactVersion.model_validate(response.json())

async def list_artifact_versions_metadata(
self,
*,
app_name: str,
user_id: str,
session_id: str,
artifact_name: str,
) -> list[ArtifactVersion]:
"""List metadata for all versions of an artifact."""
async with self._get_client() as client:
response = await client.get(
(
f"/apps/{app_name}/users/{user_id}/sessions/{session_id}"
f"/artifacts/{artifact_name}/versions/metadata"
)
)
response.raise_for_status()
return [
ArtifactVersion.model_validate(item) for item in response.json()
]
79 changes: 79 additions & 0 deletions tests/unittests/cli/conformance/test_adk_web_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from unittest.mock import MagicMock
from unittest.mock import patch

from google.adk.artifacts.base_artifact_service import ArtifactVersion
from google.adk.cli.adk_web_server import RunAgentRequest
from google.adk.cli.conformance.adk_web_server_client import AdkWebServerClient
from google.adk.events.event import Event
Expand Down Expand Up @@ -224,6 +225,84 @@ def mock_stream(*_args, **_kwargs):
assert events[1].invocation_id == "test_invocation_2"


@pytest.mark.asyncio
async def test_get_artifact_version_metadata():
client = AdkWebServerClient()
mock_response = MagicMock()
mock_response.json.return_value = {
"version": 2,
"canonicalUri": (
"artifact://apps/app/users/user/sessions/session/"
"artifacts/report/versions/2"
),
"customMetadata": {"foo": "bar"},
"createTime": 123.4,
"mimeType": "text/plain",
}

with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
mock_client_class.return_value = mock_client

metadata = await client.get_artifact_version_metadata(
app_name="app",
user_id="user",
session_id="session",
artifact_name="report",
version=2,
)

assert isinstance(metadata, ArtifactVersion)
assert metadata.version == 2
assert metadata.custom_metadata == {"foo": "bar"}
mock_client.get.assert_called_once_with(
"/apps/app/users/user/sessions/session/artifacts/report/versions/2/metadata"
)
mock_response.raise_for_status.assert_called_once()


@pytest.mark.asyncio
async def test_list_artifact_versions_metadata():
client = AdkWebServerClient()
mock_response = MagicMock()
mock_response.json.return_value = [
{
"version": 0,
"canonicalUri": "artifact://.../versions/0",
"customMetadata": {},
"createTime": 100.0,
},
{
"version": 1,
"canonicalUri": "artifact://.../versions/1",
"customMetadata": {"foo": "bar"},
"createTime": 200.0,
"mimeType": "application/json",
},
]

with patch("httpx.AsyncClient") as mock_client_class:
mock_client = AsyncMock()
mock_client.get.return_value = mock_response
mock_client_class.return_value = mock_client

metadata_list = await client.list_artifact_versions_metadata(
app_name="app",
user_id="user",
session_id="session",
artifact_name="report",
)

assert len(metadata_list) == 2
assert all(isinstance(item, ArtifactVersion) for item in metadata_list)
assert metadata_list[1].custom_metadata == {"foo": "bar"}
mock_client.get.assert_called_once_with(
"/apps/app/users/user/sessions/session/artifacts/report/versions/metadata"
)
mock_response.raise_for_status.assert_called_once()


@pytest.mark.asyncio
async def test_close():
client = AdkWebServerClient()
Expand Down
Loading
Loading