Skip to content

Commit

Permalink
Implement FilterRetriever and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bglearning committed Jan 30, 2024
1 parent 922aa0b commit d7ad9b9
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 24 deletions.
46 changes: 22 additions & 24 deletions haystack/components/retrievers/filter_retriever.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,30 @@
import importlib
import logging

from typing import Dict, List, Any, Optional

from haystack import component, Document, default_to_dict, default_from_dict, DeserializationError
from haystack.document_stores.types import DocumentStore


logger = logging.getLogger(__name__)


@component
class FilterRetriever:
"""
Retrieves documents that match the provided filters.
"""

def __init__(
self, document_store: DocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None
):
def __init__(self, document_store: DocumentStore, filters: Optional[Dict[str, Any]] = None):
"""
Create the FilterRetriever component.
:param document_store: An instance of InMemoryDocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Defaults to `None`.
:param top_k: The maximum number of documents to retrieve. Defaults to `None` in which case all documents are retrieved.
:raises ValueError: If the `top_k` is specified but is not > 0.
"""
self.document_store = document_store

if top_k is not None and top_k <= 0:
raise ValueError(f"top_k must be greater than 0. Currently, the top_k is {top_k}")

self.filters = filters
self.top_k = top_k

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -41,7 +37,7 @@ def to_dict(self) -> Dict[str, Any]:
Serialize this component to a dictionary.
"""
docstore = self.document_store.to_dict()
return default_to_dict(self, document_store=docstore, filters=self.filters, top_k=self.top_k)
return default_to_dict(self, document_store=docstore, filters=self.filters)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FilterRetriever":
Expand All @@ -53,27 +49,29 @@ def from_dict(cls, data: Dict[str, Any]) -> "FilterRetriever":
raise DeserializationError("Missing 'document_store' in serialization data")
if "type" not in init_params["document_store"]:
raise DeserializationError("Missing 'type' in document store's serialization data")
data["init_parameters"]["document_store"] = FilterRetriever.from_dict(data["init_parameters"]["document_store"])
try:
module_name, type_ = init_params["document_store"]["type"].rsplit(".", 1)
logger.debug("Trying to import %s", module_name)
module = importlib.import_module(module_name)
except (ImportError, DeserializationError) as e:
raise DeserializationError(
f"DocumentStore of type '{init_params['document_store']['type']}' not correctly imported"
) from e

docstore_class = getattr(module, type_)
data["init_parameters"]["document_store"] = docstore_class.from_dict(data["init_parameters"]["document_store"])
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
def run(self, filters: Optional[Dict[str, Any]] = None):
"""
Run the FilterRetriever on the given input data.
:param query: The query string for the Retriever. It is ignored by the retriever.
:param filters: A dictionary with filters to narrow down the search space.
:param top_k: The maximum number of documents to return.
If not specified, the value provided at initialization is used.
:return: The retrieved documents.
:raises ValueError: If the specified DocumentStore is not found
"""
if filters is None:
filters = self.filters
if top_k is None:
top_k = self.top_k

docs = self.document_store.filter_documents(filters=filters)
if top_k is not None:
docs = docs[:top_k]
return {"documents": docs}
return {"documents": self.document_store.filter_documents(filters=filters)}
5 changes: 5 additions & 0 deletions releasenotes/notes/add-filter-retriever-8901af26144d1a17.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Add FilterRetriever.
It retrieves documents that match the provided (either at init or runtime) filters.
136 changes: 136 additions & 0 deletions test/components/retrievers/test_filter_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Dict, Any, List

import pytest

from haystack import Pipeline, DeserializationError
from haystack.testing.factory import document_store_class
from haystack.components.retrievers.filter_retriever import FilterRetriever
from haystack.dataclasses import Document
from haystack.document_stores.in_memory import InMemoryDocumentStore


@pytest.fixture()
def sample_docs():
en_docs = [
Document(content="Javascript is a popular programming language", meta={"lang": "en"}),
Document(content="Python is a popular programming language", meta={"lang": "en"}),
Document(content="A chromosome is a package of DNA ", meta={"lang": "en"}),
]
de_docs = [
Document(content="python ist eine beliebte Programmiersprache", meta={"lang": "de"}),
Document(content="javascript ist eine beliebte Programmiersprache", meta={"lang": "de"}),
]
all_docs = en_docs + de_docs
return {"en_docs": en_docs, "de_docs": de_docs, "all_docs": all_docs}


@pytest.fixture()
def sample_document_store(sample_docs):
doc_store = InMemoryDocumentStore()
doc_store.write_documents(sample_docs["all_docs"])
return doc_store


class TestFilterRetriever:
@classmethod
def _documents_equal(cls, docs1: List[Document], docs2: List[Document]) -> bool:
return set(d.content for d in docs1) == set(d.content for d in docs2) # noqa: C401

def test_init_default(self):
retriever = FilterRetriever(InMemoryDocumentStore())
assert retriever.filters is None

def test_init_with_parameters(self):
retriever = FilterRetriever(InMemoryDocumentStore(), filters={"lang": "en"})
assert retriever.filters == {"lang": "en"}

def test_to_dict(self):
FilterDocStore = document_store_class("MyFakeStore", bases=(InMemoryDocumentStore,))
document_store = FilterDocStore()
document_store.to_dict = lambda: {"type": "FilterDocStore", "init_parameters": {}}
component = FilterRetriever(document_store=document_store)

data = component.to_dict()
assert data == {
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
"init_parameters": {"document_store": {"type": "FilterDocStore", "init_parameters": {}}, "filters": None},
}

def test_to_dict_with_custom_init_parameters(self):
ds = InMemoryDocumentStore()
serialized_ds = ds.to_dict()

component = FilterRetriever(document_store=InMemoryDocumentStore(), filters={"lang": "en"})
data = component.to_dict()
assert data == {
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
"init_parameters": {"document_store": serialized_ds, "filters": {"lang": "en"}},
}

def test_from_dict(self):
valid_data = {
"type": "haystack.components.retrievers.filter_retriever.FilterRetriever",
"init_parameters": {
"document_store": {
"type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
"init_parameters": {},
},
"filters": {"lang": "en"},
},
}
component = FilterRetriever.from_dict(valid_data)
assert isinstance(component.document_store, InMemoryDocumentStore)
assert component.filters == {"lang": "en"}

def test_from_dict_without_docstore(self):
data = {"type": "InMemoryBM25Retriever", "init_parameters": {}}
with pytest.raises(DeserializationError, match="Missing 'document_store' in serialization data"):
FilterRetriever.from_dict(data)

def test_retriever_init_filter(self, sample_document_store, sample_docs):
retriever = FilterRetriever(sample_document_store, filters={"field": "lang", "operator": "==", "value": "en"})
result = retriever.run()

assert "documents" in result
assert len(result["documents"]) == 3
assert TestFilterRetriever._documents_equal(result["documents"], sample_docs["en_docs"])

def test_retriever_runtime_filter(self, sample_document_store, sample_docs):
retriever = FilterRetriever(sample_document_store)
result = retriever.run(filters={"field": "lang", "operator": "==", "value": "en"})

assert "documents" in result
assert len(result["documents"]) == 3
assert TestFilterRetriever._documents_equal(result["documents"], sample_docs["en_docs"])

def test_retriever_init_filter_run_filter_override(self, sample_document_store, sample_docs):
retriever = FilterRetriever(sample_document_store, filters={"field": "lang", "operator": "==", "value": "en"})
result = retriever.run(filters={"field": "lang", "operator": "==", "value": "de"})

assert "documents" in result
assert len(result["documents"]) == 2
assert TestFilterRetriever._documents_equal(result["documents"], sample_docs["de_docs"])

@pytest.mark.integration
def test_run_with_pipeline(self, sample_document_store, sample_docs):
retriever = FilterRetriever(sample_document_store, filters={"field": "lang", "operator": "==", "value": "de"})

pipeline = Pipeline()
pipeline.add_component("retriever", retriever)
result: Dict[str, Any] = pipeline.run(data={"retriever": {}})

assert result
assert "retriever" in result
results_docs = result["retriever"]["documents"]
assert results_docs
assert TestFilterRetriever._documents_equal(results_docs, sample_docs["de_docs"])

result: Dict[str, Any] = pipeline.run(
data={"retriever": {"filters": {"field": "lang", "operator": "==", "value": "en"}}}
)

assert result
assert "retriever" in result
results_docs = result["retriever"]["documents"]
assert results_docs
assert TestFilterRetriever._documents_equal(results_docs, sample_docs["en_docs"])

0 comments on commit d7ad9b9

Please sign in to comment.