Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 106 additions & 44 deletions python/packages/mem0/agent_framework_mem0/_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,28 @@

from __future__ import annotations

import asyncio
import logging
import sys
from collections.abc import Awaitable
from contextlib import AbstractAsyncContextManager
from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias

from agent_framework import Message
from agent_framework._sessions import AgentSession, ContextProvider, SessionContext
from mem0 import AsyncMemory, AsyncMemoryClient

if sys.version_info >= (3, 11):
from typing import NotRequired, Self, TypedDict # pragma: no cover
from typing import Self # pragma: no cover
else:
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
from typing_extensions import Self # pragma: no cover

if TYPE_CHECKING:
from agent_framework._agents import SupportsAgentRun


class _MemorySearchResponse_v1_1(TypedDict):
results: list[dict[str, Any]]
relations: NotRequired[list[dict[str, Any]]]


_MemorySearchResponse_v2 = list[dict[str, Any]]
logger = logging.getLogger(__name__)
MemoryRecord: TypeAlias = dict[str, Any]
SearchResponse: TypeAlias = list[MemoryRecord] | MemoryRecord


class Mem0ContextProvider(ContextProvider):
Expand Down Expand Up @@ -106,28 +105,76 @@ async def before_run(
if not input_text.strip():
return

filters = self._build_filters()
# Query entity partitions independently to bypass strict logical AND limitations
# Mem0 OSS and Platform SDKs expose inconsistent search typings.
search_tasks: list[Awaitable[Any]] = []

# AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs
# AsyncMemoryClient (Platform) expects them in a filters dict
search_kwargs: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
search_kwargs.update(filters)
else:
search_kwargs["filters"] = filters

search_response: _MemorySearchResponse_v1_1 | _MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc]
**search_kwargs,
)
# 1. Query User partition independently
if self.user_id:
user_kwargs = self._build_search_kwargs(input_text, "user_id", self.user_id)
search_tasks.append(self.mem0_client.search(**user_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]

if isinstance(search_response, list):
memories = search_response
elif isinstance(search_response, dict) and "results" in search_response:
memories = search_response["results"]
else:
memories = [search_response]
# 2. Query Agent partition independently
if self.agent_id:
agent_kwargs = self._build_search_kwargs(input_text, "agent_id", self.agent_id)
search_tasks.append(self.mem0_client.search(**agent_kwargs)) # type: ignore[reportUnknownMemberType, reportUnknownArgumentType]

# Fall back to an app-scoped search when only application_id is configured
if not search_tasks and self.application_id:
app_kwargs: dict[str, Any] = {"query": input_text}
if isinstance(self.mem0_client, AsyncMemory):
app_kwargs["app_id"] = self.application_id
else:
app_kwargs["filters"] = {"app_id": self.application_id}
search_tasks.append(self.mem0_client.search(**app_kwargs)) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType]
if not search_tasks:
return

line_separated_memories = "\n".join(memory.get("memory", "") for memory in memories)
results: list[SearchResponse | BaseException] = await asyncio.gather(*search_tasks, return_exceptions=True)

# Merge and deduplicate results
memories: list[MemoryRecord] = []
seen_memory_ids: set[str] = set()
failed_tasks_count: int = 0

for search_response in results:
if isinstance(search_response, BaseException):
failed_tasks_count += 1
logger.error(
"Mem0 partition search task failed: %s",
search_response,
)
continue

current_memories: list[MemoryRecord] = []
if isinstance(search_response, list):
current_memories = [mem for mem in search_response if isinstance(mem, dict)]
elif isinstance(search_response, dict):
results_field = search_response.get("results")
if isinstance(results_field, list):
current_memories = [
item for item in results_field if isinstance(item, dict) # pyright: ignore[reportUnknownVariableType]
]
else:
current_memories = [search_response]

for mem in current_memories:
mem_id = mem.get("id")
if mem_id is not None and not isinstance(mem_id, str):
mem_id = str(mem_id)

if mem_id is not None and mem_id in seen_memory_ids:
continue

if mem_id is not None:
seen_memory_ids.add(mem_id)

memories.append(mem)

if failed_tasks_count == len(search_tasks):
logger.error("All Mem0 retrieval tasks failed. Context provider is unable to verify memory state.")

line_separated_memories = "\n".join(str(memory.get("memory", "")) for memory in memories)
if line_separated_memories:
context.extend_messages(
self.source_id,
Expand Down Expand Up @@ -159,12 +206,21 @@ def get_role_value(role: Any) -> str:
]

if messages:
await self.mem0_client.add( # type: ignore[misc]
messages=messages,
user_id=self.user_id,
agent_id=self.agent_id,
metadata={"application_id": self.application_id},
)
add_kwargs: dict[str, Any] = {
"messages": messages,
"user_id": self.user_id,
"agent_id": self.agent_id,
}

# Inject the application scope using the matching signature format for each SDK variant
if isinstance(self.mem0_client, AsyncMemory):
if self.application_id:
add_kwargs["app_id"] = self.application_id
else:
if self.application_id:
add_kwargs["filters"] = {"app_id": self.application_id}

await self.mem0_client.add(**add_kwargs) # type: ignore[misc, call-arg]

# -- Internal methods ------------------------------------------------------

Expand All @@ -173,15 +229,21 @@ def _validate_filters(self) -> None:
if not self.agent_id and not self.user_id and not self.application_id:
raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.")

def _build_filters(self) -> dict[str, Any]:
"""Build search filters from initialization parameters."""
filters: dict[str, Any] = {}
if self.user_id:
filters["user_id"] = self.user_id
if self.agent_id:
filters["agent_id"] = self.agent_id
if self.application_id:
filters["app_id"] = self.application_id
def _build_search_kwargs(self, input_text: str, entity_key: str, entity_value: str) -> dict[str, Any]:
"""Build search keyword arguments formatted for OSS vs Platform clients."""
filters: dict[str, Any] = {"query": input_text}

if isinstance(self.mem0_client, AsyncMemory):
# AsyncMemory (OSS) expects direct kwargs
filters[entity_key] = entity_value
if self.application_id:
filters["app_id"] = self.application_id
else:
# AsyncMemoryClient (Platform) expects a filters dict
filters["filters"] = {entity_key: entity_value}
if self.application_id:
filters["filters"]["app_id"] = self.application_id

return filters


Expand Down
139 changes: 102 additions & 37 deletions python/packages/mem0/tests/test_mem0_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from agent_framework import AgentResponse, Message
Expand Down Expand Up @@ -193,39 +193,61 @@ async def test_oss_client_passes_direct_kwargs(self, mock_oss_mem0_client: Async
assert call_kwargs["user_id"] == "u1"
assert "filters" not in call_kwargs

@pytest.mark.asyncio
async def test_oss_client_all_scoping_params(self, mock_oss_mem0_client: AsyncMock) -> None:
"""OSS client with all scoping parameters passes them as direct kwargs."""
"""OSS client with all scoping parameters passes them as isolated concurrent kwargs."""
mock_oss_mem0_client.search.return_value = []

provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_oss_mem0_client, user_id="u1", agent_id="a1", application_id="app1"
source_id="mem0",
mem0_client=mock_oss_mem0_client,
user_id="u1",
agent_id="a1",
# application_id="app1"
)
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")

mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "hello"
mock_context.input_messages = [mock_msg]
mock_context.response = None

await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)

call_kwargs = mock_oss_mem0_client.search.call_args.kwargs
assert call_kwargs["user_id"] == "u1"
assert call_kwargs["agent_id"] == "a1"
assert "filters" not in call_kwargs
# Re-aligned assertion: We expect 2 separate concurrent calls instead of 1 combined call
assert mock_oss_mem0_client.search.call_count == 2
mock_oss_mem0_client.search.assert_any_call(query="hello", user_id="u1")
mock_oss_mem0_client.search.assert_any_call(query="hello", agent_id="a1")

@pytest.mark.asyncio
async def test_platform_client_passes_filters_dict(self, mock_mem0_client: AsyncMock) -> None:
"""Platform AsyncMemoryClient should receive scoping params in a filters dict."""
"""Platform client passes scoping parameters concurrently inside the nested filters dictionary."""
mock_mem0_client.search.return_value = []
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["Hello"])], session_id="s1")

provider = Mem0ContextProvider(
source_id="mem0",
mem0_client=mock_mem0_client,
user_id="u1",
agent_id="a1",
# application_id="app1"
)

mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "hello"
mock_context.input_messages = [mock_msg]
mock_context.response = None

await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)

call_kwargs = mock_mem0_client.search.call_args.kwargs
assert call_kwargs["query"] == "Hello"
assert "filters" in call_kwargs
assert call_kwargs["filters"]["user_id"] == "u1"
# Re-aligned assertion: Platform client isolates filters per call to bypass AND limitations
assert mock_mem0_client.search.call_count == 2
mock_mem0_client.search.assert_any_call(query="hello", filters={"user_id": "u1"})
mock_mem0_client.search.assert_any_call(query="hello", filters={"agent_id": "a1"})


# -- after_run tests -----------------------------------------------------------
Expand Down Expand Up @@ -331,7 +353,7 @@ async def test_stores_with_application_id_metadata(self, mock_mem0_client: Async
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]

assert mock_mem0_client.add.call_args.kwargs["metadata"] == {"application_id": "app1"}
assert mock_mem0_client.add.call_args.kwargs["filters"] == {"app_id": "app1"}


# -- _validate_filters tests --------------------------------------------------
Expand All @@ -358,15 +380,20 @@ def test_passes_with_application_id(self, mock_mem0_client: AsyncMock) -> None:
provider._validate_filters()


# -- _build_filters tests -----------------------------------------------------
# -- _build_search_kwargs tests -----------------------------------------------------


class TestBuildFilters:
"""Test _build_filters method."""
class TestBuildSearchKwargs:
"""Test _build_search_kwargs method."""

def test_user_id_only(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
assert provider._build_filters() == {"user_id": "u1"}

# Pass the 3 required arguments
result = provider._build_search_kwargs("test query", "user_id", "u1")

# AsyncMock triggers the Platform client nested 'filters' structure
assert result == {"query": "test query", "filters": {"user_id": "u1"}}

def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(
Expand All @@ -376,28 +403,66 @@ def test_all_params(self, mock_mem0_client: AsyncMock) -> None:
agent_id="a1",
application_id="app1",
)
assert provider._build_filters() == {
"user_id": "u1",
"agent_id": "a1",
"app_id": "app1",

# Test that app_id correctly merges with the isolated target entity
result = provider._build_search_kwargs("test query", "agent_id", "a1")

assert result == {
"query": "test query",
"filters": {
"agent_id": "a1",
"app_id": "app1",
},
}

def test_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None:
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters()
assert "agent_id" not in filters
assert "run_id" not in filters
assert "app_id" not in filters

# application_id is None by default, it should not appear in the dictionary
result = provider._build_search_kwargs("test query", "user_id", "u1")

assert "app_id" not in result.get("filters", {})

def test_no_run_id_in_search_filters(self, mock_mem0_client: AsyncMock) -> None:
"""run_id is excluded from search filters so memories work across sessions."""
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client, user_id="u1")
filters = provider._build_filters()
assert "run_id" not in filters

result = provider._build_search_kwargs("test query", "user_id", "u1")

assert "run_id" not in result.get("filters", {})
assert "run_id" not in result

def test_empty_when_no_params(self, mock_mem0_client: AsyncMock) -> None:
# Validates base query payload generation
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
assert provider._build_filters() == {}

result = provider._build_search_kwargs("test query", "custom_key", "custom_val")

assert result == {"query": "test query", "filters": {"custom_key": "custom_val"}}

@pytest.mark.asyncio
async def test_before_run_application_only_fallback(self, mock_mem0_client: AsyncMock) -> None:

provider = Mem0ContextProvider(
source_id="mem0", mem0_client=mock_mem0_client, application_id="app_fallback_test"
)

# Mock a valid message list and session container setup
mock_context = MagicMock(spec=SessionContext)
mock_msg = MagicMock()
mock_msg.text = "Retrieve systemic fallback memory traces"
mock_context.input_messages = [mock_msg]
mock_context.response = None

mock_mem0_client.search = AsyncMock(return_value=[{"id": "m1", "memory": "System configuration template"}])

await provider.before_run(
agent=MagicMock(), session=MagicMock(spec=AgentSession), context=mock_context, state={}
)

# Verify that an application-scoped search task executed successfully
assert mock_mem0_client.search.call_count == 1
mock_context.extend_messages.assert_called_once()


# -- Context manager tests -----------------------------------------------------
Expand Down