diff --git a/python/packages/mem0/agent_framework_mem0/_context_provider.py b/python/packages/mem0/agent_framework_mem0/_context_provider.py index c6be708990..87f97430f3 100644 --- a/python/packages/mem0/agent_framework_mem0/_context_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_context_provider.py @@ -8,29 +8,28 @@ from __future__ import annotations +import asyncio +import logging import sys +from collections.abc import Awaitable from contextlib import AbstractAsyncContextManager -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias from agent_framework import Message from agent_framework._sessions import AgentSession, ContextProvider, SessionContext from mem0 import AsyncMemory, AsyncMemoryClient if sys.version_info >= (3, 11): - from typing import NotRequired, Self, TypedDict # pragma: no cover + from typing import Self # pragma: no cover else: - from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover + from typing_extensions import Self # pragma: no cover if TYPE_CHECKING: from agent_framework._agents import SupportsAgentRun - -class _MemorySearchResponse_v1_1(TypedDict): - results: list[dict[str, Any]] - relations: NotRequired[list[dict[str, Any]]] - - -_MemorySearchResponse_v2 = list[dict[str, Any]] +logger = logging.getLogger(__name__) +MemoryRecord: TypeAlias = dict[str, Any] +SearchResponse: TypeAlias = list[MemoryRecord] | MemoryRecord class Mem0ContextProvider(ContextProvider): @@ -106,28 +105,76 @@ async def before_run( if not input_text.strip(): return - filters = self._build_filters() + # Query entity partitions independently to bypass strict logical AND limitations + # Mem0 OSS and Platform SDKs expose inconsistent search typings. + search_tasks: list[Awaitable[Any]] = [] - # AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs - # AsyncMemoryClient (Platform) expects them in a filters dict - search_kwargs: dict[str, Any] = {"query": input_text} - if isinstance(self.mem0_client, AsyncMemory): - search_kwargs.update(filters) - else: - search_kwargs["filters"] = filters - - search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc] - **search_kwargs, - ) + # 1. Query User partition independently + if self.user_id: + user_kwargs = self._build_search_kwargs(input_text, "user_id", self.user_id) + search_tasks.append(self.mem0_client.search(**user_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType] - if isinstance(search_response, list): - memories = search_response - elif isinstance(search_response, dict) and "results" in search_response: - memories = search_response["results"] - else: - memories = [search_response] + # 2. Query Agent partition independently + if self.agent_id: + agent_kwargs = self._build_search_kwargs(input_text, "agent_id", self.agent_id) + search_tasks.append(self.mem0_client.search(**agent_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType] + + # Fall back to an app-scoped search when only application_id is configured + if not search_tasks and self.application_id: + app_kwargs: dict[str, Any] = {"query": input_text} + if isinstance(self.mem0_client, AsyncMemory): + app_kwargs["app_id"] = self.application_id + else: + app_kwargs["filters"] = {"app_id": self.application_id} + search_tasks.append(self.mem0_client.search(**app_kwargs)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType] + if not search_tasks: + return - line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories) + results: list[SearchResponse | BaseException] = await asyncio.gather(*search_tasks, return_exceptions=True) + + # Merge and deduplicate results + memories: list[MemoryRecord] = [] + seen_memory_ids: set[str] = set() + failed_tasks_count: int = 0 + + for search_response in results: + if isinstance(search_response, BaseException): + failed_tasks_count += 1 + logger.error( + "Mem0 partition search task failed: %s", + search_response, + ) + continue + + current_memories: list[MemoryRecord] = [] + if isinstance(search_response, list): + current_memories = [mem for mem in search_response if isinstance(mem, dict)] + elif isinstance(search_response, dict): + results_field = search_response.get("results") + if isinstance(results_field, list): + current_memories = [ + item for item in results_field if isinstance(item, dict) # pyright: ignore[reportUnknownVariableType] + ] + else: + current_memories = [search_response] + + for mem in current_memories: + mem_id = mem.get("id") + if mem_id is not None and not isinstance(mem_id, str): + mem_id = str(mem_id) + + if mem_id is not None and mem_id in seen_memory_ids: + continue + + if mem_id is not None: + seen_memory_ids.add(mem_id) + + memories.append(mem) + + if failed_tasks_count == len(search_tasks): + logger.error("All Mem0 retrieval tasks failed. Context provider is unable to verify memory state.") + + line_separated_memories = "\n".join(str(memory.get("memory", "")) for memory in memories) if line_separated_memories: context.extend_messages( self.source_id, @@ -159,12 +206,21 @@ def get_role_value(role: Any) -> str: ] if messages: - await self.mem0_client.add( # type: ignore[misc] - messages=messages, - user_id=self.user_id, - agent_id=self.agent_id, - metadata={"application_id": self.application_id}, - ) + add_kwargs: dict[str, Any] = { + "messages": messages, + "user_id": self.user_id, + "agent_id": self.agent_id, + } + + # Inject the application scope using the matching signature format for each SDK variant + if isinstance(self.mem0_client, AsyncMemory): + if self.application_id: + add_kwargs["app_id"] = self.application_id + else: + if self.application_id: + add_kwargs["filters"] = {"app_id": self.application_id} + + await self.mem0_client.add(**add_kwargs) # type: ignore[misc, call-arg] # -- Internal methods ------------------------------------------------------ @@ -173,15 +229,21 @@ def _validate_filters(self) -> None: if not self.agent_id and not self.user_id and not self.application_id: raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.") - def _build_filters(self) -> dict[str, Any]: - """Build search filters from initialization parameters.""" - filters: dict[str, Any] = {} - if self.user_id: - filters["user_id"] = self.user_id - if self.agent_id: - filters["agent_id"] = self.agent_id - if self.application_id: - filters["app_id"] = self.application_id + def _build_search_kwargs(self, input_text: str, entity_key: str, entity_value: str) -> dict[str, Any]: + """Build search keyword arguments formatted for OSS vs Platform clients.""" + filters: dict[str, Any] = {"query": input_text} + + if isinstance(self.mem0_client, AsyncMemory): + # AsyncMemory (OSS) expects direct kwargs + filters[entity_key] = entity_value + if self.application_id: + filters["app_id"] = self.application_id + else: + # AsyncMemoryClient (Platform) expects a filters dict + filters["filters"] = {entity_key: entity_value} + if self.application_id: + filters["filters"]["app_id"] = self.application_id + return filters diff --git a/python/packages/mem0/tests/test_mem0_context_provider.py b/python/packages/mem0/tests/test_mem0_context_provider.py index bf40577878..a8938ae0e4 100644 --- a/python/packages/mem0/tests/test_mem0_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_context_provider.py @@ -3,7 +3,7 @@ from __future__ import annotations -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from agent_framework import AgentResponse, Message @@ -193,39 +193,61 @@ async def test_oss_client_passes_direct_kwargs(self, mock_oss_mem0_client: Async assert call_kwargs["user_id"] == "u1" assert "filters" not in call_kwargs + @pytest.mark.asyncio async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None: - """OSS client with all scoping parameters passes them as direct kwargs.""" + """OSS client with all scoping parameters passes them as isolated concurrent kwargs.""" mock_oss_mem0_client.search.return_value = [] + provider = Mem0ContextProvider( - source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1" + source_id="mem0", + mem0_client=mock_oss_mem0_client, + user_id="u1", + agent_id="a1", + # application_id="app1" ) - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") + + mock_context = MagicMock(spec=SessionContext) + mock_msg = MagicMock() + mock_msg.text = "hello" + mock_context.input_messages = [mock_msg] + mock_context.response = None await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) - ) # type: ignore[arg-type] + agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={} + ) - call_kwargs = mock_oss_mem0_client.search.call_args.kwargs - assert call_kwargs["user_id"] == "u1" - assert call_kwargs["agent_id"] == "a1" - assert "filters" not in call_kwargs + # Re-aligned assertion: We expect 2 separate concurrent calls instead of 1 combined call + assert mock_oss_mem0_client.search.call_count == 2 + mock_oss_mem0_client.search.assert_any_call(query="hello", user_id="u1") + mock_oss_mem0_client.search.assert_any_call(query="hello", agent_id="a1") + @pytest.mark.asyncio async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None: - """Platform AsyncMemoryClient should receive scoping params in a filters dict.""" + """Platform client passes scoping parameters concurrently inside the nested filters dictionary.""" mock_mem0_client.search.return_value = [] - provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - session = AgentSession(session_id="test-session") - ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1") + + provider = Mem0ContextProvider( + source_id="mem0", + mem0_client=mock_mem0_client, + user_id="u1", + agent_id="a1", + # application_id="app1" + ) + + mock_context = MagicMock(spec=SessionContext) + mock_msg = MagicMock() + mock_msg.text = "hello" + mock_context.input_messages = [mock_msg] + mock_context.response = None await provider.before_run( - agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) - ) # type: ignore[arg-type] + agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={} + ) - call_kwargs = mock_mem0_client.search.call_args.kwargs - assert call_kwargs["query"] == "Hello" - assert "filters" in call_kwargs - assert call_kwargs["filters"]["user_id"] == "u1" + # Re-aligned assertion: Platform client isolates filters per call to bypass AND limitations + assert mock_mem0_client.search.call_count == 2 + mock_mem0_client.search.assert_any_call(query="hello", filters={"user_id": "u1"}) + mock_mem0_client.search.assert_any_call(query="hello", filters={"agent_id": "a1"}) # -- after_run tests ----------------------------------------------------------- @@ -331,7 +353,7 @@ async def test_stores_with_application_id_metadata(self, mock_mem0_client: Async agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {}) ) # type: ignore[arg-type] - assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"} + assert mock_mem0_client.add.call_args.kwargs["filters"] == {"app_id": "app1"} # -- _validate_filters tests -------------------------------------------------- @@ -358,15 +380,20 @@ def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None: provider._validate_filters() -# -- _build_filters tests ----------------------------------------------------- +# -- _build_search_kwargs tests ----------------------------------------------------- -class TestBuildFilters: - """Test _build_filters method.""" +class TestBuildSearchKwargs: + """Test _build_search_kwargs method.""" def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None: provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - assert provider._build_filters() == {"user_id": "u1"} + + # Pass the 3 required arguments + result = provider._build_search_kwargs("test query", "user_id", "u1") + + # AsyncMock triggers the Platform client nested 'filters' structure + assert result == {"query": "test query", "filters": {"user_id": "u1"}} def test_all_params(self, mock_mem0_client: AsyncMock) -> None: provider = Mem0ContextProvider( @@ -376,28 +403,66 @@ def test_all_params(self, mock_mem0_client: AsyncMock) -> None: agent_id="a1", application_id="app1", ) - assert provider._build_filters() == { - "user_id": "u1", - "agent_id": "a1", - "app_id": "app1", + + # Test that app_id correctly merges with the isolated target entity + result = provider._build_search_kwargs("test query", "agent_id", "a1") + + assert result == { + "query": "test query", + "filters": { + "agent_id": "a1", + "app_id": "app1", + }, } def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None: provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - filters = provider._build_filters() - assert "agent_id" not in filters - assert "run_id" not in filters - assert "app_id" not in filters + + # application_id is None by default, it should not appear in the dictionary + result = provider._build_search_kwargs("test query", "user_id", "u1") + + assert "app_id" not in result.get("filters", {}) def test_no_run_id_in_search_filters(self, mock_mem0_client: AsyncMock) -> None: """run_id is excluded from search filters so memories work across sessions.""" provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1") - filters = provider._build_filters() - assert "run_id" not in filters + + result = provider._build_search_kwargs("test query", "user_id", "u1") + + assert "run_id" not in result.get("filters", {}) + assert "run_id" not in result def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None: + # Validates base query payload generation provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client) - assert provider._build_filters() == {} + + result = provider._build_search_kwargs("test query", "custom_key", "custom_val") + + assert result == {"query": "test query", "filters": {"custom_key": "custom_val"}} + + @pytest.mark.asyncio + async def test_before_run_application_only_fallback(self, mock_mem0_client: AsyncMock) -> None: + + provider = Mem0ContextProvider( + source_id="mem0", mem0_client=mock_mem0_client, application_id="app_fallback_test" + ) + + # Mock a valid message list and session container setup + mock_context = MagicMock(spec=SessionContext) + mock_msg = MagicMock() + mock_msg.text = "Retrieve systemic fallback memory traces" + mock_context.input_messages = [mock_msg] + mock_context.response = None + + mock_mem0_client.search = AsyncMock(return_value=[{"id": "m1", "memory": "System configuration template"}]) + + await provider.before_run( + agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={} + ) + + # Verify that an application-scoped search task executed successfully + assert mock_mem0_client.search.call_count == 1 + mock_context.extend_messages.assert_called_once() # -- Context manager tests -----------------------------------------------------