Skip to content

Commit

Permalink
bulk root doc access in hnswlib
Browse files Browse the repository at this point in the history
  • Loading branch information
Oytun Tez committed Feb 21, 2024
1 parent 2da17d0 commit 7555eb3
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,51 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
)
return self._get_root_doc_id(cur_root_id, root, '')

def _get_root_doc_ids(self, ids: Sequence[str], root: str, sub: str, return_docs: bool = False) -> Union[Sequence[str], DocList]:
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
:param id: id of the subindex Document
:param root: root index name
:param sub: subindex name
:return: the root_id of the Document
"""
subindex: HnswDocumentIndex = self._subindices[root]

if not sub:
sub_docs = subindex._get_items(ids, out=False) # type: ignore
parent_ids = [sub_doc['parent_id'] if isinstance(sub_doc, dict) else sub_doc.parent_id for sub_doc in sub_docs]
if not return_docs:
return parent_ids
return self._get_items(parent_ids)
else:
fields = sub.split('__')
cur_root_ids = subindex._get_root_doc_ids(
ids, fields[0], '__'.join(fields[1:]), return_docs
)
if isinstance(cur_root_ids, DocList):
cur_root_ids = cur_root_ids.id
return self._get_root_doc_ids(cur_root_ids, root, '', return_docs)

def _get_root_docs(self, sub_docs: DocList, root: str, sub: str) -> DocList:
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.
:param id: id of the subindex Document
:param root: root index name
:param sub: subindex name
:return: the root_id of the Document
"""
subindex: HnswDocumentIndex = self._subindices[root]

if not sub:
parent_ids = [sub_doc['parent_id'] if isinstance(sub_doc, dict) else sub_doc.__getattr__('parent_id') for sub_doc in sub_docs]
return self._get_items(parent_ids)
else:
fields = sub.split('__')
cur_roots = subindex._get_root_docs(
sub_docs.id, fields[0], '__'.join(fields[1:])
)
return self._get_root_docs(cur_roots, root, '')

def _get_column_names(self) -> List[str]:
"""
Retrieves the column names of the 'docs' table in the SQLite database.
Expand Down

0 comments on commit 7555eb3

Please sign in to comment.