Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions lib/crewai/src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -1933,22 +1933,52 @@ def replay(self, task_id: str, inputs: dict[str, Any] | None = None) -> CrewOutp
return self._execute_tasks(self.tasks, start_index, True)

def query_knowledge(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
self,
query: list[str],
results_limit: int = 3,
score_threshold: float = 0.35,
metadata_filter: dict[str, Any] | None = None,
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information."""
"""Query the crew's knowledge base for relevant information.

Args:
query: List of query strings.
results_limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
metadata_filter: Optional metadata filter forwarded to the
underlying knowledge storage.
"""
if self.knowledge:
return self.knowledge.query(
query, results_limit=results_limit, score_threshold=score_threshold
query,
results_limit=results_limit,
score_threshold=score_threshold,
metadata_filter=metadata_filter,
)
return None

async def aquery_knowledge(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
self,
query: list[str],
results_limit: int = 3,
score_threshold: float = 0.35,
metadata_filter: dict[str, Any] | None = None,
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information asynchronously."""
"""Query the crew's knowledge base for relevant information asynchronously.

Args:
query: List of query strings.
results_limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
metadata_filter: Optional metadata filter forwarded to the
underlying knowledge storage.
"""
if self.knowledge:
return await self.knowledge.aquery(
query, results_limit=results_limit, score_threshold=score_threshold
query,
results_limit=results_limit,
score_threshold=score_threshold,
metadata_filter=metadata_filter,
)
return None

Expand Down
27 changes: 23 additions & 4 deletions lib/crewai/src/crewai/knowledge/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,23 @@ def __init__(
self.sources = sources

def query(
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
self,
query: list[str],
results_limit: int = 5,
score_threshold: float = 0.6,
metadata_filter: dict[str, Any] | None = None,
) -> list[SearchResult]:
"""
Query across all knowledge sources to find the most relevant information.
"""Query across all knowledge sources to find the most relevant information.

Returns the top_k most relevant chunks.

Args:
query: List of query strings.
results_limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
metadata_filter: Optional metadata filter forwarded to the
underlying storage.

Raises:
ValueError: If storage is not initialized.
"""
Expand All @@ -140,6 +151,7 @@ def query(
return self.storage.search(
query,
limit=results_limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)

Expand All @@ -158,14 +170,20 @@ def reset(self) -> None:
raise ValueError("Storage is not initialized.")

async def aquery(
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
self,
query: list[str],
results_limit: int = 5,
score_threshold: float = 0.6,
metadata_filter: dict[str, Any] | None = None,
) -> list[SearchResult]:
"""Query across all knowledge sources asynchronously.

Args:
query: List of query strings.
results_limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
metadata_filter: Optional metadata filter forwarded to the
underlying storage.

Returns:
The top results matching the query.
Expand All @@ -179,6 +197,7 @@ async def aquery(
return await self.storage.asearch(
query,
limit=results_limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)

Expand Down
26 changes: 23 additions & 3 deletions lib/crewai/src/crewai/knowledge/knowledge_config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
from typing import Any

from pydantic import BaseModel, Field


class KnowledgeConfig(BaseModel):
"""Configuration for knowledge retrieval.

Args:
results_limit (int): The number of relevant documents to return.
score_threshold (float): The minimum score for a document to be considered relevant.
results_limit (int): The number of relevant documents to return. Must be
at least 1.
score_threshold (float): The minimum score for a document to be
considered relevant. Must be greater than 0 and less than or equal
to 1.
metadata_filter (dict[str, Any] | None): Optional metadata filter
forwarded to the underlying knowledge storage so retrieval can be
narrowed to documents whose stored metadata matches these keys and
values.
"""

results_limit: int = Field(default=5, description="The number of results to return")
results_limit: int = Field(
default=5, ge=1, description="The number of results to return"
)
score_threshold: float = Field(
default=0.6,
gt=0,
le=1,
description="The minimum score for a result to be considered relevant",
)
metadata_filter: dict[str, Any] | None = Field(
default=None,
description=(
"Optional metadata filter passed to knowledge storage to restrict "
"retrieval to documents whose metadata matches these key/value pairs."
),
)
34 changes: 25 additions & 9 deletions lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@ class BaseKnowledgeSource(BaseModel, ABC):

model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
metadata: dict[str, Any] = Field(
default_factory=dict,
description=(
"Metadata merged into each stored document's metadata so it can be "
"matched by ``metadata_filter`` at query time."
),
)
collection_name: str | None = Field(default=None)

@abstractmethod
Expand All @@ -44,15 +50,20 @@ def _chunk_text(self, text: str) -> list[str]:
def _save_documents(self) -> None:
"""Save the documents to the storage.

This method should be called after the chunks and embeddings are generated.
This method should be called after the chunks and embeddings are
generated. When ``self.metadata`` is non-empty it is forwarded to the
underlying storage so each stored document carries the source-level
metadata and can be matched by ``metadata_filter`` at query time.

Raises:
ValueError: If no storage is configured.
"""
if self.storage:
self.storage.save(self.chunks)
else:
if self.storage is None:
raise ValueError("No storage found to save documents.")
if self.metadata:
self.storage.save(self.chunks, metadata=self.metadata)
else:
self.storage.save(self.chunks)

@abstractmethod
async def aadd(self) -> None:
Expand All @@ -61,12 +72,17 @@ async def aadd(self) -> None:
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously.

This method should be called after the chunks and embeddings are generated.
This method should be called after the chunks and embeddings are
generated. When ``self.metadata`` is non-empty it is forwarded to the
underlying storage so each stored document carries the source-level
metadata and can be matched by ``metadata_filter`` at query time.

Raises:
ValueError: If no storage is configured.
"""
if self.storage:
await self.storage.asave(self.chunks)
else:
if self.storage is None:
raise ValueError("No storage found to save documents.")
if self.metadata:
await self.storage.asave(self.chunks, metadata=self.metadata)
else:
await self.storage.asave(self.chunks)
30 changes: 26 additions & 4 deletions lib/crewai/src/crewai/knowledge/storage/base_knowledge_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,34 @@ async def asearch(
"""Search for documents in the knowledge base asynchronously."""

@abstractmethod
def save(self, documents: list[str]) -> None:
"""Save documents to the knowledge base."""
def save(
self,
documents: list[str],
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
) -> None:
"""Save documents to the knowledge base.

Args:
documents: List of document strings to save.
metadata: Optional metadata to attach to each stored document.
A single dict is applied to every document; a list of dicts
must match the number of documents.
"""

@abstractmethod
async def asave(self, documents: list[str]) -> None:
"""Save documents to the knowledge base asynchronously."""
async def asave(
self,
documents: list[str],
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
) -> None:
"""Save documents to the knowledge base asynchronously.

Args:
documents: List of document strings to save.
metadata: Optional metadata to attach to each stored document.
A single dict is applied to every document; a list of dicts
must match the number of documents.
"""

@abstractmethod
def reset(self) -> None:
Expand Down
59 changes: 55 additions & 4 deletions lib/crewai/src/crewai/knowledge/storage/knowledge_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@
from crewai.utilities.logger import Logger


def _build_rag_documents(
documents: list[str],
metadata: dict[str, Any] | list[dict[str, Any]] | None,
) -> list[BaseRecord]:
"""Build BaseRecord entries, optionally attaching per-document metadata.

A single ``metadata`` dict is applied to every document. A list of dicts is
paired positionally with ``documents`` and must match its length.
"""
if metadata is None:
return [{"content": doc} for doc in documents]
if isinstance(metadata, list):
if len(metadata) != len(documents):
raise ValueError(
"metadata list length does not match documents length: "
f"got {len(metadata)} metadata entries for {len(documents)} documents."
)
return [
{"content": doc, "metadata": dict(meta)}
for doc, meta in zip(documents, metadata, strict=True)
]
return [{"content": doc, "metadata": dict(metadata)} for doc in documents]


class KnowledgeStorage(BaseKnowledgeStorage):
"""
Extends Storage to handle embeddings for memory entries, improving
Expand Down Expand Up @@ -102,7 +126,23 @@ def reset(self) -> None:
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
)

def save(self, documents: list[str]) -> None:
def save(
self,
documents: list[str],
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
) -> None:
"""Save documents to the knowledge base.

Args:
documents: List of document strings to save.
metadata: Optional metadata attached to each stored document.
A single dict is applied to every document; a list of dicts
must have the same length as ``documents``.

Raises:
ValueError: If a metadata list is supplied and its length does not
match ``documents``.
"""
if not documents:
return

Expand All @@ -115,7 +155,7 @@ def save(self, documents: list[str]) -> None:
)
client.get_or_create_collection(collection_name=collection_name)

rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
rag_documents: list[BaseRecord] = _build_rag_documents(documents, metadata)

client.add_documents(
collection_name=collection_name, documents=rag_documents
Expand Down Expand Up @@ -178,11 +218,22 @@ async def asearch(
)
return []

async def asave(self, documents: list[str]) -> None:
async def asave(
self,
documents: list[str],
metadata: dict[str, Any] | list[dict[str, Any]] | None = None,
) -> None:
"""Save documents to the knowledge base asynchronously.

Args:
documents: List of document strings to save.
metadata: Optional metadata attached to each stored document.
A single dict is applied to every document; a list of dicts
must have the same length as ``documents``.

Raises:
ValueError: If a metadata list is supplied and its length does not
match ``documents``.
"""
if not documents:
return
Expand All @@ -196,7 +247,7 @@ async def asave(self, documents: list[str]) -> None:
)
await client.aget_or_create_collection(collection_name=collection_name)

rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
rag_documents: list[BaseRecord] = _build_rag_documents(documents, metadata)

await client.aadd_documents(
collection_name=collection_name, documents=rag_documents
Expand Down
1 change: 1 addition & 0 deletions lib/crewai/tests/knowledge/test_async_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ async def test_aquery_calls_storage_asearch(self):
mock_storage.asearch.assert_called_once_with(
["test query"],
limit=5,
metadata_filter=None,
score_threshold=0.6,
)

Expand Down
Loading