From f36c621104f64d4e88aeb6673e1a8ba34c3472d1 Mon Sep 17 00:00:00 2001 From: Saba Sturua <45267439+jupyterjazz@users.noreply.github.com> Date: Tue, 13 Jun 2023 10:06:43 +0200 Subject: [PATCH] fix: find_and_filter for inmemory (#1642) Signed-off-by: jupyterjazz --- docarray/index/backends/helper.py | 13 +++++++++++-- docarray/index/backends/in_memory.py | 1 + tests/index/in_memory/test_in_memory.py | 6 ++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/docarray/index/backends/helper.py b/docarray/index/backends/helper.py index 7f914d47353..268f623ab18 100644 --- a/docarray/index/backends/helper.py +++ b/docarray/index/backends/helper.py @@ -21,13 +21,22 @@ def inner(self, *args, **kwargs): def _execute_find_and_filter_query( - doc_index: BaseDocIndex, query: List[Tuple[str, Dict]] + doc_index: BaseDocIndex, query: List[Tuple[str, Dict]], reverse_order: bool = False ) -> FindResult: """ Executes all find calls from query first using `doc_index.find()`, and filtering queries after that using DocArray's `filter_docs()`. Text search is not supported. + + :param doc_index: Document index instance. + Either InMemoryExactNNIndex or HnswDocumentIndex. + :param query: Dictionary containing search and filtering configuration. + :param reverse_order: Flag indicating whether to sort in descending order. + If set to False (default), the sorting will be in ascending order. + This option is necessary because, depending on the index, lower scores + can correspond to better matches, and vice versa. + :return: Sorted documents and their corresponding scores. """ docs_found = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))([]) filter_conditions = [] @@ -57,7 +66,7 @@ def _execute_find_and_filter_query( docs_and_scores = zip( docs_filtered, (doc_to_score[doc.id] for doc in docs_filtered) ) - docs_sorted = sorted(docs_and_scores, key=lambda x: x[1]) + docs_sorted = sorted(docs_and_scores, key=lambda x: x[1], reverse=reverse_order) out_docs, out_scores = zip(*docs_sorted) return FindResult(documents=out_docs, scores=out_scores) diff --git a/docarray/index/backends/in_memory.py b/docarray/index/backends/in_memory.py index 20f1af5dfc8..390c7d8de9c 100644 --- a/docarray/index/backends/in_memory.py +++ b/docarray/index/backends/in_memory.py @@ -299,6 +299,7 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: find_res = _execute_find_and_filter_query( doc_index=self, query=query, + reverse_order=True, ) return find_res diff --git a/tests/index/in_memory/test_in_memory.py b/tests/index/in_memory/test_in_memory.py index 2b6b172119f..acac38d8226 100644 --- a/tests/index/in_memory/test_in_memory.py +++ b/tests/index/in_memory/test_in_memory.py @@ -20,7 +20,9 @@ class SchemaDoc(BaseDoc): def docs(): docs = DocList[SchemaDoc]( [ - SchemaDoc(text=f'hello {i}', price=i, tensor=np.array([i] * 10)) + SchemaDoc( + text=f'hello {i}', price=i, tensor=np.array([i + j for j in range(10)]) + ) for i in range(9) ] ) @@ -126,7 +128,7 @@ def test_concatenated_queries(doc_index): @pytest.mark.parametrize( - 'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, None, 3)] + 'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, None, 1)] ) def test_query_builder_limits(doc_index, find_limit, filter_limit, expected_docs): query = SchemaDoc(text='query', price=3, tensor=np.array([3] * 10))