diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 0ffec4888c..fd0e60b74e 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -1933,22 +1933,52 @@ def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutp return self._execute_tasks(self.tasks, start_index, True) def query_knowledge( - self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35 + self, + query: list[str], + results_limit: int = 3, + score_threshold: float = 0.35, + metadata_filter: dict[str, Any] | None = None, ) -> list[SearchResult] | None: - """Query the crew's knowledge base for relevant information.""" + """Query the crew's knowledge base for relevant information. + + Args: + query: List of query strings. + results_limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + metadata_filter: Optional metadata filter forwarded to the + underlying knowledge storage. + """ if self.knowledge: return self.knowledge.query( - query, results_limit=results_limit, score_threshold=score_threshold + query, + results_limit=results_limit, + score_threshold=score_threshold, + metadata_filter=metadata_filter, ) return None async def aquery_knowledge( - self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35 + self, + query: list[str], + results_limit: int = 3, + score_threshold: float = 0.35, + metadata_filter: dict[str, Any] | None = None, ) -> list[SearchResult] | None: - """Query the crew's knowledge base for relevant information asynchronously.""" + """Query the crew's knowledge base for relevant information asynchronously. + + Args: + query: List of query strings. + results_limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + metadata_filter: Optional metadata filter forwarded to the + underlying knowledge storage. + """ if self.knowledge: return await self.knowledge.aquery( - query, results_limit=results_limit, score_threshold=score_threshold + query, + results_limit=results_limit, + score_threshold=score_threshold, + metadata_filter=metadata_filter, ) return None diff --git a/lib/crewai/src/crewai/knowledge/knowledge.py b/lib/crewai/src/crewai/knowledge/knowledge.py index 8dcf38f4ea..85c7e01954 100644 --- a/lib/crewai/src/crewai/knowledge/knowledge.py +++ b/lib/crewai/src/crewai/knowledge/knowledge.py @@ -125,12 +125,23 @@ def __init__( self.sources = sources def query( - self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6 + self, + query: list[str], + results_limit: int = 5, + score_threshold: float = 0.6, + metadata_filter: dict[str, Any] | None = None, ) -> list[SearchResult]: - """ - Query across all knowledge sources to find the most relevant information. + """Query across all knowledge sources to find the most relevant information. + Returns the top_k most relevant chunks. + Args: + query: List of query strings. + results_limit: Maximum number of results to return. + score_threshold: Minimum similarity score for results. + metadata_filter: Optional metadata filter forwarded to the + underlying storage. + Raises: ValueError: If storage is not initialized. """ @@ -140,6 +151,7 @@ def query( return self.storage.search( query, limit=results_limit, + metadata_filter=metadata_filter, score_threshold=score_threshold, ) @@ -158,7 +170,11 @@ def reset(self) -> None: raise ValueError("Storage is not initialized.") async def aquery( - self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6 + self, + query: list[str], + results_limit: int = 5, + score_threshold: float = 0.6, + metadata_filter: dict[str, Any] | None = None, ) -> list[SearchResult]: """Query across all knowledge sources asynchronously. @@ -166,6 +182,8 @@ async def aquery( query: List of query strings. results_limit: Maximum number of results to return. score_threshold: Minimum similarity score for results. + metadata_filter: Optional metadata filter forwarded to the + underlying storage. Returns: The top results matching the query. @@ -179,6 +197,7 @@ async def aquery( return await self.storage.asearch( query, limit=results_limit, + metadata_filter=metadata_filter, score_threshold=score_threshold, ) diff --git a/lib/crewai/src/crewai/knowledge/knowledge_config.py b/lib/crewai/src/crewai/knowledge/knowledge_config.py index 67f0ee44b8..2938489136 100644 --- a/lib/crewai/src/crewai/knowledge/knowledge_config.py +++ b/lib/crewai/src/crewai/knowledge/knowledge_config.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, Field @@ -5,12 +7,30 @@ class KnowledgeConfig(BaseModel): """Configuration for knowledge retrieval. Args: - results_limit (int): The number of relevant documents to return. - score_threshold (float): The minimum score for a document to be considered relevant. + results_limit (int): The number of relevant documents to return. Must be + at least 1. + score_threshold (float): The minimum score for a document to be + considered relevant. Must be greater than 0 and less than or equal + to 1. + metadata_filter (dict[str, Any] | None): Optional metadata filter + forwarded to the underlying knowledge storage so retrieval can be + narrowed to documents whose stored metadata matches these keys and + values. """ - results_limit: int = Field(default=5, description="The number of results to return") + results_limit: int = Field( + default=5, ge=1, description="The number of results to return" + ) score_threshold: float = Field( default=0.6, + gt=0, + le=1, description="The minimum score for a result to be considered relevant", ) + metadata_filter: dict[str, Any] | None = Field( + default=None, + description=( + "Optional metadata filter passed to knowledge storage to restrict " + "retrieval to documents whose metadata matches these key/value pairs." + ), + ) diff --git a/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py index 8c99b47b0a..485028050e 100644 --- a/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py @@ -19,7 +19,13 @@ class BaseKnowledgeSource(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) storage: KnowledgeStorage | None = Field(default=None) - metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused + metadata: dict[str, Any] = Field( + default_factory=dict, + description=( + "Metadata merged into each stored document's metadata so it can be " + "matched by ``metadata_filter`` at query time." + ), + ) collection_name: str | None = Field(default=None) @abstractmethod @@ -44,15 +50,20 @@ def _chunk_text(self, text: str) -> list[str]: def _save_documents(self) -> None: """Save the documents to the storage. - This method should be called after the chunks and embeddings are generated. + This method should be called after the chunks and embeddings are + generated. When ``self.metadata`` is non-empty it is forwarded to the + underlying storage so each stored document carries the source-level + metadata and can be matched by ``metadata_filter`` at query time. Raises: ValueError: If no storage is configured. """ - if self.storage: - self.storage.save(self.chunks) - else: + if self.storage is None: raise ValueError("No storage found to save documents.") + if self.metadata: + self.storage.save(self.chunks, metadata=self.metadata) + else: + self.storage.save(self.chunks) @abstractmethod async def aadd(self) -> None: @@ -61,12 +72,17 @@ async def aadd(self) -> None: async def _asave_documents(self) -> None: """Save the documents to the storage asynchronously. - This method should be called after the chunks and embeddings are generated. + This method should be called after the chunks and embeddings are + generated. When ``self.metadata`` is non-empty it is forwarded to the + underlying storage so each stored document carries the source-level + metadata and can be matched by ``metadata_filter`` at query time. Raises: ValueError: If no storage is configured. """ - if self.storage: - await self.storage.asave(self.chunks) - else: + if self.storage is None: raise ValueError("No storage found to save documents.") + if self.metadata: + await self.storage.asave(self.chunks, metadata=self.metadata) + else: + await self.storage.asave(self.chunks) diff --git a/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py index ea8aff7341..30245698fd 100644 --- a/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py @@ -35,12 +35,34 @@ async def asearch( """Search for documents in the knowledge base asynchronously.""" @abstractmethod - def save(self, documents: list[str]) -> None: - """Save documents to the knowledge base.""" + def save( + self, + documents: list[str], + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + ) -> None: + """Save documents to the knowledge base. + + Args: + documents: List of document strings to save. + metadata: Optional metadata to attach to each stored document. + A single dict is applied to every document; a list of dicts + must match the number of documents. + """ @abstractmethod - async def asave(self, documents: list[str]) -> None: - """Save documents to the knowledge base asynchronously.""" + async def asave( + self, + documents: list[str], + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + ) -> None: + """Save documents to the knowledge base asynchronously. + + Args: + documents: List of document strings to save. + metadata: Optional metadata to attach to each stored document. + A single dict is applied to every document; a list of dicts + must match the number of documents. + """ @abstractmethod def reset(self) -> None: diff --git a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py index 3c9615946f..812024b6bc 100644 --- a/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py +++ b/lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py @@ -19,6 +19,30 @@ from crewai.utilities.logger import Logger +def _build_rag_documents( + documents: list[str], + metadata: dict[str, Any] | list[dict[str, Any]] | None, +) -> list[BaseRecord]: + """Build BaseRecord entries, optionally attaching per-document metadata. + + A single ``metadata`` dict is applied to every document. A list of dicts is + paired positionally with ``documents`` and must match its length. + """ + if metadata is None: + return [{"content": doc} for doc in documents] + if isinstance(metadata, list): + if len(metadata) != len(documents): + raise ValueError( + "metadata list length does not match documents length: " + f"got {len(metadata)} metadata entries for {len(documents)} documents." + ) + return [ + {"content": doc, "metadata": dict(meta)} + for doc, meta in zip(documents, metadata, strict=True) + ] + return [{"content": doc, "metadata": dict(metadata)} for doc in documents] + + class KnowledgeStorage(BaseKnowledgeStorage): """ Extends Storage to handle embeddings for memory entries, improving @@ -102,7 +126,23 @@ def reset(self) -> None: f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}" ) - def save(self, documents: list[str]) -> None: + def save( + self, + documents: list[str], + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + ) -> None: + """Save documents to the knowledge base. + + Args: + documents: List of document strings to save. + metadata: Optional metadata attached to each stored document. + A single dict is applied to every document; a list of dicts + must have the same length as ``documents``. + + Raises: + ValueError: If a metadata list is supplied and its length does not + match ``documents``. + """ if not documents: return @@ -115,7 +155,7 @@ def save(self, documents: list[str]) -> None: ) client.get_or_create_collection(collection_name=collection_name) - rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents] + rag_documents: list[BaseRecord] = _build_rag_documents(documents, metadata) client.add_documents( collection_name=collection_name, documents=rag_documents @@ -178,11 +218,22 @@ async def asearch( ) return [] - async def asave(self, documents: list[str]) -> None: + async def asave( + self, + documents: list[str], + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + ) -> None: """Save documents to the knowledge base asynchronously. Args: documents: List of document strings to save. + metadata: Optional metadata attached to each stored document. + A single dict is applied to every document; a list of dicts + must have the same length as ``documents``. + + Raises: + ValueError: If a metadata list is supplied and its length does not + match ``documents``. """ if not documents: return @@ -196,7 +247,7 @@ async def asave(self, documents: list[str]) -> None: ) await client.aget_or_create_collection(collection_name=collection_name) - rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents] + rag_documents: list[BaseRecord] = _build_rag_documents(documents, metadata) await client.aadd_documents( collection_name=collection_name, documents=rag_documents diff --git a/lib/crewai/tests/knowledge/test_async_knowledge.py b/lib/crewai/tests/knowledge/test_async_knowledge.py index c243b3ce4b..6c38744d96 100644 --- a/lib/crewai/tests/knowledge/test_async_knowledge.py +++ b/lib/crewai/tests/knowledge/test_async_knowledge.py @@ -94,6 +94,7 @@ async def test_aquery_calls_storage_asearch(self): mock_storage.asearch.assert_called_once_with( ["test query"], limit=5, + metadata_filter=None, score_threshold=0.6, ) diff --git a/lib/crewai/tests/knowledge/test_knowledge_metadata_filter.py b/lib/crewai/tests/knowledge/test_knowledge_metadata_filter.py new file mode 100644 index 0000000000..9615007ad0 --- /dev/null +++ b/lib/crewai/tests/knowledge/test_knowledge_metadata_filter.py @@ -0,0 +1,316 @@ +"""Tests for metadata_filter threading through the knowledge query pipeline. + +Issue #5805: ``KnowledgeStorage.search`` already accepts a ``metadata_filter``, +but the public ``Knowledge.query`` / ``Crew.query_knowledge`` / +``KnowledgeConfig`` layer never forwarded it, so users could not narrow +retrieval by document metadata. These tests pin the wiring with a fake +storage so the regression cannot return silently. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pydantic import PrivateAttr, ValidationError + +from crewai.crew import Crew +from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.knowledge_config import KnowledgeConfig +from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage +from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage +from crewai.rag.types import SearchResult + + +class _RecordingStorage(BaseKnowledgeStorage): + """Fake storage that records call kwargs without touching a real backend.""" + + _search_calls: list[dict[str, Any]] = PrivateAttr(default_factory=list) + _asearch_calls: list[dict[str, Any]] = PrivateAttr(default_factory=list) + _save_calls: list[dict[str, Any]] = PrivateAttr(default_factory=list) + _asave_calls: list[dict[str, Any]] = PrivateAttr(default_factory=list) + + @property + def search_calls(self) -> list[dict[str, Any]]: + return self._search_calls + + @property + def asearch_calls(self) -> list[dict[str, Any]]: + return self._asearch_calls + + @property + def save_calls(self) -> list[dict[str, Any]]: + return self._save_calls + + @property + def asave_calls(self) -> list[dict[str, Any]]: + return self._asave_calls + + def search( + self, + query: list[str], + limit: int = 5, + metadata_filter: dict[str, Any] | None = None, + score_threshold: float = 0.6, + ) -> list[SearchResult]: + self._search_calls.append( + { + "query": query, + "limit": limit, + "metadata_filter": metadata_filter, + "score_threshold": score_threshold, + } + ) + return [] + + async def asearch( + self, + query: list[str], + limit: int = 5, + metadata_filter: dict[str, Any] | None = None, + score_threshold: float = 0.6, + ) -> list[SearchResult]: + self._asearch_calls.append( + { + "query": query, + "limit": limit, + "metadata_filter": metadata_filter, + "score_threshold": score_threshold, + } + ) + return [] + + def save( + self, + documents: list[str], + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + ) -> None: + self._save_calls.append({"documents": list(documents), "metadata": metadata}) + + async def asave( + self, + documents: list[str], + metadata: dict[str, Any] | list[dict[str, Any]] | None = None, + ) -> None: + self._asave_calls.append({"documents": list(documents), "metadata": metadata}) + + def reset(self) -> None: + pass + + async def areset(self) -> None: + pass + + +class TestKnowledgeConfigValidators: + """KnowledgeConfig should reject out-of-range bounds eagerly.""" + + def test_defaults_include_metadata_filter(self) -> None: + config = KnowledgeConfig() + + assert config.results_limit == 5 + assert config.score_threshold == 0.6 + assert config.metadata_filter is None + + def test_metadata_filter_round_trips_through_model_dump(self) -> None: + config = KnowledgeConfig(metadata_filter={"task": "translation"}) + + assert config.model_dump()["metadata_filter"] == {"task": "translation"} + + def test_results_limit_must_be_at_least_one(self) -> None: + with pytest.raises(ValidationError): + KnowledgeConfig(results_limit=0) + + def test_score_threshold_must_be_positive(self) -> None: + with pytest.raises(ValidationError): + KnowledgeConfig(score_threshold=0) + + def test_score_threshold_must_be_at_most_one(self) -> None: + with pytest.raises(ValidationError): + KnowledgeConfig(score_threshold=1.5) + + +class TestKnowledgeQueryForwardsMetadataFilter: + """Knowledge.query / aquery must forward metadata_filter to storage.""" + + def test_query_forwards_metadata_filter(self) -> None: + storage = _RecordingStorage() + knowledge = Knowledge(collection_name="t", sources=[], storage=storage) + + knowledge.query( + ["hello"], + results_limit=7, + score_threshold=0.42, + metadata_filter={"task": "translation"}, + ) + + assert storage.search_calls == [ + { + "query": ["hello"], + "limit": 7, + "metadata_filter": {"task": "translation"}, + "score_threshold": 0.42, + } + ] + + def test_query_defaults_metadata_filter_to_none(self) -> None: + storage = _RecordingStorage() + knowledge = Knowledge(collection_name="t", sources=[], storage=storage) + + knowledge.query(["hello"]) + + assert storage.search_calls[0]["metadata_filter"] is None + + @pytest.mark.asyncio + async def test_aquery_forwards_metadata_filter(self) -> None: + storage = _RecordingStorage() + knowledge = Knowledge(collection_name="t", sources=[], storage=storage) + + await knowledge.aquery( + ["hi"], + results_limit=2, + score_threshold=0.25, + metadata_filter={"agent": "translator"}, + ) + + assert storage.asearch_calls == [ + { + "query": ["hi"], + "limit": 2, + "metadata_filter": {"agent": "translator"}, + "score_threshold": 0.25, + } + ] + + +class TestCrewQueryKnowledgeForwardsMetadataFilter: + """Crew.query_knowledge / aquery_knowledge must forward metadata_filter.""" + + def _make_crew(self, storage: _RecordingStorage) -> Crew: + knowledge = Knowledge(collection_name="t", sources=[], storage=storage) + crew = Crew.__new__(Crew) + object.__setattr__(crew, "knowledge", knowledge) + return crew + + def test_query_knowledge_forwards_metadata_filter(self) -> None: + storage = _RecordingStorage() + crew = self._make_crew(storage) + + Crew.query_knowledge( + crew, + ["q"], + results_limit=4, + score_threshold=0.3, + metadata_filter={"pipeline_stage": "review"}, + ) + + assert storage.search_calls[0]["metadata_filter"] == {"pipeline_stage": "review"} + assert storage.search_calls[0]["limit"] == 4 + assert storage.search_calls[0]["score_threshold"] == 0.3 + + @pytest.mark.asyncio + async def test_aquery_knowledge_forwards_metadata_filter(self) -> None: + storage = _RecordingStorage() + crew = self._make_crew(storage) + + await Crew.aquery_knowledge( + crew, + ["q"], + results_limit=1, + score_threshold=0.5, + metadata_filter={"source_agent": "scout"}, + ) + + assert storage.asearch_calls[0]["metadata_filter"] == {"source_agent": "scout"} + assert storage.asearch_calls[0]["limit"] == 1 + assert storage.asearch_calls[0]["score_threshold"] == 0.5 + + +class TestBaseKnowledgeSourceSavesMetadata: + """BaseKnowledgeSource must merge self.metadata into stored documents.""" + + def test_save_documents_passes_metadata_when_set(self) -> None: + storage = _RecordingStorage() + source = StringKnowledgeSource( + content="hello world", metadata={"source": "manual"} + ) + source.storage = storage + source.chunks = ["hello world"] + + source._save_documents() + + assert storage.save_calls == [ + {"documents": ["hello world"], "metadata": {"source": "manual"}} + ] + + def test_save_documents_omits_metadata_when_empty(self) -> None: + storage = _RecordingStorage() + source = StringKnowledgeSource(content="hello world") + source.storage = storage + source.chunks = ["hello world"] + + source._save_documents() + + assert storage.save_calls == [{"documents": ["hello world"], "metadata": None}] + + def test_save_documents_raises_without_storage(self) -> None: + source = StringKnowledgeSource(content="x", metadata={"k": "v"}) + source.storage = None + + with pytest.raises(ValueError, match="No storage found"): + source._save_documents() + + @pytest.mark.asyncio + async def test_asave_documents_passes_metadata_when_set(self) -> None: + storage = _RecordingStorage() + source = StringKnowledgeSource( + content="hello world", metadata={"source": "manual"} + ) + source.storage = storage + source.chunks = ["hello world"] + + await source._asave_documents() + + assert storage.asave_calls == [ + {"documents": ["hello world"], "metadata": {"source": "manual"}} + ] + + +class TestKnowledgeStorageBuildsMetadataRecords: + """KnowledgeStorage.save should attach metadata to BaseRecord entries.""" + + def test_save_attaches_dict_metadata_to_each_record(self) -> None: + mock_client = MagicMock() + storage = KnowledgeStorage(collection_name="t") + storage._client = mock_client + + storage.save(["a", "b"], metadata={"task": "translation"}) + + mock_client.add_documents.assert_called_once() + records = mock_client.add_documents.call_args.kwargs["documents"] + assert records == [ + {"content": "a", "metadata": {"task": "translation"}}, + {"content": "b", "metadata": {"task": "translation"}}, + ] + + def test_save_rejects_metadata_list_length_mismatch(self) -> None: + mock_client = MagicMock() + storage = KnowledgeStorage(collection_name="t") + storage._client = mock_client + + with pytest.raises(ValueError, match="metadata list length"): + storage.save(["a", "b"], metadata=[{"k": "v"}]) + + @pytest.mark.asyncio + async def test_asave_attaches_dict_metadata_to_each_record(self) -> None: + mock_client = MagicMock() + mock_client.aget_or_create_collection = AsyncMock() + mock_client.aadd_documents = AsyncMock() + storage = KnowledgeStorage(collection_name="t") + storage._client = mock_client + + await storage.asave(["a"], metadata={"task": "translation"}) + + records = mock_client.aadd_documents.call_args.kwargs["documents"] + assert records == [{"content": "a", "metadata": {"task": "translation"}}] diff --git a/lib/crewai/tests/knowledge/test_knowledge_searchresult.py b/lib/crewai/tests/knowledge/test_knowledge_searchresult.py index 6f3db84dec..e1b84106e4 100644 --- a/lib/crewai/tests/knowledge/test_knowledge_searchresult.py +++ b/lib/crewai/tests/knowledge/test_knowledge_searchresult.py @@ -39,7 +39,7 @@ def test_knowledge_query_returns_searchresult() -> None: ) mock_storage.search.assert_called_once_with( - ["AI technology"], limit=5, score_threshold=0.3 + ["AI technology"], limit=5, metadata_filter=None, score_threshold=0.3 ) assert isinstance(results, list)