From 7e9c8e9d0ace7e2ca24f3d67ccc0b061ecddd8eb Mon Sep 17 00:00:00 2001 From: Raman369AI Date: Thu, 5 Mar 2026 22:38:02 -0600 Subject: [PATCH] feat(memory): add DatabaseMemoryService with SQL backend and scratchpad MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a durable, RDBMS-backed memory service that works with any SQLAlchemy-supported database (SQLite, PostgreSQL, MySQL, MariaDB) as an alternative to the volatile InMemoryMemoryService. Key additions: - DatabaseMemoryService: implements BaseMemoryService with lazy table creation, idempotent session ingest, and delta event ingestion - MemorySearchBackend ABC + KeywordSearchBackend: LIKE/ILIKE search with AND-first → OR-fallback tokenization strategy - Scratchpad KV store and append-only log for intermediate agent state - Four agent-callable BaseTool subclasses: scratchpad_get_tool, scratchpad_set_tool, scratchpad_append_log_tool, scratchpad_get_log_tool - 38 unit tests covering all methods, tool happy-paths, wrong-service errors, multi-user isolation, and session scoping --- .gitignore | 3 + src/google/adk/memory/__init__.py | 17 + .../adk/memory/database_memory_service.py | 562 ++++++++++++++ .../adk/memory/memory_search_backend.py | 127 +++ src/google/adk/memory/schemas/__init__.py | 13 + .../adk/memory/schemas/memory_schema.py | 153 ++++ src/google/adk/tools/scratchpad_tool.py | 246 ++++++ .../memory/test_database_memory_service.py | 734 ++++++++++++++++++ 8 files changed, 1855 insertions(+) create mode 100644 src/google/adk/memory/database_memory_service.py create mode 100644 src/google/adk/memory/memory_search_backend.py create mode 100644 src/google/adk/memory/schemas/__init__.py create mode 100644 src/google/adk/memory/schemas/memory_schema.py create mode 100644 src/google/adk/tools/scratchpad_tool.py create mode 100644 tests/unittests/memory/test_database_memory_service.py diff --git a/.gitignore b/.gitignore index 47f633c5c5..7473842e51 100644 --- a/.gitignore +++ b/.gitignore @@ -99,6 +99,9 @@ Thumbs.db *.tmp *.temp +# Agent handoff / session notes (not for version control) +AGENT_HANDOFF.md + # AI Coding Tools - Project-specific configs # Developers should symlink or copy AGENTS.md and add their own overrides locally .adk/ diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index c47fb8ec40..2cfb657842 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -35,3 +35,20 @@ ' VertexAiRagMemoryService please install it. If not, you can ignore this' ' warning.' ) + +try: + from .database_memory_service import DatabaseMemoryService + from .memory_search_backend import KeywordSearchBackend + from .memory_search_backend import MemorySearchBackend + + __all__ += [ + 'DatabaseMemoryService', + 'KeywordSearchBackend', + 'MemorySearchBackend', + ] +except ImportError: + logger.debug( + 'SQLAlchemy or an async DB driver is not installed. If you want to use' + ' DatabaseMemoryService please install sqlalchemy and an async driver' + ' (e.g. aiosqlite for SQLite). If not, you can ignore this warning.' + ) diff --git a/src/google/adk/memory/database_memory_service.py b/src/google/adk/memory/database_memory_service.py new file mode 100644 index 0000000000..f00fd3b483 --- /dev/null +++ b/src/google/adk/memory/database_memory_service.py @@ -0,0 +1,562 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQL-backed memory service with scratchpad support for ADK agents.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from collections.abc import Sequence +from contextlib import asynccontextmanager +import logging +from typing import Any +from typing import AsyncIterator +from typing import Optional +from typing import TYPE_CHECKING +import uuid + +from google.genai import types +from sqlalchemy import delete +from sqlalchemy import select +from sqlalchemy.engine import make_url +from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.asyncio import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.pool import StaticPool +from typing_extensions import override + +from . import _utils +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry +from .memory_search_backend import KeywordSearchBackend +from .memory_search_backend import MemorySearchBackend +from .schemas.memory_schema import Base +from .schemas.memory_schema import StorageMemoryEntry +from .schemas.memory_schema import StorageScratchpadKV +from .schemas.memory_schema import StorageScratchpadLog + +if TYPE_CHECKING: + from ..events.event import Event + from ..sessions.session import Session + +logger = logging.getLogger("google_adk." + __name__) + +_SQLITE_DIALECT = "sqlite" + + +class DatabaseMemoryService(BaseMemoryService): + """A durable, SQL-backed memory service for any SQLAlchemy-supported DB. + + Works with SQLite, PostgreSQL, MySQL, MariaDB, and Spanner. Exposes a + scratchpad (KV store + append-log) for agents to use as intermediate + working memory during task execution. + + Usage:: + + from google.adk.memory import DatabaseMemoryService + + # SQLite (no external DB needed): + svc = DatabaseMemoryService("sqlite+aiosqlite:///:memory:") + + # PostgreSQL: + svc = DatabaseMemoryService( + "postgresql+asyncpg://user:pass@host/dbname" + ) + """ + + def __init__( + self, + db_url: str, + search_backend: Optional[MemorySearchBackend] = None, + **kwargs: Any, + ): + """Initialises the service and creates a DB engine. + + Args: + db_url: SQLAlchemy async connection URL. + search_backend: Optional custom search backend. Defaults to + KeywordSearchBackend. + **kwargs: Extra keyword arguments forwarded to + sqlalchemy.ext.asyncio.create_async_engine. + + Raises: + ValueError: If the db_url is invalid or the required DB driver is + not installed. + """ + try: + engine_kwargs = dict(kwargs) + url = make_url(db_url) + backend = url.get_backend_name() + if backend == _SQLITE_DIALECT and url.database == ":memory:": + engine_kwargs.setdefault("poolclass", StaticPool) + connect_args = dict(engine_kwargs.get("connect_args", {})) + connect_args.setdefault("check_same_thread", False) + engine_kwargs["connect_args"] = connect_args + elif backend != _SQLITE_DIALECT: + engine_kwargs.setdefault("pool_pre_ping", True) + + self.db_engine: AsyncEngine = create_async_engine(db_url, **engine_kwargs) + except ArgumentError as exc: + raise ValueError( + f"Invalid database URL format or argument '{db_url}'." + ) from exc + except ImportError as exc: + raise ValueError( + f"Database-related module not found for URL '{db_url}'." + ) from exc + + self._session_factory: async_sessionmaker[AsyncSession] = ( + async_sessionmaker(bind=self.db_engine, expire_on_commit=False) + ) + self._search_backend: MemorySearchBackend = ( + search_backend or KeywordSearchBackend() + ) + self._tables_created = False + self._table_creation_lock = asyncio.Lock() + + # --------------------------------------------------------------------------- + # Internal helpers + # --------------------------------------------------------------------------- + + @asynccontextmanager + async def _session(self) -> AsyncIterator[AsyncSession]: + """Yield an AsyncSession; roll back on exception.""" + async with self._session_factory() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + async def _prepare_tables(self) -> None: + """Lazy, double-checked table initialisation.""" + if self._tables_created: + return + async with self._table_creation_lock: + if self._tables_created: + return + async with self.db_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + self._tables_created = True + + @staticmethod + def _extract_search_text(content: types.Content) -> str: + """Join all text parts of a Content into a single searchable string.""" + if not content or not content.parts: + return "" + return " ".join(part.text for part in content.parts if part.text) + + @staticmethod + def _should_skip_event(event: Event) -> bool: + """Return True if the event has no usable text content.""" + if not event.content or not event.content.parts: + return True + return not any(part.text for part in event.content.parts if part.text) + + # --------------------------------------------------------------------------- + # BaseMemoryService implementation + # --------------------------------------------------------------------------- + + @override + async def add_session_to_memory(self, session: Session) -> None: + """Idempotently ingest all events from a session. + + Deletes any existing rows for this session, then re-inserts from scratch. + + Args: + session: The session whose events should be stored in memory. + """ + await self._prepare_tables() + async with self._session() as sql: + # Delete existing rows for this session so re-ingest is idempotent. + await sql.execute( + delete(StorageMemoryEntry).where( + StorageMemoryEntry.app_name == session.app_name, + StorageMemoryEntry.user_id == session.user_id, + StorageMemoryEntry.session_id == session.id, + ) + ) + for event in session.events: + if self._should_skip_event(event): + continue + content_dict = event.content.model_dump(mode="json", exclude_none=True) + sql.add( + StorageMemoryEntry( + id=str(uuid.uuid4()), + app_name=session.app_name, + user_id=session.user_id, + session_id=session.id, + event_id=event.id, + author=event.author, + timestamp=_utils.format_timestamp(event.timestamp), + content_json=content_dict, + search_text=self._extract_search_text(event.content), + custom_metadata={}, + ) + ) + + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: Optional[str] = None, + custom_metadata: Optional[Mapping[str, object]] = None, + ) -> None: + """Delta-insert events; skips duplicate event_id within the same session. + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + events: The events to add to memory. + session_id: Optional session ID for memory scope/partitioning. + custom_metadata: Optional metadata attached to each stored entry. + """ + await self._prepare_tables() + async with self._session() as sql: + # Fetch existing event IDs for this session to avoid duplicates. + stmt = select(StorageMemoryEntry.event_id).where( + StorageMemoryEntry.app_name == app_name, + StorageMemoryEntry.user_id == user_id, + StorageMemoryEntry.session_id == session_id, + StorageMemoryEntry.event_id.isnot(None), + ) + result = await sql.execute(stmt) + existing_event_ids = {row[0] for row in result.fetchall()} + + meta = dict(custom_metadata) if custom_metadata else {} + for event in events: + if self._should_skip_event(event): + continue + if event.id and event.id in existing_event_ids: + continue + content_dict = event.content.model_dump(mode="json", exclude_none=True) + sql.add( + StorageMemoryEntry( + id=str(uuid.uuid4()), + app_name=app_name, + user_id=user_id, + session_id=session_id, + event_id=event.id, + author=event.author, + timestamp=_utils.format_timestamp(event.timestamp), + content_json=content_dict, + search_text=self._extract_search_text(event.content), + custom_metadata=meta, + ) + ) + if event.id: + existing_event_ids.add(event.id) + + @override + async def add_memory( + self, + *, + app_name: str, + user_id: str, + memories: Sequence[MemoryEntry], + custom_metadata: Optional[Mapping[str, object]] = None, + ) -> None: + """Directly insert MemoryEntry objects (not tied to session events). + + Args: + app_name: The application name for memory scope. + user_id: The user ID for memory scope. + memories: Explicit memory items to add. + custom_metadata: Optional metadata attached to each stored entry. + """ + await self._prepare_tables() + meta = dict(custom_metadata) if custom_metadata else {} + async with self._session() as sql: + for entry in memories: + entry_id = entry.id or str(uuid.uuid4()) + content_dict = entry.content.model_dump(mode="json", exclude_none=True) + sql.add( + StorageMemoryEntry( + id=entry_id, + app_name=app_name, + user_id=user_id, + session_id=None, + event_id=None, + author=entry.author, + timestamp=entry.timestamp, + content_json=content_dict, + search_text=self._extract_search_text(entry.content), + custom_metadata={**entry.custom_metadata, **meta}, + ) + ) + + @override + async def search_memory( + self, + *, + app_name: str, + user_id: str, + query: str, + ) -> SearchMemoryResponse: + """Search stored memories using the configured search backend. + + Args: + app_name: The name of the application. + user_id: The id of the user. + query: The query to search for. + + Returns: + A SearchMemoryResponse containing the matching memories. + """ + await self._prepare_tables() + async with self._session() as sql: + rows = await self._search_backend.search( + sql_session=sql, + app_name=app_name, + user_id=user_id, + query=query, + ) + memories = [] + for row in rows: + try: + content = types.Content.model_validate(row.content_json) + except Exception: # pylint: disable=broad-except + logger.warning( + "Skipping memory entry %s: invalid content JSON", row.id + ) + continue + memories.append( + MemoryEntry( + id=row.id, + content=content, + author=row.author, + timestamp=row.timestamp, + custom_metadata=row.custom_metadata or {}, + ) + ) + return SearchMemoryResponse(memories=memories) + + # --------------------------------------------------------------------------- + # Scratchpad KV methods + # --------------------------------------------------------------------------- + + async def set_scratchpad( + self, + *, + app_name: str, + user_id: str, + session_id: str = "", + key: str, + value: Any, + ) -> None: + """Write a key-value pair to the scratchpad. + + Overwrites any existing value for the same composite key. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level (non-session) KV. + key: The key to write. + value: The JSON-serialisable value to store. + """ + await self._prepare_tables() + async with self._session() as sql: + existing = await sql.get( + StorageScratchpadKV, (app_name, user_id, session_id, key) + ) + if existing is not None: + existing.value_json = value + else: + sql.add( + StorageScratchpadKV( + app_name=app_name, + user_id=user_id, + session_id=session_id, + key=key, + value_json=value, + ) + ) + + async def get_scratchpad( + self, + *, + app_name: str, + user_id: str, + session_id: str = "", + key: str, + ) -> Any | None: + """Read a value from the scratchpad. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level (non-session) KV. + key: The key to read. + + Returns: + The stored value, or None if the key does not exist. + """ + await self._prepare_tables() + async with self._session() as sql: + row = await sql.get( + StorageScratchpadKV, (app_name, user_id, session_id, key) + ) + return row.value_json if row is not None else None + + async def delete_scratchpad( + self, + *, + app_name: str, + user_id: str, + session_id: str = "", + key: str, + ) -> None: + """Delete a key-value pair from the scratchpad. No-op if not found. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level (non-session) KV. + key: The key to delete. + """ + await self._prepare_tables() + async with self._session() as sql: + await sql.execute( + delete(StorageScratchpadKV).where( + StorageScratchpadKV.app_name == app_name, + StorageScratchpadKV.user_id == user_id, + StorageScratchpadKV.session_id == session_id, + StorageScratchpadKV.key == key, + ) + ) + + async def list_scratchpad_keys( + self, + *, + app_name: str, + user_id: str, + session_id: str = "", + ) -> list[str]: + """Return all keys present in the scratchpad for the given scope. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level (non-session) KV. + + Returns: + A list of key strings. + """ + await self._prepare_tables() + async with self._session() as sql: + result = await sql.execute( + select(StorageScratchpadKV.key).where( + StorageScratchpadKV.app_name == app_name, + StorageScratchpadKV.user_id == user_id, + StorageScratchpadKV.session_id == session_id, + ) + ) + return [row[0] for row in result.fetchall()] + + # --------------------------------------------------------------------------- + # Scratchpad log methods + # --------------------------------------------------------------------------- + + async def append_log( + self, + *, + app_name: str, + user_id: str, + session_id: str = "", + content: str, + tag: Optional[str] = None, + agent_name: Optional[str] = None, + extra: Optional[Any] = None, + ) -> None: + """Append an entry to the append-only scratchpad log. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level log. + content: The text content to log. + tag: Optional category label for filtering. + agent_name: Optional name of the agent appending this entry. + extra: Optional JSON-serialisable extra data. + """ + await self._prepare_tables() + async with self._session() as sql: + sql.add( + StorageScratchpadLog( + app_name=app_name, + user_id=user_id, + session_id=session_id, + tag=tag, + agent_name=agent_name, + content=content, + extra_json=extra, + ) + ) + + async def get_log( + self, + *, + app_name: str, + user_id: str, + session_id: str = "", + tag: Optional[str] = None, + limit: int = 50, + ) -> list[dict]: + """Read the most recent log entries, optionally filtered by tag. + + Args: + app_name: Application name scope. + user_id: User ID scope. + session_id: Session ID scope. Use '' for user-level log. + tag: Optional tag to filter results by. + limit: Maximum number of entries to return. + + Returns: + A list of dicts with keys: id, tag, agent_name, content, extra. + """ + await self._prepare_tables() + async with self._session() as sql: + stmt = ( + select(StorageScratchpadLog) + .where( + StorageScratchpadLog.app_name == app_name, + StorageScratchpadLog.user_id == user_id, + StorageScratchpadLog.session_id == session_id, + ) + .order_by(StorageScratchpadLog.id.desc()) + .limit(limit) + ) + if tag is not None: + stmt = stmt.where(StorageScratchpadLog.tag == tag) + result = await sql.execute(stmt) + rows = result.scalars().all() + return [ + { + "id": r.id, + "tag": r.tag, + "agent_name": r.agent_name, + "content": r.content, + "extra": r.extra_json, + } + for r in reversed(rows) + ] diff --git a/src/google/adk/memory/memory_search_backend.py b/src/google/adk/memory/memory_search_backend.py new file mode 100644 index 0000000000..c77d55982c --- /dev/null +++ b/src/google/adk/memory/memory_search_backend.py @@ -0,0 +1,127 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Memory search backends for DatabaseMemoryService.""" + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod +from collections.abc import Sequence +import re +from typing import TYPE_CHECKING + +from sqlalchemy import or_ +from sqlalchemy import select + +from .schemas.memory_schema import StorageMemoryEntry + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession + +_ILIKE_DIALECTS = frozenset({"postgresql", "mysql", "mariadb"}) + + +class MemorySearchBackend(ABC): + """Abstract base class for memory search strategies.""" + + @abstractmethod + async def search( + self, + *, + sql_session: AsyncSession, + app_name: str, + user_id: str, + query: str, + limit: int = 10, + ) -> Sequence[StorageMemoryEntry]: + """Search for memory entries matching the query. + + Args: + sql_session: The active async SQLAlchemy session. + app_name: Application name scope. + user_id: User ID scope. + query: Natural-language or keyword query string. + limit: Maximum number of results to return. + + Returns: + A sequence of matching StorageMemoryEntry rows. + """ + + +class KeywordSearchBackend(MemorySearchBackend): + """LIKE/ILIKE keyword search on the search_text column. + + Strategy: + 1. Tokenise the query into individual words. + 2. Try an AND predicate (all tokens must appear) — return if found. + 3. Fall back to OR (any token matches) if AND yields nothing. + + Uses ILIKE on PostgreSQL/MySQL/MariaDB and LIKE on SQLite + (case-insensitive by default collation). + """ + + async def search( + self, + *, + sql_session: AsyncSession, + app_name: str, + user_id: str, + query: str, + limit: int = 10, + ) -> Sequence[StorageMemoryEntry]: + """Search for memory entries using LIKE/ILIKE keyword matching.""" + if not query or not query.strip(): + return [] + + tokens = [ + cleaned + for raw in query.split() + if raw.strip() + for cleaned in [re.sub(r"[^\w]", "", raw).lower()] + if cleaned + ] + if not tokens: + return [] + + # Determine dialect via the engine bound to this session. + dialect_name = sql_session.get_bind().dialect.name + use_ilike = dialect_name in _ILIKE_DIALECTS + + def _like_expr(token: str): + pattern = f"%{token}%" + col = StorageMemoryEntry.search_text + return col.ilike(pattern) if use_ilike else col.like(pattern) + + base_stmt = ( + select(StorageMemoryEntry) + .where( + StorageMemoryEntry.app_name == app_name, + StorageMemoryEntry.user_id == user_id, + StorageMemoryEntry.search_text.isnot(None), + ) + .limit(limit) + ) + + # AND predicate: all tokens must match. + and_stmt = base_stmt.where(*[_like_expr(t) for t in tokens]) + result = await sql_session.execute(and_stmt) + rows = result.scalars().all() + if rows: + return rows + + # OR fallback: any token matches. + or_stmt = base_stmt.where(or_(*[_like_expr(t) for t in tokens])) + result = await sql_session.execute(or_stmt) + return result.scalars().all() diff --git a/src/google/adk/memory/schemas/__init__.py b/src/google/adk/memory/schemas/__init__.py new file mode 100644 index 0000000000..58d482ea38 --- /dev/null +++ b/src/google/adk/memory/schemas/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/memory/schemas/memory_schema.py b/src/google/adk/memory/schemas/memory_schema.py new file mode 100644 index 0000000000..4a7d36f531 --- /dev/null +++ b/src/google/adk/memory/schemas/memory_schema.py @@ -0,0 +1,153 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SQLAlchemy ORM schema for DatabaseMemoryService tables.""" + +from __future__ import annotations + +from typing import Any +from typing import Optional + +from sqlalchemy import func +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import Text +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.types import String + +from ...sessions.schemas.shared import DynamicJSON +from ...sessions.schemas.shared import PreciseTimestamp + +DEFAULT_MAX_KEY_LENGTH = 128 +DEFAULT_MAX_VARCHAR_LENGTH = 256 + + +class Base(DeclarativeBase): + """Declarative base for memory schema tables.""" + + pass + + +class StorageMemoryEntry(Base): + """ORM model for the adk_memory_entries table.""" + + __tablename__ = "adk_memory_entries" + + id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False, index=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False, index=True + ) + session_id: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True + ) + event_id: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True + ) + author: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True + ) + timestamp: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + content_json: Mapped[Any] = mapped_column(DynamicJSON, nullable=True) + search_text: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + custom_metadata: Mapped[Any] = mapped_column( + MutableDict.as_mutable(DynamicJSON), nullable=True + ) + created_at: Mapped[Any] = mapped_column( + PreciseTimestamp, server_default=func.now() + ) + + __table_args__ = ( + Index("ix_memory_entries_app_user", "app_name", "user_id"), + Index("ix_memory_entries_session", "app_name", "user_id", "session_id"), + ) + + +class StorageScratchpadKV(Base): + """ORM model for the adk_scratchpad_kv table. + + Composite PK: (app_name, user_id, session_id, key). + Use session_id='' as a sentinel for user-level (non-session) KV. + """ + + __tablename__ = "adk_scratchpad_kv" + + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + key: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + value_json: Mapped[Any] = mapped_column(DynamicJSON, nullable=False) + updated_at: Mapped[Any] = mapped_column( + PreciseTimestamp, + server_default=func.now(), + onupdate=func.now(), + ) + + +class StorageScratchpadLog(Base): + """ORM model for the adk_scratchpad_log table. + + Append-only. id is autoincrement int to preserve insertion order. + Use session_id='' as a sentinel for user-level (non-session) log. + """ + + __tablename__ = "adk_scratchpad_log" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + app_name: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False + ) + user_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False + ) + session_id: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=False + ) + tag: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True, index=True + ) + agent_name: Mapped[Optional[str]] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), nullable=True + ) + content: Mapped[str] = mapped_column(Text, nullable=False) + extra_json: Mapped[Optional[Any]] = mapped_column(DynamicJSON, nullable=True) + created_at: Mapped[Any] = mapped_column( + PreciseTimestamp, server_default=func.now() + ) + + __table_args__ = ( + Index( + "ix_scratchpad_log_scope", + "app_name", + "user_id", + "session_id", + ), + ) diff --git a/src/google/adk/tools/scratchpad_tool.py b/src/google/adk/tools/scratchpad_tool.py new file mode 100644 index 0000000000..caebdb059e --- /dev/null +++ b/src/google/adk/tools/scratchpad_tool.py @@ -0,0 +1,246 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Agent-callable tools for reading/writing the scratchpad.""" + +from __future__ import annotations + +from typing import Any +from typing import Optional + +from google.genai import types +from typing_extensions import override + +from .base_tool import BaseTool +from .tool_context import ToolContext + + +def _get_db_memory_service(tool_context: ToolContext): + """Return the DatabaseMemoryService from the invocation context, or raise.""" + # Import here to avoid circular imports at module load time. + # pylint: disable=g-import-not-at-top + from ..memory.database_memory_service import DatabaseMemoryService + + svc = tool_context._invocation_context.memory_service + if not isinstance(svc, DatabaseMemoryService): + raise ValueError( + "Scratchpad tools require the agent's memory_service to be a " + f"DatabaseMemoryService, got: {type(svc).__name__}" + ) + return svc + + +def _session_scope(tool_context: ToolContext) -> tuple[str, str, str]: + """Return (app_name, user_id, session_id) from the invocation context.""" + ic = tool_context._invocation_context + return ic.app_name, ic.session.user_id, ic.session.id + + +class ScratchpadGetTool(BaseTool): + """Read a value from the agent scratchpad KV store.""" + + def __init__(self): + super().__init__( + name="scratchpad_get", + description=( + "Read a value stored in the scratchpad KV store by key." + " Returns null if the key does not exist." + ), + ) + + @override + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "key": types.Schema( + type=types.Type.STRING, + description="The key to read.", + ), + }, + required=["key"], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + svc = _get_db_memory_service(tool_context) + app_name, user_id, session_id = _session_scope(tool_context) + return await svc.get_scratchpad( + app_name=app_name, + user_id=user_id, + session_id=session_id, + key=args["key"], + ) + + +class ScratchpadSetTool(BaseTool): + """Write a value to the agent scratchpad KV store.""" + + def __init__(self): + super().__init__( + name="scratchpad_set", + description=( + "Write a value to the scratchpad KV store. " + "Overwrites any existing value for the same key." + ), + ) + + @override + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "key": types.Schema( + type=types.Type.STRING, + description="The key to write.", + ), + "value": types.Schema( + description=( + "The value to store (any JSON-serialisable type)." + ), + ), + }, + required=["key", "value"], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> str: + svc = _get_db_memory_service(tool_context) + app_name, user_id, session_id = _session_scope(tool_context) + await svc.set_scratchpad( + app_name=app_name, + user_id=user_id, + session_id=session_id, + key=args["key"], + value=args["value"], + ) + return "ok" + + +class ScratchpadAppendLogTool(BaseTool): + """Append an observation or note to the agent scratchpad log.""" + + def __init__(self): + super().__init__( + name="scratchpad_append_log", + description=( + "Append a text observation or note to the scratchpad log. " + "Entries are stored in insertion order and can be filtered by tag." + ), + ) + + @override + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "content": types.Schema( + type=types.Type.STRING, + description="The text content to log.", + ), + "tag": types.Schema( + type=types.Type.STRING, + description="Optional category label for filtering.", + ), + }, + required=["content"], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> str: + svc = _get_db_memory_service(tool_context) + app_name, user_id, session_id = _session_scope(tool_context) + await svc.append_log( + app_name=app_name, + user_id=user_id, + session_id=session_id, + content=args["content"], + tag=args.get("tag"), + agent_name=tool_context.agent_name, + ) + return "ok" + + +class ScratchpadGetLogTool(BaseTool): + """Read recent entries from the agent scratchpad log.""" + + def __init__(self): + super().__init__( + name="scratchpad_get_log", + description=( + "Read recent entries from the scratchpad log, " + "optionally filtered by tag." + ), + ) + + @override + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "tag": types.Schema( + type=types.Type.STRING, + description="Optional category label to filter by.", + ), + "limit": types.Schema( + type=types.Type.INTEGER, + description=( + "Maximum number of entries to return (default 50)." + ), + ), + }, + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> list[dict]: + svc = _get_db_memory_service(tool_context) + app_name, user_id, session_id = _session_scope(tool_context) + return await svc.get_log( + app_name=app_name, + user_id=user_id, + session_id=session_id, + tag=args.get("tag"), + limit=int(args.get("limit", 50)), + ) + + +# Ready-to-use singleton instances +scratchpad_get_tool = ScratchpadGetTool() +scratchpad_set_tool = ScratchpadSetTool() +scratchpad_append_log_tool = ScratchpadAppendLogTool() +scratchpad_get_log_tool = ScratchpadGetLogTool() diff --git a/tests/unittests/memory/test_database_memory_service.py b/tests/unittests/memory/test_database_memory_service.py new file mode 100644 index 0000000000..38e74848fc --- /dev/null +++ b/tests/unittests/memory/test_database_memory_service.py @@ -0,0 +1,734 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for DatabaseMemoryService.""" + +from __future__ import annotations + +from collections.abc import Sequence +import time +from typing import Any +from unittest.mock import MagicMock + +from google.adk.events.event import Event +from google.adk.memory.base_memory_service import SearchMemoryResponse +from google.adk.memory.database_memory_service import DatabaseMemoryService +from google.adk.memory.memory_entry import MemoryEntry +from google.adk.memory.memory_search_backend import MemorySearchBackend +from google.adk.memory.schemas.memory_schema import StorageMemoryEntry +from google.adk.sessions.session import Session +from google.genai import types +import pytest +import pytest_asyncio + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_DB_URL = "sqlite+aiosqlite:///:memory:" +_APP = "test_app" +_USER = "user_1" +_SESSION = "session_1" + + +def _make_content(text: str) -> types.Content: + return types.Content(role="user", parts=[types.Part(text=text)]) + + +def _make_event( + text: str, event_id: str = "ev1", author: str = "user" +) -> Event: + return Event( + id=event_id, + author=author, + content=_make_content(text), + timestamp=time.time(), + invocation_id="inv1", + ) + + +def _make_session(events: list[Event], session_id: str = _SESSION) -> Session: + return Session( + id=session_id, + app_name=_APP, + user_id=_USER, + events=events, + ) + + +@pytest.fixture +def svc() -> DatabaseMemoryService: + return DatabaseMemoryService(_DB_URL) + + +# --------------------------------------------------------------------------- +# 1. add_session_to_memory — filters empty events, persists content/author/ts +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_session_to_memory_persists_text_events(svc): + session = _make_session([_make_event("hello world")]) + await svc.add_session_to_memory(session) + + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="hello") + assert len(resp.memories) == 1 + assert resp.memories[0].author == "user" + assert resp.memories[0].timestamp is not None + + +@pytest.mark.asyncio +async def test_add_session_to_memory_skips_empty_events(svc): + empty_event = Event( + id="empty", + author="user", + content=types.Content(role="user", parts=[]), + timestamp=time.time(), + invocation_id="inv1", + ) + session = _make_session([empty_event]) + await svc.add_session_to_memory(session) + + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="anything") + assert resp.memories == [] + + +# --------------------------------------------------------------------------- +# 2. Re-ingest same session → idempotent (no duplicates) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_session_to_memory_idempotent(svc): + session = _make_session([_make_event("idempotent test")]) + await svc.add_session_to_memory(session) + await svc.add_session_to_memory(session) + + resp = await svc.search_memory( + app_name=_APP, user_id=_USER, query="idempotent" + ) + assert len(resp.memories) == 1 + + +# --------------------------------------------------------------------------- +# 3. add_events_to_memory — delta, skips duplicate event_id +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_events_to_memory_delta(svc): + ev = _make_event("delta event", event_id="ev_delta") + await svc.add_events_to_memory( + app_name=_APP, + user_id=_USER, + events=[ev], + session_id=_SESSION, + ) + # Second call with same event_id should not create a duplicate + await svc.add_events_to_memory( + app_name=_APP, + user_id=_USER, + events=[ev], + session_id=_SESSION, + ) + + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="delta") + assert len(resp.memories) == 1 + + +@pytest.mark.asyncio +async def test_add_events_to_memory_skips_empty(svc): + empty = Event( + id="empty2", + author="agent", + content=types.Content(role="model", parts=[]), + timestamp=time.time(), + invocation_id="inv1", + ) + await svc.add_events_to_memory( + app_name=_APP, user_id=_USER, events=[empty], session_id=_SESSION + ) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="anything") + assert resp.memories == [] + + +# --------------------------------------------------------------------------- +# 4. add_memory — direct MemoryEntry persist, auto-UUID +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_memory_direct(svc): + entry = MemoryEntry( + content=_make_content("direct memory fact"), + author="system", + ) + await svc.add_memory(app_name=_APP, user_id=_USER, memories=[entry]) + + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="direct") + assert len(resp.memories) == 1 + assert resp.memories[0].author == "system" + assert resp.memories[0].id is not None # auto-UUID + + +@pytest.mark.asyncio +async def test_add_memory_preserves_explicit_id(svc): + entry = MemoryEntry( + id="explicit-id-123", + content=_make_content("explicit id memory"), + ) + await svc.add_memory(app_name=_APP, user_id=_USER, memories=[entry]) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="explicit") + assert resp.memories[0].id == "explicit-id-123" + + +# --------------------------------------------------------------------------- +# 5. search_memory — AND match, OR fallback, no results for empty query +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_search_and_match(svc): + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[MemoryEntry(content=_make_content("cats and dogs"))], + ) + # Both tokens present → AND match + resp = await svc.search_memory( + app_name=_APP, user_id=_USER, query="cats dogs" + ) + assert len(resp.memories) == 1 + + +@pytest.mark.asyncio +async def test_search_or_fallback(svc): + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[MemoryEntry(content=_make_content("cats are great"))], + ) + # Only one token matches → OR fallback should still find it + resp = await svc.search_memory( + app_name=_APP, user_id=_USER, query="cats fish" + ) + assert len(resp.memories) == 1 + + +@pytest.mark.asyncio +async def test_search_empty_query_returns_empty(svc): + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[MemoryEntry(content=_make_content("something"))], + ) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="") + assert resp.memories == [] + + +@pytest.mark.asyncio +async def test_search_no_match(svc): + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[MemoryEntry(content=_make_content("hello world"))], + ) + resp = await svc.search_memory( + app_name=_APP, user_id=_USER, query="zzznomatch" + ) + assert resp.memories == [] + + +# --------------------------------------------------------------------------- +# 6. Scratchpad KV: set/get/overwrite/delete/list +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_kv_set_get(svc): + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="k1", value="v1" + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="k1" + ) + assert val == "v1" + + +@pytest.mark.asyncio +async def test_scratchpad_kv_overwrite(svc): + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="k2", value="old" + ) + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="k2", value="new" + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="k2" + ) + assert val == "new" + + +@pytest.mark.asyncio +async def test_scratchpad_kv_missing_returns_none(svc): + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="nonexistent" + ) + assert val is None + + +@pytest.mark.asyncio +async def test_scratchpad_kv_delete(svc): + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="k3", value="v3" + ) + await svc.delete_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="k3" + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="k3" + ) + assert val is None + + +@pytest.mark.asyncio +async def test_scratchpad_kv_list_keys(svc): + for k in ("a", "b", "c"): + await svc.set_scratchpad( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + key=k, + value=k, + ) + keys = await svc.list_scratchpad_keys( + app_name=_APP, user_id=_USER, session_id=_SESSION + ) + assert set(keys) == {"a", "b", "c"} + + +@pytest.mark.asyncio +async def test_scratchpad_kv_json_types(svc): + payload = {"nested": [1, 2, 3], "flag": True} + await svc.set_scratchpad( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + key="json_key", + value=payload, + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="json_key" + ) + assert val == payload + + +# --------------------------------------------------------------------------- +# 7. Scratchpad log: append/get, filter by tag, limit +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_log_append_get(svc): + await svc.append_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, content="entry 1" + ) + await svc.append_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, content="entry 2" + ) + entries = await svc.get_log(app_name=_APP, user_id=_USER, session_id=_SESSION) + assert len(entries) == 2 + assert entries[0]["content"] == "entry 1" + assert entries[1]["content"] == "entry 2" + + +@pytest.mark.asyncio +async def test_scratchpad_log_filter_by_tag(svc): + await svc.append_log( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + content="tagged", + tag="mytag", + ) + await svc.append_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, content="untagged" + ) + tagged = await svc.get_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, tag="mytag" + ) + assert len(tagged) == 1 + assert tagged[0]["content"] == "tagged" + + +@pytest.mark.asyncio +async def test_scratchpad_log_limit(svc): + for i in range(10): + await svc.append_log( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + content=f"msg {i}", + ) + entries = await svc.get_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, limit=3 + ) + assert len(entries) == 3 + + +# --------------------------------------------------------------------------- +# 8. Custom search backend +# --------------------------------------------------------------------------- + + +class _AlwaysReturnOneBackend(MemorySearchBackend): + """Stub backend that always returns a single hard-coded row.""" + + async def search( + self, + *, + sql_session, + app_name, + user_id, + query, + limit=10, + ) -> Sequence[StorageMemoryEntry]: + row = StorageMemoryEntry( + id="stub-id", + app_name=app_name, + user_id=user_id, + content_json={"role": "user", "parts": [{"text": "stub result"}]}, + author="stub", + timestamp=None, + custom_metadata={}, + ) + return [row] + + +@pytest.mark.asyncio +async def test_custom_search_backend(): + svc = DatabaseMemoryService(_DB_URL, search_backend=_AlwaysReturnOneBackend()) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="anything") + assert len(resp.memories) == 1 + assert resp.memories[0].id == "stub-id" + assert resp.memories[0].author == "stub" + + +# --------------------------------------------------------------------------- +# 9. Engine construction errors raise ValueError +# --------------------------------------------------------------------------- + + +def test_bad_url_raises_value_error(): + with pytest.raises(ValueError, match="Invalid database URL"): + DatabaseMemoryService("not_a_valid_url://") + + +def test_missing_driver_raises_value_error(): + with pytest.raises(ValueError): + # Use a driver that definitely isn't installed + DatabaseMemoryService("sqlite+nonexistentdriver:///:memory:") + + +# --------------------------------------------------------------------------- +# 10. Multi-user isolation — user A results must not leak to user B +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_search_user_isolation(svc): + """User B's search must not return user A's memories even on matching text.""" + await svc.add_memory( + app_name=_APP, + user_id="user_a", + memories=[MemoryEntry(content=_make_content("secret data alpha"))], + ) + + # User B has no memories; searching the same text should yield nothing. + resp = await svc.search_memory( + app_name=_APP, user_id="user_b", query="secret" + ) + assert resp.memories == [], "User B should not see user A's memories" + + +@pytest.mark.asyncio +async def test_add_session_user_isolation(svc): + """Session ingestion from user A must not appear in user B's search.""" + session_a = Session( + id="sess_a", + app_name=_APP, + user_id="user_a", + events=[_make_event("shared keyword")], + ) + await svc.add_session_to_memory(session_a) + + resp = await svc.search_memory( + app_name=_APP, user_id="user_b", query="shared" + ) + assert resp.memories == [] + + +# --------------------------------------------------------------------------- +# 11. Scratchpad KV scoping — session A key invisible in session B or user-level +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_kv_session_scoping(svc): + """A key written for s1 must not be visible under s2 or the user-level scope.""" + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id="s1", key="scoped", value="yes" + ) + + val_s2 = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id="s2", key="scoped" + ) + assert val_s2 is None, "Key from s1 must not appear in s2" + + val_user = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id="", key="scoped" + ) + assert val_user is None, "Key from s1 must not appear in user-level scope" + + +@pytest.mark.asyncio +async def test_scratchpad_log_session_scoping(svc): + """Log entries appended to s1 must not appear when querying s2.""" + await svc.append_log( + app_name=_APP, + user_id=_USER, + session_id="s1", + content="session-one log", + ) + + entries = await svc.get_log(app_name=_APP, user_id=_USER, session_id="s2") + assert entries == [], "Log from s1 must not appear in s2" + + +# --------------------------------------------------------------------------- +# 12. add_memory with custom_metadata — verify merge with entry.custom_metadata +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_add_memory_custom_metadata_merge(svc): + """custom_metadata passed to add_memory should merge with entry.custom_metadata.""" + entry = MemoryEntry( + content=_make_content("metadata test"), + author="agent", + custom_metadata={"entry_key": "entry_val"}, + ) + await svc.add_memory( + app_name=_APP, + user_id=_USER, + memories=[entry], + custom_metadata={"call_key": "call_val"}, + ) + resp = await svc.search_memory(app_name=_APP, user_id=_USER, query="metadata") + assert len(resp.memories) == 1 + meta = resp.memories[0].custom_metadata + assert ( + meta.get("entry_key") == "entry_val" + ), "entry.custom_metadata must survive" + assert ( + meta.get("call_key") == "call_val" + ), "call-site custom_metadata must survive" + + +# --------------------------------------------------------------------------- +# 13. delete_scratchpad no-op — deleting a non-existent key must not raise +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_delete_noop(svc): + """Deleting a key that does not exist must be a silent no-op.""" + # Should not raise any exception. + await svc.delete_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="ghost" + ) + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="ghost" + ) + assert val is None + + +# --------------------------------------------------------------------------- +# 14. list_scratchpad_keys on empty scope returns [] +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_scratchpad_list_keys_empty_scope(svc): + """list_scratchpad_keys on a brand-new scope returns an empty list.""" + keys = await svc.list_scratchpad_keys( + app_name=_APP, user_id=_USER, session_id="brand_new_session" + ) + assert keys == [] + + +# --------------------------------------------------------------------------- +# 15. Scratchpad tool tests — all 4 BaseTool subclasses +# --------------------------------------------------------------------------- + +# Helper: build a minimal mock ToolContext backed by a real DatabaseMemoryService. + + +def _make_tool_context(svc: DatabaseMemoryService, session_id: str = _SESSION): + """Return a ToolContext mock wired to *svc*.""" + session_mock = MagicMock() + session_mock.user_id = _USER + session_mock.id = session_id + + ic_mock = MagicMock() + ic_mock.app_name = _APP + ic_mock.session = session_mock + ic_mock.memory_service = svc + + ctx = MagicMock() + ctx._invocation_context = ic_mock + ctx.agent_name = "test_agent" + return ctx + + +@pytest.mark.asyncio +async def test_scratchpad_set_tool_happy_path(svc): + from google.adk.tools.scratchpad_tool import ScratchpadSetTool + + tool = ScratchpadSetTool() + ctx = _make_tool_context(svc) + result = await tool.run_async( + args={"key": "tool_key", "value": "tool_value"}, tool_context=ctx + ) + assert result == "ok" + # Verify persistence. + val = await svc.get_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="tool_key" + ) + assert val == "tool_value" + + +@pytest.mark.asyncio +async def test_scratchpad_get_tool_happy_path(svc): + from google.adk.tools.scratchpad_tool import ScratchpadGetTool + + await svc.set_scratchpad( + app_name=_APP, user_id=_USER, session_id=_SESSION, key="gt_key", value=42 + ) + tool = ScratchpadGetTool() + ctx = _make_tool_context(svc) + val = await tool.run_async(args={"key": "gt_key"}, tool_context=ctx) + assert val == 42 + + +@pytest.mark.asyncio +async def test_scratchpad_append_log_tool_happy_path(svc): + from google.adk.tools.scratchpad_tool import ScratchpadAppendLogTool + + tool = ScratchpadAppendLogTool() + ctx = _make_tool_context(svc) + result = await tool.run_async( + args={"content": "observation logged", "tag": "obs"}, tool_context=ctx + ) + assert result == "ok" + entries = await svc.get_log( + app_name=_APP, user_id=_USER, session_id=_SESSION, tag="obs" + ) + assert len(entries) == 1 + assert entries[0]["content"] == "observation logged" + assert entries[0]["agent_name"] == "test_agent" + + +@pytest.mark.asyncio +async def test_scratchpad_get_log_tool_happy_path(svc): + from google.adk.tools.scratchpad_tool import ScratchpadGetLogTool + + for i in range(5): + await svc.append_log( + app_name=_APP, + user_id=_USER, + session_id=_SESSION, + content=f"log {i}", + ) + tool = ScratchpadGetLogTool() + ctx = _make_tool_context(svc) + entries = await tool.run_async(args={"limit": 3}, tool_context=ctx) + assert len(entries) == 3 + + +# --------------------------------------------------------------------------- +# 15b. Wrong-service-type error for all 4 tools +# --------------------------------------------------------------------------- + + +def _make_wrong_service_context(): + """Return a ToolContext backed by a plain InMemoryMemoryService.""" + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + session_mock = MagicMock() + session_mock.user_id = _USER + session_mock.id = _SESSION + + ic_mock = MagicMock() + ic_mock.app_name = _APP + ic_mock.session = session_mock + ic_mock.memory_service = InMemoryMemoryService() + + ctx = MagicMock() + ctx._invocation_context = ic_mock + ctx.agent_name = "test_agent" + return ctx + + +@pytest.mark.asyncio +async def test_scratchpad_get_tool_wrong_service(): + from google.adk.tools.scratchpad_tool import ScratchpadGetTool + + tool = ScratchpadGetTool() + with pytest.raises(ValueError, match="DatabaseMemoryService"): + await tool.run_async( + args={"key": "x"}, tool_context=_make_wrong_service_context() + ) + + +@pytest.mark.asyncio +async def test_scratchpad_set_tool_wrong_service(): + from google.adk.tools.scratchpad_tool import ScratchpadSetTool + + tool = ScratchpadSetTool() + with pytest.raises(ValueError, match="DatabaseMemoryService"): + await tool.run_async( + args={"key": "x", "value": 1}, + tool_context=_make_wrong_service_context(), + ) + + +@pytest.mark.asyncio +async def test_scratchpad_append_log_tool_wrong_service(): + from google.adk.tools.scratchpad_tool import ScratchpadAppendLogTool + + tool = ScratchpadAppendLogTool() + with pytest.raises(ValueError, match="DatabaseMemoryService"): + await tool.run_async( + args={"content": "x"}, tool_context=_make_wrong_service_context() + ) + + +@pytest.mark.asyncio +async def test_scratchpad_get_log_tool_wrong_service(): + from google.adk.tools.scratchpad_tool import ScratchpadGetLogTool + + tool = ScratchpadGetLogTool() + with pytest.raises(ValueError, match="DatabaseMemoryService"): + await tool.run_async(args={}, tool_context=_make_wrong_service_context())