Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions lib/crewai/src/crewai/tools/memory_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
class RecallMemorySchema(BaseModel):
"""Schema for the recall memory tool."""

queries: list[str] = Field(
...,
queries: list[str] | None = Field(
default=None,
description=(
"One or more search queries. Pass a single item for a focused search, "
"or multiple items to search for several things at once."
"REQUIRED: A list of search query strings. "
"Examples: ['AI trends'], ['Python', 'machine learning'], ['vector databases']. "
"Pass a single item for a focused search, or multiple items to search for several things at once."
),
min_length=1,
)

model_config = {"extra": "forbid"}


class RecallMemoryTool(BaseTool):
"""Tool that lets an agent search memory for one or more queries at once."""
Expand All @@ -32,7 +36,7 @@ class RecallMemoryTool(BaseTool):

def _run(
self,
queries: list[str] | str,
queries: list[str] | str | None = None,
**kwargs: Any,
) -> str:
"""Search memory for relevant information.
Expand All @@ -43,9 +47,20 @@ def _run(
Returns:
Formatted string of matching memories, or a message if none found.
"""
# Handle None or empty input
if not queries:
return "Error: Please provide search queries. Example: search_memory(queries=['AI trends'])"

# Handle string input
if isinstance(queries, str):
queries = [queries]

# Filter out empty strings
queries = [q for q in queries if q and q.strip()]

if not queries:
return "Error: Please provide non-empty search queries."

all_lines: list[str] = []
seen_ids: set[str] = set()
for query in queries:
Expand All @@ -63,14 +78,18 @@ def _run(
class RememberSchema(BaseModel):
"""Schema for the remember tool."""

contents: list[str] = Field(
...,
contents: list[str] | None = Field(
default=None,
description=(
"One or more facts, decisions, or observations to remember. "
"REQUIRED: A list of strings to save to memory. "
"Examples: ['User prefers dark mode'], ['Project deadline is March 15', 'Budget is $50k']. "
"Pass a single item or multiple items at once."
),
min_length=1,
)

model_config = {"extra": "forbid"}


class RememberTool(BaseTool):
"""Tool that lets an agent save one or more items to memory at once."""
Expand All @@ -80,7 +99,7 @@ class RememberTool(BaseTool):
args_schema: type[BaseModel] = RememberSchema
memory: Any = Field(exclude=True)

def _run(self, contents: list[str] | str, **kwargs: Any) -> str:
def _run(self, contents: list[str] | str | None = None, **kwargs: Any) -> str:
"""Store one or more items in memory. The system infers scope, categories, and importance.

Args:
Expand All @@ -89,8 +108,19 @@ def _run(self, contents: list[str] | str, **kwargs: Any) -> str:
Returns:
Confirmation with the number of items saved.
"""
# Handle None or empty input
if not contents:
return "Error: Please provide content to save. Example: save_to_memory(contents=['fact to remember'])"

if isinstance(contents, str):
contents = [contents]

# Filter out empty strings
contents = [c for c in contents if c and c.strip()]

if not contents:
return "Error: Please provide non-empty content to save."

if len(contents) == 1:
record = self.memory.remember(contents[0])
return (
Expand Down
4 changes: 2 additions & 2 deletions lib/crewai/src/crewai/translations/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
"description": "See image to understand its content, you can optionally ask a question about the image",
"default_action": "Please provide a detailed description of this image, including all visual elements, context, and any notable details you can observe."
},
"recall_memory": "Search through the team's shared memory for relevant information. Pass one or more queries to search for multiple things at once. Use this when you need to find facts, decisions, preferences, or past results that may have been stored previously. IMPORTANT: For questions that require counting, summing, or listing items across multiple conversations (e.g. 'how many X', 'total Y', 'list all Z'), you MUST search multiple times with different phrasings to ensure you find ALL relevant items before giving a final count or total. Do not rely on a single search — items may be described differently across conversations.",
"save_to_memory": "Store one or more important facts, decisions, observations, or lessons in memory so they can be recalled later by you or other agents. Pass multiple items at once when you have several things worth remembering."
"recall_memory": "Search through the team's shared memory for relevant information. REQUIRED: You must provide a 'queries' parameter with a list of search strings, for example: {\"queries\": [\"search term\"]} or {\"queries\": [\"term1\", \"term2\"]}. Use this when you need to find facts, decisions, preferences, or past results that may have been stored previously. IMPORTANT: For questions that require counting, summing, or listing items across multiple conversations (e.g. 'how many X', 'total Y', 'list all Z'), you MUST search multiple times with different phrasings to ensure you find ALL relevant items before giving a final count or total. Do not rely on a single search — items may be described differently across conversations.",
"save_to_memory": "Store one or more important facts, decisions, observations, or lessons in memory so they can be recalled later by you or other agents. REQUIRED: You must provide a 'contents' parameter with a list of strings to save, for example: {\"contents\": [\"fact to remember\"]} or {\"contents\": [\"fact1\", \"fact2\"]}. Pass multiple items at once when you have several things worth remembering."
},
"memory": {
"query_system": "You analyze a query for searching memory.\nGiven the query and available scopes, output:\n1. keywords: Key entities or keywords that can be used to filter by category.\n2. suggested_scopes: Which available scopes are most relevant (empty for all).\n3. complexity: 'simple' or 'complex'.\n4. recall_queries: 1-3 short, targeted search phrases distilled from the query. Each should be a concise phrase optimized for semantic vector search. If the query is already short and focused, return it as-is in a single-item list. For long task descriptions, extract the distinct things worth searching for.\n5. time_filter: If the query references a time period (like 'last week', 'yesterday', 'in January'), return an ISO 8601 date string for the earliest relevant date (e.g. '2026-02-01'). Return null if no time constraint is implied.",
Expand Down
140 changes: 140 additions & 0 deletions lib/crewai/tests/tools/test_memory_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Tests for memory tool input validation and hardening."""

from __future__ import annotations

from unittest.mock import MagicMock

import pytest

from crewai.memory.types import MemoryMatch, MemoryRecord
from crewai.tools.memory_tools import RecallMemoryTool, RememberTool


@pytest.fixture
def mock_memory() -> MagicMock:
"""Create a mock Memory instance."""
memory = MagicMock()
memory.read_only = False
memory.recall.return_value = []
memory.remember.return_value = MemoryRecord(content="test")
memory.remember_many.return_value = None
return memory


@pytest.fixture
def recall_tool(mock_memory: MagicMock) -> RecallMemoryTool:
return RecallMemoryTool(memory=mock_memory, description="test recall")


@pytest.fixture
def remember_tool(mock_memory: MagicMock) -> RememberTool:
return RememberTool(memory=mock_memory, description="test remember")


# --- RecallMemoryTool ---


class TestRecallMemoryToolValidation:
"""Tests for RecallMemoryTool input validation."""

def test_none_queries_returns_error(self, recall_tool: RecallMemoryTool) -> None:
result = recall_tool._run(queries=None)
assert "Error" in result

def test_empty_list_queries_returns_error(
self, recall_tool: RecallMemoryTool
) -> None:
result = recall_tool._run(queries=[])
assert "Error" in result

def test_list_of_empty_strings_returns_error(
self, recall_tool: RecallMemoryTool,
) -> None:
result = recall_tool._run(queries=["", " ", ""])
assert "Error" in result

def test_string_input_converted_to_list(
self, recall_tool: RecallMemoryTool, mock_memory: MagicMock
) -> None:
recall_tool._run(queries="single query")
mock_memory.recall.assert_called_once_with("single query", limit=20)

def test_valid_queries_calls_memory(
self, recall_tool: RecallMemoryTool, mock_memory: MagicMock
) -> None:
recall_tool._run(queries=["query1", "query2"])
assert mock_memory.recall.call_count == 2

def test_no_matches_returns_message(
self, recall_tool: RecallMemoryTool, mock_memory: MagicMock
) -> None:
mock_memory.recall.return_value = []
result = recall_tool._run(queries=["test"])
assert "No relevant memories found" in result

def test_matches_formatted(
self, recall_tool: RecallMemoryTool, mock_memory: MagicMock
) -> None:
record = MemoryRecord(content="important fact")
match = MemoryMatch(record=record, score=0.9)
mock_memory.recall.return_value = [match]
result = recall_tool._run(queries=["test"])
assert "important fact" in result

def test_deduplicates_across_queries(
self, recall_tool: RecallMemoryTool, mock_memory: MagicMock
) -> None:
record = MemoryRecord(id="same-id", content="fact")
match = MemoryMatch(record=record, score=0.9)
mock_memory.recall.return_value = [match]
result = recall_tool._run(queries=["q1", "q2"])
# Should only appear once despite two queries returning same record
assert result.count("fact") == 1


# --- RememberTool ---


class TestRememberToolValidation:
"""Tests for RememberTool input validation."""

def test_none_contents_returns_error(self, remember_tool: RememberTool) -> None:
result = remember_tool._run(contents=None)
assert "Error" in result

def test_empty_list_returns_error(self, remember_tool: RememberTool) -> None:
result = remember_tool._run(contents=[])
assert "Error" in result

def test_list_of_empty_strings_returns_error(
self, remember_tool: RememberTool,
) -> None:
result = remember_tool._run(contents=["", " "])
assert "Error" in result

def test_string_input_converted_to_list(
self, remember_tool: RememberTool, mock_memory: MagicMock
) -> None:
remember_tool._run(contents="single fact")
mock_memory.remember.assert_called_once_with("single fact")

def test_single_item_calls_remember(
self, remember_tool: RememberTool, mock_memory: MagicMock
) -> None:
result = remember_tool._run(contents=["a fact"])
mock_memory.remember.assert_called_once_with("a fact")
assert "Saved to memory" in result

def test_multiple_items_calls_remember_many(
self, remember_tool: RememberTool, mock_memory: MagicMock
) -> None:
result = remember_tool._run(contents=["fact1", "fact2"])
mock_memory.remember_many.assert_called_once_with(["fact1", "fact2"])
assert "2 items" in result

def test_filters_empty_strings_from_list(
self, remember_tool: RememberTool, mock_memory: MagicMock
) -> None:
result = remember_tool._run(contents=["real fact", "", " "])
mock_memory.remember.assert_called_once_with("real fact")
assert "Saved to memory" in result
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading