diff --git a/pyproject.toml b/pyproject.toml index 4f8e42bcf7..452502bdd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ eval = [ test = [ # go/keep-sorted start "a2a-sdk>=0.3.0,<0.4.0", + "aioboto3>=15.5.0", # For S3 Artifact Service tests "anthropic>=0.43.0", # For anthropic model tests "crewai[tools];python_version>='3.10' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ "kubernetes>=29.0.0", # For GkeCodeExecutor @@ -144,6 +145,7 @@ docs = [ # Optional extensions extensions = [ + "aioboto3>=15.5.0", # For S3 Artifact Service "anthropic>=0.43.0", # For anthropic model support "beautifulsoup4>=3.2.2", # For load_web_page tool. "crewai[tools];python_version>='3.10' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ diff --git a/src/google/adk/artifacts/__init__.py b/src/google/adk/artifacts/__init__.py index 90a8063fae..88dd05dd7c 100644 --- a/src/google/adk/artifacts/__init__.py +++ b/src/google/adk/artifacts/__init__.py @@ -16,10 +16,12 @@ from .file_artifact_service import FileArtifactService from .gcs_artifact_service import GcsArtifactService from .in_memory_artifact_service import InMemoryArtifactService +from .s3_artifact_service import S3ArtifactService __all__ = [ 'BaseArtifactService', 'FileArtifactService', 'GcsArtifactService', 'InMemoryArtifactService', + 'S3ArtifactService', ] diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index 97b2fb147d..203e2eac41 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -106,7 +106,7 @@ def _resolve_scoped_artifact_path( pure_path = _to_posix_path(stripped) scope_root_resolved = scope_root.resolve(strict=False) - if pure_path.is_absolute(): + if Path(stripped).is_absolute(): raise ValueError( f"Absolute artifact filename {filename!r} is not permitted; " "provide a path relative to the storage scope." diff --git a/src/google/adk/artifacts/s3_artifact_service.py b/src/google/adk/artifacts/s3_artifact_service.py new file mode 100644 index 0000000000..fa2cad65c4 --- /dev/null +++ b/src/google/adk/artifacts/s3_artifact_service.py @@ -0,0 +1,374 @@ +"""An artifact service implementation using Amazon S3 or other S3-compatible services. + +The blob/key name format depends on whether the filename has a user namespace: + - For files with user namespace (starting with "user:"): + {app_name}/{user_id}/user/{filename}/{version} + - For regular session-scoped files: + {app_name}/{user_id}/{session_id}/{filename}/{version} + +This service supports storing and retrieving artifacts with inline data or text. +Artifacts can also have optional custom metadata, which is serialized as JSON +when stored in S3. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from google.genai import types +from pydantic import BaseModel +from typing_extensions import override + +from .base_artifact_service import ArtifactVersion +from .base_artifact_service import BaseArtifactService + +logger = logging.getLogger("google_adk." + __name__) + + +class S3ArtifactService(BaseArtifactService, BaseModel): + """An artifact service implementation using Amazon S3 or other S3-compatible services. + + Attributes: + bucket_name: The name of the S3 bucket to use for storing and retrieving artifacts. + aws_configs: A dictionary of AWS configuration options to pass to the boto3 client. + save_artifact_max_retries: The maximum number of retries to attempt when saving an artifact with version conflicts. + If set to -1, the service will retry indefinitely. + """ + + bucket_name: str + aws_configs: dict[str, Any] = {} + save_artifact_max_retries: int = -1 + _s3_client: Any = None + + async def _client(self): + """Creates or returns the aioboto3 S3 client.""" + import aioboto3 + + if self._s3_client is None: + self._s3_client = ( + await aioboto3.Session() + .client(service_name="s3", **self.aws_configs) + .__aenter__() + ) + return self._s3_client + + async def close(self): + """Closes the underlying S3 client session.""" + if self._s3_client: + await self._s3_client.__aexit__(None, None, None) + self._s3_client = None + + def _flatten_metadata(self, metadata: dict[str, Any]) -> dict[str, str]: + return {k: json.dumps(v) for k, v in (metadata or {}).items()} + + def _unflatten_metadata(self, metadata: dict[str, str]) -> dict[str, Any]: + results = {} + for k, v in (metadata or {}).items(): + try: + results[k] = json.loads(v) + except json.JSONDecodeError: + logger.warning( + f"Failed to decode metadata value for key {k}. Using raw string." + ) + results[k] = v + return results + + def _file_has_user_namespace(self, filename: str) -> bool: + return filename.startswith("user:") + + def _get_blob_prefix( + self, app_name: str, user_id: str, session_id: str | None, filename: str + ) -> str: + if self._file_has_user_namespace(filename): + return f"{app_name}/{user_id}/user/{filename}" + if session_id: + return f"{app_name}/{user_id}/{session_id}/{filename}" + raise ValueError("session_id is required for session-scoped artifacts.") + + def _get_blob_name( + self, + app_name: str, + user_id: str, + session_id: str | None, + filename: str, + version: int, + ) -> str: + return ( + f"{self._get_blob_prefix(app_name, user_id, session_id, filename)}/{version}" + ) + + @override + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: types.Part, + session_id: str | None = None, + custom_metadata: dict[str, Any] | None = None, + ) -> int: + """Saves an artifact to S3 with atomic versioning using If-None-Match.""" + from botocore.exceptions import ClientError + + s3 = await self._client() + + if self.save_artifact_max_retries < 0: + retry_iter = iter(int, 1) + else: + retry_iter = range(self.save_artifact_max_retries + 1) + for _ in retry_iter: + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + version = 0 if not versions else max(versions) + 1 + key = self._get_blob_name( + app_name, user_id, session_id, filename, version + ) + if artifact.inline_data: + body = artifact.inline_data.data + mime_type = artifact.inline_data.mime_type + elif artifact.text: + body = artifact.text + mime_type = "text/plain" + elif artifact.file_data: + raise NotImplementedError( + "Saving artifact with file_data is not supported yet in" + " S3ArtifactService." + ) + else: + raise ValueError("Artifact must have either inline_data or text.") + + try: + await s3.put_object( + Bucket=self.bucket_name, + Key=key, + Body=body, + ContentType=mime_type, + Metadata=self._flatten_metadata(custom_metadata), + IfNoneMatch="*", + ) + return version + except ClientError as e: + if e.response["Error"]["Code"] in ( + "PreconditionFailed", + "ObjectAlreadyExists", + ): + continue + raise e + raise RuntimeError( + "Failed to save artifact due to version conflicts after retries" + ) + + @override + async def load_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + version: int | None = None, + ) -> types.Part | None: + """Loads a specific version of an artifact from S3. + + If version is not provided, the latest version is loaded. + + Returns: + A types.Part instance (always with inline_data), or None if the artifact does not exist. + """ + from botocore.exceptions import ClientError + + s3 = await self._client() + if version is None: + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not versions: + return None + version = max(versions) + + key = self._get_blob_name(app_name, user_id, session_id, filename, version) + try: + response = await s3.get_object(Bucket=self.bucket_name, Key=key) + async with response["Body"] as stream: + data = await stream.read() + mime_type = response["ContentType"] + except ClientError as e: + if e.response["Error"]["Code"] in ("NoSuchKey", "404"): + return None + raise + return types.Part.from_bytes(data=data, mime_type=mime_type) + + @override + async def list_artifact_keys( + self, *, app_name: str, user_id: str, session_id: str | None = None + ) -> list[str]: + """Lists all artifact keys for a user, optionally filtered by session.""" + s3 = await self._client() + keys = set() + prefixes = [ + f"{app_name}/{user_id}/{session_id}/" if session_id else None, + f"{app_name}/{user_id}/user/", + ] + + for prefix in filter(None, prefixes): + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.bucket_name, Prefix=prefix + ): + for obj in page.get("Contents", []): + relative = obj["Key"][len(prefix) :] + filename = "/".join(relative.split("/")[:-1]) + keys.add(filename) + return sorted(keys) + + @override + async def delete_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> None: + """Deletes all versions of a specified artifact efficiently using batch delete.""" + s3 = await self._client() + versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not versions: + return + + keys_to_delete = [ + {"Key": self._get_blob_name(app_name, user_id, session_id, filename, v)} + for v in versions + ] + for i in range(0, len(keys_to_delete), 1000): + batch = keys_to_delete[i : i + 1000] + await s3.delete_objects( + Bucket=self.bucket_name, Delete={"Objects": batch} + ) + + @override + async def list_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> list[int]: + """Lists all available versions of a specified artifact.""" + s3 = await self._client() + prefix = ( + self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" + ) + versions = [] + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.bucket_name, Prefix=prefix + ): + for obj in page.get("Contents", []): + try: + versions.append(int(obj["Key"].split("/")[-1])) + except ValueError: + continue + return sorted(versions) + + @override + async def list_artifact_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + ) -> list[ArtifactVersion]: + """Lists all artifact versions with their metadata.""" + s3 = await self._client() + prefix = ( + self._get_blob_prefix(app_name, user_id, session_id, filename) + "/" + ) + results: list[ArtifactVersion] = [] + + paginator = s3.get_paginator("list_objects_v2") + async for page in paginator.paginate( + Bucket=self.bucket_name, Prefix=prefix + ): + for obj in page.get("Contents", []): + try: + version = int(obj["Key"].split("/")[-1]) + except ValueError: + continue + + head = await s3.head_object(Bucket=self.bucket_name, Key=obj["Key"]) + mime_type = head["ContentType"] + metadata = head.get("Metadata", {}) + + canonical_uri = f"s3://{self.bucket_name}/{obj['Key']}" + + results.append( + ArtifactVersion( + version=version, + canonical_uri=canonical_uri, + custom_metadata=self._unflatten_metadata(metadata), + create_time=obj["LastModified"].timestamp(), + mime_type=mime_type, + ) + ) + + return sorted(results, key=lambda a: a.version) + + @override + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: str | None = None, + version: int | None = None, + ) -> ArtifactVersion | None: + """Retrieves a specific artifact version, or the latest if version is None.""" + s3 = await self._client() + if version is None: + all_versions = await self.list_versions( + app_name=app_name, + user_id=user_id, + filename=filename, + session_id=session_id, + ) + if not all_versions: + return None + version = max(all_versions) + + key = self._get_blob_name(app_name, user_id, session_id, filename, version) + + from botocore.exceptions import ClientError + + try: + head = await s3.head_object(Bucket=self.bucket_name, Key=key) + except ClientError as e: + if e.response["Error"]["Code"] in ("NoSuchKey", "404"): + return None + raise + + return ArtifactVersion( + version=version, + canonical_uri=f"s3://{self.bucket_name}/{key}", + custom_metadata=self._unflatten_metadata(head.get("Metadata", {})), + create_time=head["LastModified"].timestamp(), + mime_type=head["ContentType"], + ) diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index 007b18ecf7..a0468809de 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -16,10 +16,13 @@ """Tests for the artifact service.""" +import asyncio from datetime import datetime import enum import json from pathlib import Path +import random +import sys from typing import Any from typing import Optional from typing import Union @@ -27,11 +30,14 @@ from unittest.mock import patch from urllib.parse import unquote from urllib.parse import urlparse +from urllib.request import url2pathname +from botocore.exceptions import ClientError from google.adk.artifacts.base_artifact_service import ArtifactVersion from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.gcs_artifact_service import GcsArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.artifacts.s3_artifact_service import S3ArtifactService from google.genai import types import pytest @@ -45,6 +51,7 @@ class ArtifactServiceType(Enum): FILE = "FILE" IN_MEMORY = "IN_MEMORY" GCS = "GCS" + S3 = "S3" class MockBlob: @@ -167,8 +174,188 @@ def mock_gcs_artifact_service(): return GcsArtifactService(bucket_name="test_bucket") +class MockBody: + + def __init__(self, data: bytes): + self._data = data + + async def read(self) -> bytes: + return self._data + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + pass + + +class MockAsyncS3Object: + + def __init__(self, key): + self.key = key + self.data = None + self.content_type = None + self.metadata = {} + self.last_modified = FIXED_DATETIME + + async def put(self, Body, ContentType=None, Metadata=None): + self.data = Body if isinstance(Body, bytes) else Body.encode("utf-8") + self.content_type = ContentType + self.metadata = Metadata or {} + + async def get(self): + if self.data is None: + raise ClientError( + {"Error": {"Code": "NoSuchKey", "Message": "Not Found"}}, + operation_name="GetObject", + ) + return { + "Body": MockBody(self.data), + "ContentType": self.content_type, + "Metadata": self.metadata, + "LastModified": self.last_modified, + } + + +class MockAsyncS3Bucket: + + def __init__(self, name): + self.name = name + self.objects = {} + + def object(self, key): + if key not in self.objects: + self.objects[key] = MockAsyncS3Object(key) + return self.objects[key] + + async def listed_keys(self, prefix=None): + return [ + k + for k, obj in self.objects.items() + if obj.data is not None and (prefix is None or k.startswith(prefix)) + ] + + +class MockAsyncS3Client: + + def __init__(self): + self.buckets = {} + + def get_bucket(self, bucket_name): + if bucket_name not in self.buckets: + self.buckets[bucket_name] = MockAsyncS3Bucket(bucket_name) + return self.buckets[bucket_name] + + async def put_object( + self, Bucket, Key, Body, ContentType=None, Metadata=None, IfNoneMatch=None + ): + await asyncio.sleep(random.uniform(0, 0.05)) + bucket = self.get_bucket(Bucket) + obj_exists = Key in bucket.objects and bucket.objects[Key].data is not None + + if IfNoneMatch == "*" and obj_exists: + raise ClientError( + {"Error": {"Code": "PreconditionFailed", "Message": "Object exists"}}, + operation_name="PutObject", + ) + + await bucket.object(Key).put( + Body=Body, ContentType=ContentType, Metadata=Metadata + ) + + async def get_object(self, Bucket, Key): + bucket = self.get_bucket(Bucket) + obj = bucket.object(Key) + return await obj.get() + + async def delete_object(self, Bucket, Key): + bucket = self.get_bucket(Bucket) + bucket.objects.pop(Key, None) + + async def delete_objects(self, Bucket, Delete): + bucket = self.get_bucket(Bucket) + for item in Delete.get("Objects", []): + key = item.get("Key") + if key in bucket.objects: + bucket.objects.pop(key) + + async def list_objects_v2(self, Bucket, Prefix=None): + bucket = self.get_bucket(Bucket) + keys = await bucket.listed_keys(Prefix) + return { + "KeyCount": len(keys), + "Contents": [ + {"Key": k, "LastModified": bucket.objects[k].last_modified} + for k in keys + ], + } + + async def head_object(self, Bucket, Key): + obj = await self.get_object(Bucket, Key) + return { + "ContentType": obj["ContentType"], + "Metadata": obj.get("Metadata", {}), + "LastModified": obj.get("LastModified"), + } + + def get_paginator(self, operation_name): + if operation_name != "list_objects_v2": + raise NotImplementedError( + f"Paginator for {operation_name} not implemented" + ) + + class MockAsyncPaginator: + + def __init__(self, client, Bucket, Prefix=None): + self.client = client + self.Bucket = Bucket + self.Prefix = Prefix + + async def __aiter__(self): + response = await self.client.list_objects_v2( + Bucket=self.Bucket, Prefix=self.Prefix + ) + contents = response.get("Contents", []) + page_size = 2 + for i in range(0, len(contents), page_size): + yield { + "KeyCount": len(contents[i : i + page_size]), + "Contents": contents[i : i + page_size], + } + + class MockPaginator: + + def paginate(inner_self, Bucket, Prefix=None): + return MockAsyncPaginator(self, Bucket, Prefix) + + return MockPaginator() + + +def mock_s3_artifact_service(monkeypatch): + mock_s3_client = MockAsyncS3Client() + + class MockAioboto3: + + class Session: + + def client(self, *args, **kwargs): + class MockClientCtx: + + async def __aenter__(self_inner): + return mock_s3_client + + async def __aexit__(self_inner, exc_type, exc, tb): + pass + + return MockClientCtx() + + monkeypatch.setitem(sys.modules, "aioboto3", MockAioboto3) + artifact_service = S3ArtifactService(bucket_name="test_bucket") + return artifact_service + + @pytest.fixture -def artifact_service_factory(tmp_path: Path): +def artifact_service_factory(tmp_path: Path, monkeypatch): """Provides an artifact service constructor bound to the test tmp path.""" def factory( @@ -178,6 +365,8 @@ def factory( return mock_gcs_artifact_service() if service_type == ArtifactServiceType.FILE: return FileArtifactService(root_dir=tmp_path / "artifacts") + if service_type == ArtifactServiceType.S3: + return mock_s3_artifact_service(monkeypatch) return InMemoryArtifactService() return factory @@ -190,6 +379,7 @@ def factory( ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_load_empty(service_type, artifact_service_factory): @@ -210,6 +400,7 @@ async def test_load_empty(service_type, artifact_service_factory): ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_save_load_delete(service_type, artifact_service_factory): @@ -268,6 +459,7 @@ async def test_save_load_delete(service_type, artifact_service_factory): ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_list_keys(service_type, artifact_service_factory): @@ -304,6 +496,7 @@ async def test_list_keys(service_type, artifact_service_factory): ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_list_versions(service_type, artifact_service_factory): @@ -348,6 +541,7 @@ async def test_list_versions(service_type, artifact_service_factory): ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS, ArtifactServiceType.FILE, + ArtifactServiceType.S3, ], ) async def test_list_keys_preserves_user_prefix( @@ -398,7 +592,12 @@ async def test_list_keys_preserves_user_prefix( @pytest.mark.asyncio @pytest.mark.parametrize( - "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.S3, + ], ) async def test_list_artifact_versions_and_get_artifact_version( service_type, artifact_service_factory @@ -446,6 +645,10 @@ async def test_list_artifact_versions_and_get_artifact_version( uri = ( f"gs://test_bucket/{app_name}/{user_id}/{session_id}/{filename}/{i}" ) + elif service_type == ArtifactServiceType.S3: + uri = ( + f"s3://test_bucket/{app_name}/{user_id}/{session_id}/{filename}/{i}" + ) else: uri = f"memory://apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{filename}/versions/{i}" expected_artifact_versions.append( @@ -485,7 +688,12 @@ async def test_list_artifact_versions_and_get_artifact_version( @pytest.mark.asyncio @pytest.mark.parametrize( - "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.S3, + ], ) async def test_list_artifact_versions_with_user_prefix( service_type, artifact_service_factory @@ -532,6 +740,8 @@ async def test_list_artifact_versions_with_user_prefix( metadata = {"key": "value" + str(i)} if service_type == ArtifactServiceType.GCS: uri = f"gs://test_bucket/{app_name}/{user_id}/user/{user_scoped_filename}/{i}" + elif service_type == ArtifactServiceType.S3: + uri = f"s3://test_bucket/{app_name}/{user_id}/user/{user_scoped_filename}/{i}" else: uri = f"memory://apps/{app_name}/users/{user_id}/artifacts/{user_scoped_filename}/versions/{i}" expected_artifact_versions.append( @@ -548,7 +758,12 @@ async def test_list_artifact_versions_with_user_prefix( @pytest.mark.asyncio @pytest.mark.parametrize( - "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.S3, + ], ) async def test_get_artifact_version_artifact_does_not_exist( service_type, artifact_service_factory @@ -565,7 +780,12 @@ async def test_get_artifact_version_artifact_does_not_exist( @pytest.mark.asyncio @pytest.mark.parametrize( - "service_type", [ArtifactServiceType.IN_MEMORY, ArtifactServiceType.GCS] + "service_type", + [ + ArtifactServiceType.IN_MEMORY, + ArtifactServiceType.GCS, + ArtifactServiceType.S3, + ], ) async def test_get_artifact_version_out_of_index( service_type, artifact_service_factory @@ -643,7 +863,7 @@ async def test_file_metadata_camelcase(tmp_path, artifact_service_factory): "customMetadata": {}, } parsed_canonical = urlparse(metadata["canonicalUri"]) - canonical_path = Path(unquote(parsed_canonical.path)) + canonical_path = Path(url2pathname(unquote(parsed_canonical.path))) assert canonical_path.name == "report.txt" assert canonical_path.read_bytes() == b"binary-content" @@ -693,7 +913,7 @@ async def test_file_list_artifact_versions(tmp_path, artifact_service_factory): assert version_meta.canonical_uri == version_payload_path.as_uri() assert version_meta.custom_metadata == custom_metadata parsed_version_uri = urlparse(version_meta.canonical_uri) - version_uri_path = Path(unquote(parsed_version_uri.path)) + version_uri_path = Path(url2pathname(unquote(parsed_version_uri.path))) assert version_uri_path.read_bytes() == b"binary-content" fetched = await artifact_service.get_artifact_version(