Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions src/google/adk/memory/vertex_ai_rag_memory_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from __future__ import annotations

import base64
import binascii
from collections import OrderedDict
import json
import os
Expand All @@ -35,6 +37,58 @@
from ..sessions.session import Session


_SOURCE_DISPLAY_NAME_PREFIX = "adk-memory-v1."


def _encode_source_display_name_part(value: str) -> str:
return (
base64.urlsafe_b64encode(value.encode("utf-8"))
.decode("ascii")
.rstrip("=")
)


def _decode_source_display_name_part(value: str) -> str:
padded_value = value + "=" * (-len(value) % 4)
return base64.b64decode(
padded_value.encode("ascii"), altchars=b"-_", validate=True
).decode("utf-8")


def _build_source_display_name(
app_name: str, user_id: str, session_id: str
) -> str:
return _SOURCE_DISPLAY_NAME_PREFIX + ".".join([
_encode_source_display_name_part(app_name),
_encode_source_display_name_part(user_id),
_encode_source_display_name_part(session_id),
])


def _parse_source_display_name(
source_display_name: str,
) -> tuple[str, str, str] | None:
if source_display_name.startswith(_SOURCE_DISPLAY_NAME_PREFIX):
parts = source_display_name[len(_SOURCE_DISPLAY_NAME_PREFIX) :].split(".")
if len(parts) != 3:
return None
try:
return (
_decode_source_display_name_part(parts[0]),
_decode_source_display_name_part(parts[1]),
_decode_source_display_name_part(parts[2]),
)
except (binascii.Error, UnicodeDecodeError, UnicodeEncodeError):
return None

# Legacy display names were dot-delimited. Only the exact three-part form is
# unambiguous, so dotted app/user/session IDs are intentionally ignored.
parts = source_display_name.split(".")
if len(parts) != 3:
return None
return parts[0], parts[1], parts[2]


class VertexAiRagMemoryService(BaseMemoryService):
"""A memory service that uses Vertex AI RAG for storage and retrieval."""

Expand Down Expand Up @@ -63,7 +117,7 @@ def __init__(
)

@override
async def add_session_to_memory(self, session: Session):
async def add_session_to_memory(self, session: Session) -> None:
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=".txt"
) as temp_file:
Expand Down Expand Up @@ -100,7 +154,9 @@ async def add_session_to_memory(self, session: Session):
path=temp_file_path,
# this is the temp workaround as upload file does not support
# adding metadata, thus use display_name to store the session info.
display_name=f"{session.app_name}.{session.user_id}.{session.id}",
display_name=_build_source_display_name(
session.app_name, session.user_id, session.id
),
)

os.remove(temp_file_path)
Expand All @@ -122,13 +178,19 @@ async def search_memory(
)

memory_results = []
session_events_map = OrderedDict()
session_events_map: OrderedDict[str, list[list[Event]]] = OrderedDict()
for context in response.contexts.contexts:
# filter out context that is not related
# TODO: Add server side filtering by app_name and user_id.
if not context.source_display_name.startswith(f"{app_name}.{user_id}."):
source_display_name = getattr(context, "source_display_name", "")
if not isinstance(source_display_name, str):
continue
session_info = _parse_source_display_name(source_display_name)
if not session_info:
continue
source_app_name, source_user_id, session_id = session_info
if source_app_name != app_name or source_user_id != user_id:
continue
session_id = context.source_display_name.split(".")[-1]
events = []
if context.text:
lines = context.text.split("\n")
Expand Down
117 changes: 117 additions & 0 deletions tests/unittests/memory/test_vertex_ai_rag_memory_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from types import SimpleNamespace

from google.adk.events.event import Event
from google.adk.memory.vertex_ai_rag_memory_service import _build_source_display_name
from google.adk.memory.vertex_ai_rag_memory_service import _SOURCE_DISPLAY_NAME_PREFIX
from google.adk.memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
from google.adk.sessions.session import Session
from google.genai import types
import pytest


def _rag_context(source_display_name: str, text: str) -> SimpleNamespace:
return SimpleNamespace(
source_display_name=source_display_name,
text=json.dumps({"author": "user", "timestamp": 1, "text": text}),
)


@pytest.mark.asyncio
async def test_search_memory_rejects_ambiguous_legacy_display_names(mocker):
"""Ensures dotted user IDs cannot match another user's legacy memory."""
memory_service = VertexAiRagMemoryService(rag_corpus="unused")
fake_rag = SimpleNamespace(
retrieval_query=mocker.Mock(
return_value=SimpleNamespace(
contexts=SimpleNamespace(
contexts=[
_rag_context(
"demo.alice.smith.session_secret",
"SECRET_FROM_ALICE_SMITH",
),
_rag_context(
_build_source_display_name(
"demo", "alice", "session_ok"
),
"NORMAL_ALICE_MEMORY",
),
_rag_context(
"demo.alice.legacy_session",
"LEGACY_ALICE_MEMORY",
),
_rag_context("demo.bob.session_other", "BOB_MEMORY"),
]
)
)
)
)
mocker.patch("google.adk.dependencies.vertexai.rag", fake_rag)

response = await memory_service.search_memory(
app_name="demo", user_id="alice", query="secret"
)

texts = [memory.content.parts[0].text for memory in response.memories]
assert texts == ["NORMAL_ALICE_MEMORY", "LEGACY_ALICE_MEMORY"]


@pytest.mark.asyncio
async def test_add_and_search_memory_uses_unambiguous_display_names(mocker):
memory_service = VertexAiRagMemoryService(rag_corpus="unused")
upload_file = mocker.Mock()
fake_rag = SimpleNamespace(upload_file=upload_file)
mocker.patch("google.adk.dependencies.vertexai.rag", fake_rag)

await memory_service.add_session_to_memory(
Session(
app_name="demo.app",
user_id="alice.smith",
id="session.secret",
last_update_time=1,
events=[
Event(
id="event-1",
author="user",
timestamp=1,
content=types.Content(
parts=[types.Part(text="sensitive memory")]
),
)
],
)
)

display_name = upload_file.call_args.kwargs["display_name"]
assert display_name.startswith(_SOURCE_DISPLAY_NAME_PREFIX)
assert display_name != "demo.app.alice.smith.session.secret"

fake_rag.retrieval_query = mocker.Mock(
return_value=SimpleNamespace(
contexts=SimpleNamespace(
contexts=[_rag_context(display_name, "sensitive memory")]
)
)
)

response = await memory_service.search_memory(
app_name="demo.app", user_id="alice.smith", query="sensitive"
)

assert [memory.content.parts[0].text for memory in response.memories] == [
"sensitive memory"
]