Skip to content

Commit

Permalink
fix: find_and_filter for inmemory (#1642)
Browse files Browse the repository at this point in the history
Signed-off-by: jupyterjazz <saba.sturua@jina.ai>
  • Loading branch information
jupyterjazz committed Jun 13, 2023
1 parent eedd83c commit f36c621
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
13 changes: 11 additions & 2 deletions docarray/index/backends/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,22 @@ def inner(self, *args, **kwargs):


def _execute_find_and_filter_query(
doc_index: BaseDocIndex, query: List[Tuple[str, Dict]]
doc_index: BaseDocIndex, query: List[Tuple[str, Dict]], reverse_order: bool = False
) -> FindResult:
"""
Executes all find calls from query first using `doc_index.find()`,
and filtering queries after that using DocArray's `filter_docs()`.
Text search is not supported.
:param doc_index: Document index instance.
Either InMemoryExactNNIndex or HnswDocumentIndex.
:param query: Dictionary containing search and filtering configuration.
:param reverse_order: Flag indicating whether to sort in descending order.
If set to False (default), the sorting will be in ascending order.
This option is necessary because, depending on the index, lower scores
can correspond to better matches, and vice versa.
:return: Sorted documents and their corresponding scores.
"""
docs_found = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))([])
filter_conditions = []
Expand Down Expand Up @@ -57,7 +66,7 @@ def _execute_find_and_filter_query(
docs_and_scores = zip(
docs_filtered, (doc_to_score[doc.id] for doc in docs_filtered)
)
docs_sorted = sorted(docs_and_scores, key=lambda x: x[1])
docs_sorted = sorted(docs_and_scores, key=lambda x: x[1], reverse=reverse_order)
out_docs, out_scores = zip(*docs_sorted)

return FindResult(documents=out_docs, scores=out_scores)
1 change: 1 addition & 0 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
find_res = _execute_find_and_filter_query(
doc_index=self,
query=query,
reverse_order=True,
)
return find_res

Expand Down
6 changes: 4 additions & 2 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class SchemaDoc(BaseDoc):
def docs():
docs = DocList[SchemaDoc](
[
SchemaDoc(text=f'hello {i}', price=i, tensor=np.array([i] * 10))
SchemaDoc(
text=f'hello {i}', price=i, tensor=np.array([i + j for j in range(10)])
)
for i in range(9)
]
)
Expand Down Expand Up @@ -126,7 +128,7 @@ def test_concatenated_queries(doc_index):


@pytest.mark.parametrize(
'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, None, 3)]
'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, None, 1)]
)
def test_query_builder_limits(doc_index, find_limit, filter_limit, expected_docs):
query = SchemaDoc(text='query', price=3, tensor=np.array([3] * 10))
Expand Down

0 comments on commit f36c621

Please sign in to comment.