Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add FilterRetriever #6836

Merged
merged 12 commits into from
Feb 8, 2024
9 changes: 7 additions & 2 deletions docs/pydoc/config/retrievers_api.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
loaders:
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
search_path: [../../../haystack/components/retrievers/in_memory]
modules: ["bm25_retriever", "embedding_retriever"]
search_path: [../../../haystack/components/retrievers]
modules:
[
"in_memory/bm25_retriever",
"in_memory/embedding_retriever",
"filter_retriever",
]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
Expand Down
3 changes: 3 additions & 0 deletions haystack/components/retrievers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from haystack.components.retrievers.filter_retriever import FilterRetriever

__all__ = ["FilterRetriever"]
97 changes: 97 additions & 0 deletions haystack/components/retrievers/filter_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import importlib
import logging

from typing import Dict, List, Any, Optional

from haystack import component, Document, default_to_dict, default_from_dict, DeserializationError
from haystack.document_stores.types import DocumentStore


logger = logging.getLogger(__name__)


@component
class FilterRetriever:
"""
Retrieves documents that match the provided filters.

Usage example:
```python
from haystack import Document
from haystack.components.retrievers import FilterRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore

docs = [
Document(content="Python is a popular programming language", meta={"lang": "en"}),
Document(content="python ist eine beliebte Programmiersprache", meta={"lang": "de"}),
]

doc_store = InMemoryDocumentStore()
doc_store.write_documents(docs)
retriever = FilterRetriever(doc_store, filters={"field": "lang", "operator": "==", "value": "en"})

# if passed in the run method, filters will override those provided at initialization
result = retriever.run(filters={"field": "lang", "operator": "==", "value": "de"})

assert "documents" in result
assert len(result["documents"]) == 1
assert result["documents"][0].content == "python ist eine beliebte Programmiersprache"
```
"""

def __init__(self, document_store: DocumentStore, filters: Optional[Dict[str, Any]] = None):
"""
Create the FilterRetriever component.

:param document_store: An instance of a DocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Defaults to `None`.
"""
self.document_store = document_store
self.filters = filters

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"document_store": type(self.document_store).__name__}

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
docstore = self.document_store.to_dict()
return default_to_dict(self, document_store=docstore, filters=self.filters)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FilterRetriever":
"""
Deserialize this component from a dictionary.
"""
init_params = data.get("init_parameters", {})
if "document_store" not in init_params:
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
try:
module_name, type_ = init_params["document_store"]["type"].rsplit(".", 1)
logger.debug("Trying to import %s", module_name)
module = importlib.import_module(module_name)
except (ImportError, DeserializationError) as e:
raise DeserializationError(
f"DocumentStore of type '{init_params['document_store']['type']}' not correctly imported"
) from e

docstore_class = getattr(module, type_)
data["init_parameters"]["document_store"] = docstore_class.from_dict(data["init_parameters"]["document_store"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, filters: Optional[Dict[str, Any]] = None):
"""
Run the FilterRetriever on the given input data.

:param filters: A dictionary with filters to narrow down the search space.
If not specified, the FilterRetriever uses the value provided at initialization.
:return: The retrieved documents.
"""
return {"documents": self.document_store.filter_documents(filters=filters or self.filters)}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Add FilterRetriever.
It retrieves documents that match the provided (either at init or runtime) filters.
139 changes: 139 additions & 0 deletions test/components/retrievers/test_filter_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import Dict, Any, List

import pytest

from haystack import Pipeline, DeserializationError
from haystack.testing.factory import document_store_class
from haystack.components.retrievers.filter_retriever import FilterRetriever
from haystack.dataclasses import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore


@pytest.fixture()
def sample_docs():
en_docs = [
Document(content="Javascript is a popular programming language", meta={"lang": "en"}),
Document(content="Python is a popular programming language", meta={"lang": "en"}),
Document(content="A chromosome is a package of DNA ", meta={"lang": "en"}),
]
de_docs = [
Document(content="python ist eine beliebte Programmiersprache", meta={"lang": "de"}),
Document(content="javascript ist eine beliebte Programmiersprache", meta={"lang": "de"}),
]
all_docs = en_docs + de_docs
return {"en_docs": en_docs, "de_docs": de_docs, "all_docs": all_docs}


@pytest.fixture()
def sample_document_store(sample_docs):
doc_store = InMemoryDocumentStore()
doc_store.write_documents(sample_docs["all_docs"])
return doc_store


class TestFilterRetriever:
@classmethod
def _documents_equal(cls, docs1: List[Document], docs2: List[Document]) -> bool:
# # Order doesn't matter; we sort before comparing
docs1.sort(key=lambda x: x.id)
docs2.sort(key=lambda x: x.id)
return docs1 == docs2

def test_init_default(self):
retriever = FilterRetriever(InMemoryDocumentStore())
assert retriever.filters is None

def test_init_with_parameters(self):
retriever = FilterRetriever(InMemoryDocumentStore(), filters={"lang": "en"})
assert retriever.filters == {"lang": "en"}

def test_to_dict(self):
FilterDocStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
document_store = FilterDocStore()
document_store.to_dict = lambda: {"type": "FilterDocStore", "init_parameters": {}}
component = FilterRetriever(document_store=document_store)

data = component.to_dict()
assert data == {
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
"init_parameters": {"document_store": {"type": "FilterDocStore", "init_parameters": {}}, "filters": None},
}

def test_to_dict_with_custom_init_parameters(self):
ds = InMemoryDocumentStore()
serialized_ds = ds.to_dict()

component = FilterRetriever(document_store=InMemoryDocumentStore(), filters={"lang": "en"})
data = component.to_dict()
assert data == {
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
"init_parameters": {"document_store": serialized_ds, "filters": {"lang": "en"}},
}

def test_from_dict(self):
valid_data = {
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
},
"filters": {"lang": "en"},
},
}
component = FilterRetriever.from_dict(valid_data)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.filters == {"lang": "en"}

def test_from_dict_without_docstore(self):
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
FilterRetriever.from_dict(data)

def test_retriever_init_filter(self, sample_document_store, sample_docs):
retriever = FilterRetriever(sample_document_store, filters={"field": "lang", "operator": "==", "value": "en"})
result = retriever.run()

assert "documents" in result
assert len(result["documents"]) == 3
assert TestFilterRetriever._documents_equal(result["documents"], sample_docs["en_docs"])

def test_retriever_runtime_filter(self, sample_document_store, sample_docs):
retriever = FilterRetriever(sample_document_store)
result = retriever.run(filters={"field": "lang", "operator": "==", "value": "en"})

assert "documents" in result
assert len(result["documents"]) == 3
assert TestFilterRetriever._documents_equal(result["documents"], sample_docs["en_docs"])

def test_retriever_init_filter_run_filter_override(self, sample_document_store, sample_docs):
retriever = FilterRetriever(sample_document_store, filters={"field": "lang", "operator": "==", "value": "en"})
result = retriever.run(filters={"field": "lang", "operator": "==", "value": "de"})

assert "documents" in result
assert len(result["documents"]) == 2
assert TestFilterRetriever._documents_equal(result["documents"], sample_docs["de_docs"])

@pytest.mark.integration
def test_run_with_pipeline(self, sample_document_store, sample_docs):
retriever = FilterRetriever(sample_document_store, filters={"field": "lang", "operator": "==", "value": "de"})

pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
result: Dict[str, Any] = pipeline.run(data={"retriever": {}})

assert result
assert "retriever" in result
results_docs = result["retriever"]["documents"]
assert results_docs
assert TestFilterRetriever._documents_equal(results_docs, sample_docs["de_docs"])

result: Dict[str, Any] = pipeline.run(
data={"retriever": {"filters": {"field": "lang", "operator": "==", "value": "en"}}}
)

assert result
assert "retriever" in result
results_docs = result["retriever"]["documents"]
assert results_docs
assert TestFilterRetriever._documents_equal(results_docs, sample_docs["en_docs"])