Skip to content

Commit

Permalink
feat: sub-document support for indexer
Browse files Browse the repository at this point in the history
Signed-off-by: maxwelljin2 <gejin@berkeley.edu>
  • Loading branch information
maxwelljin committed Jun 14, 2023
1 parent 7889270 commit a6fdd80
Show file tree
Hide file tree
Showing 14 changed files with 268 additions and 17 deletions.
43 changes: 43 additions & 0 deletions docarray/index/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,3 +1166,46 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
id, fields[0], '__'.join(fields[1:])
)
return self._get_root_doc_id(cur_root_id, root, '')

def __contains__(self, item: BaseDoc) -> bool:
"""Checks if a given BaseDoc item is contained in the index.
:param item: the given BaseDoc
:return: if the given BaseDoc item is contained in the index
"""
...

def _get_all_documents(self) -> Union[AnyDocArray, List]:
"""Retrieve all documents from the index
:return: a DocArray or list of documents
"""
...

def subindex_contains(self, item: BaseDoc) -> bool:
"""Checks if a given BaseDoc item is contained in the index or any of its subindices.
:param item: the given BaseDoc
:return: if the given BaseDoc item is contained in the index/subindices
"""
if self.num_docs() == 0:
return False

if safe_issubclass(type(item), BaseDoc):
docs = self._get_all_documents()
for doc in docs:
for field_name in doc.__fields__:
sub_doc = getattr(doc, field_name)
if (
safe_issubclass(type(sub_doc), BaseDoc)
and sub_doc.id == item.id
):
return True

return self.__contains__(item) or any(
index.subindex_contains(item) for index in self._subindices.values()
)
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
26 changes: 17 additions & 9 deletions docarray/index/backends/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,23 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], List[Any]]:
def _refresh(self, index_name: str):
self._client.indices.refresh(index=index_name)

def __contains__(self, item: BaseDoc) -> bool:
if safe_issubclass(type(item), BaseDoc):
if len(item.id) == 0:
return False
ret = self._client_mget([item.id])
return ret["docs"][0]["found"]
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)

def _get_all_documents(self) -> Union[AnyDocArray, List]:
response = self._client.search(index=self.index_name)
return self._dict_list_to_docarray(
[item["_source"] for item in response["hits"]["hits"]]
)

###############################################
# API Wrappers #
###############################################
Expand All @@ -694,12 +711,3 @@ def _client_search(self, **kwargs):

def _client_msearch(self, request: List[Dict[str, Any]]):
return self._client.msearch(index=self.index_name, searches=request)

def __contains__(self, item: BaseDoc) -> bool:
if safe_issubclass(type(item), BaseDoc):
ret = self._client_mget([item.id])
return ret["docs"][0]["found"]
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)
5 changes: 5 additions & 0 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,11 @@ def __contains__(self, item: BaseDoc):
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)

def _get_all_documents(self) -> Union[AnyDocArray, List]:
self._sqlite_cursor.execute("SELECT data FROM docs")
rows = self._sqlite_cursor.fetchall()
return [self._doc_from_bytes(row[0]) for row in rows]

def num_docs(self) -> int:
"""
Get the number of documents.
Expand Down
3 changes: 3 additions & 0 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def __contains__(self, item: BaseDoc):
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)

def _get_all_documents(self) -> Union[AnyDocArray, List]:
return self._docs

def persist(self, file: str = 'in_memory_index.bin') -> None:
"""Persist InMemoryExactNNIndex into a binary file."""
self._docs.save_binary(file=file)
Expand Down
10 changes: 10 additions & 0 deletions docarray/index/backends/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,16 @@ def __contains__(self, item: BaseDoc) -> bool:
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)

def _get_all_documents(self) -> Union[AnyDocArray, List]:
response, _ = self._client.scroll(
collection_name=self.index_name,
with_payload=True,
with_vectors=True,
)
return self._dict_list_to_docarray(
[self._convert_to_doc(point) for point in response]
)

def _del_items(self, doc_ids: Sequence[str]):
items = self._get_items(doc_ids)
if len(items) < len(doc_ids):
Expand Down
15 changes: 14 additions & 1 deletion docarray/index/backends/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,12 +776,25 @@ def __contains__(self, item: BaseDoc) -> bool:
)
.do()
)
return len(result["data"]["Get"][self.index_name]) > 0
docs = result["data"]["Get"][self.index_name]
return docs is not None and len(docs) > 0
else:
raise TypeError(
f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'"
)

def _get_all_documents(self) -> Union[AnyDocArray, List]:
result = self._client.query.get(self.index_name, ['docarrayid']).do()

return self._dict_list_to_docarray(
self._get_items(
[
list(doc.values())[0]
for doc in result["data"]["Get"][self.index_name]
]
)
)

class QueryBuilder(BaseDocIndex.QueryBuilder):
def __init__(self, document_index):
self._queries = [
Expand Down
3 changes: 2 additions & 1 deletion docarray/utils/_internal/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import get_origin
from typing_inspect import get_args, is_typevar, is_union_type

from docarray.typing.id import ID
from docarray.typing.tensor.abstract_tensor import AbstractTensor


Expand Down Expand Up @@ -45,6 +46,6 @@ def safe_issubclass(x: type, a_tuple: type) -> bool:
:return: A boolean value - 'True' if 'x' is a subclass of 'A_tuple', 'False' otherwise.
Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'.
"""
if (get_origin(x) in (list, tuple, dict, set)) or is_typevar(x):
if (get_origin(x) in (list, tuple, dict, set)) or is_typevar(x) or x == ID:
return False
return issubclass(x, a_tuple)
29 changes: 29 additions & 0 deletions tests/index/elastic/v7/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,32 @@ def test_subindex_del(index):
assert index._subindices['docs'].num_docs() == 20
assert index._subindices['list_docs'].num_docs() == 20
assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100


def test_subindex_contain(index):
# Checks for individual simple_docs within list_docs
for i in range(4):
doc = index[f'{i + 1}']
for simple_doc in doc.list_docs:
assert index.subindex_contains(simple_doc) is True
for nested_doc in simple_doc.docs:
assert index.subindex_contains(nested_doc) is True

invalid_doc = SimpleDoc(
id='non_existent',
simple_tens=np.zeros(10),
simple_text='invalid',
)
assert index.subindex_contains(invalid_doc) is False

# Checks for an empty doc
empty_doc = SimpleDoc(
id='',
simple_tens=np.zeros(10),
simple_text='',
)
assert index.subindex_contains(empty_doc) is False

# Empty index
empty_index = ElasticV7DocIndex[MyDoc]()
assert (empty_doc in empty_index) is False
33 changes: 27 additions & 6 deletions tests/index/elastic/v8/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,30 @@ def test_subindex_filter(index):
assert doc.id.split('-')[-1] == '0'


def test_subindex_del(index):
del index['0']
assert index.num_docs() == 4
assert index._subindices['docs'].num_docs() == 20
assert index._subindices['list_docs'].num_docs() == 20
assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100
def test_subindex_contain(index):
# Checks for individual simple_docs within list_docs
for i in range(4):
doc = index[f'{i + 1}']
for simple_doc in doc.list_docs:
assert index.subindex_contains(simple_doc) is True
for nested_doc in simple_doc.docs:
assert index.subindex_contains(nested_doc) is True

invalid_doc = SimpleDoc(
id='non_existent',
simple_tens=np.zeros(10),
simple_text='invalid',
)
assert index.subindex_contains(invalid_doc) is False

# Checks for an empty doc
empty_doc = SimpleDoc(
id='',
simple_tens=np.zeros(10),
simple_text='',
)
assert index.subindex_contains(empty_doc) is False

# Empty index
empty_index = ElasticDocIndex[MyDoc]()
assert (empty_doc in empty_index) is False
2 changes: 2 additions & 0 deletions tests/index/hnswlib/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,5 @@ class SimpleSchema(BaseDoc):
index_docs_new = [SimpleDoc(tens=np.zeros(10)) for _ in range(10)]
for doc in index_docs_new:
assert (doc in index) is False

print(index._get_all_documents())
29 changes: 29 additions & 0 deletions tests/index/hnswlib/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,32 @@ def test_subindex_del(index):
assert index._subindices['docs'].num_docs() == 20
assert index._subindices['list_docs'].num_docs() == 20
assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100


def test_subindex_contain(index):
# Checks for individual simple_docs within list_docs
for i in range(4):
doc = index[f'{i + 1}']
for simple_doc in doc.list_docs:
assert index.subindex_contains(simple_doc) is True
for nested_doc in simple_doc.docs:
assert index.subindex_contains(nested_doc) is True

invalid_doc = SimpleDoc(
id='non_existent',
simple_tens=np.zeros(10),
simple_text='invalid',
)
assert index.subindex_contains(invalid_doc) is False

# Checks for an empty doc
empty_doc = SimpleDoc(
id='',
simple_tens=np.zeros(10),
simple_text='',
)
assert index.subindex_contains(empty_doc) is False

# Empty index
empty_index = HnswDocumentIndex[MyDoc]()
assert (empty_doc in empty_index) is False
29 changes: 29 additions & 0 deletions tests/index/in_memory/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,32 @@ def test_subindex_del(index):
assert index._subindices['docs'].num_docs() == 20
assert index._subindices['list_docs'].num_docs() == 20
assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100


def test_subindex_contain(index):
# Checks for individual simple_docs within list_docs
for i in range(4):
doc = index[f'{i + 1}']
for simple_doc in doc.list_docs:
assert index.subindex_contains(simple_doc) is True
for nested_doc in simple_doc.docs:
assert index.subindex_contains(nested_doc) is True

invalid_doc = SimpleDoc(
id='non_existent',
simple_tens=np.zeros(10),
simple_text='invalid',
)
assert index.subindex_contains(invalid_doc) is False

# Checks for an empty doc
empty_doc = SimpleDoc(
id='',
simple_tens=np.zeros(10),
simple_text='',
)
assert index.subindex_contains(empty_doc) is False

# Empty index
empty_index = InMemoryExactNNIndex[MyDoc]()
assert (empty_doc in empty_index) is False
29 changes: 29 additions & 0 deletions tests/index/qdrant/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,32 @@ def test_subindex_del(index):
assert index._subindices['docs'].num_docs() == 20
assert index._subindices['list_docs'].num_docs() == 20
assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100


def test_subindex_contain(index):
# Checks for individual simple_docs within list_docs
for i in range(4):
doc = index[f'{i + 1}']
for simple_doc in doc.list_docs:
assert index.subindex_contains(simple_doc) is True
for nested_doc in simple_doc.docs:
assert index.subindex_contains(nested_doc) is True

invalid_doc = SimpleDoc(
id='non_existent',
simple_tens=np.zeros(10),
simple_text='invalid',
)
assert index.subindex_contains(invalid_doc) is False

# Checks for an empty doc
empty_doc = SimpleDoc(
id='',
simple_tens=np.zeros(10),
simple_text='',
)
assert index.subindex_contains(empty_doc) is False

# Empty index
empty_index = QdrantDocumentIndex[MyDoc]()
assert (empty_doc in empty_index) is False
29 changes: 29 additions & 0 deletions tests/index/weaviate/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,32 @@ def test_subindex_del(index):
assert index._subindices['docs'].num_docs() == 20
assert index._subindices['list_docs'].num_docs() == 20
assert index._subindices['list_docs']._subindices['docs'].num_docs() == 100


def test_subindex_contain(index):
# Checks for individual simple_docs within list_docs
for i in range(4):
doc = index[f'{i + 1}']
for simple_doc in doc.list_docs:
assert index.subindex_contains(simple_doc) is True
for nested_doc in simple_doc.docs:
assert index.subindex_contains(nested_doc) is True

invalid_doc = SimpleDoc(
id='non_existent',
simple_tens=np.zeros(10),
simple_text='invalid',
)
assert index.subindex_contains(invalid_doc) is False

# Checks for an empty doc
empty_doc = SimpleDoc(
id='',
simple_tens=np.zeros(10),
simple_text='',
)
assert index.subindex_contains(empty_doc) is False

# Empty index
empty_index = WeaviateDocumentIndex[MyDoc]()
assert (empty_doc in empty_index) is False

0 comments on commit a6fdd80

Please sign in to comment.