Skip to content

Commit

Permalink
feat: make DocList an actual Python List (#1457)
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Fontanals Martinez <joan.martinez@jina.ai>
Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
Co-authored-by: samsja <sami.jaghouar@hotmail.fr>
  • Loading branch information
JoanFM and samsja committed Apr 27, 2023
1 parent 7ba430c commit b3649b4
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 142 deletions.
82 changes: 33 additions & 49 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import io
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
List,
MutableSequence,
Expand All @@ -15,15 +13,13 @@
overload,
)

from typing_extensions import SupportsIndex
from typing_inspect import is_union_type

from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.io import IOMixinArray
from docarray.array.doc_list.pushpull import PushPullMixin
from docarray.array.doc_list.sequence_indexing_mixin import (
IndexingSequenceMixin,
IndexIterType,
)
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.typing import NdArray

Expand All @@ -40,25 +36,11 @@
T_doc = TypeVar('T_doc', bound=BaseDoc)


def _delegate_meth_to_data(meth_name: str) -> Callable:
"""
create a function that mimic a function call to the data attribute of the
DocList
:param meth_name: name of the method
:return: a method that mimic the meth_name
"""
func = getattr(list, meth_name)

@wraps(func)
def _delegate_meth(self, *args, **kwargs):
return getattr(self._data, meth_name)(*args, **kwargs)

return _delegate_meth


class DocList(
IndexingSequenceMixin[T_doc], PushPullMixin, IOMixinArray, AnyDocArray[T_doc]
ListAdvancedIndexing[T_doc],
PushPullMixin,
IOMixinArray,
AnyDocArray[T_doc],
):
"""
DocList is a container of Documents.
Expand Down Expand Up @@ -129,8 +111,13 @@ class Image(BaseDoc):
def __init__(
self,
docs: Optional[Iterable[T_doc]] = None,
validate_input_docs: bool = True,
):
self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else []
if validate_input_docs:
docs = self._validate_docs(docs) if docs else []
else:
docs = docs if docs else []
super().__init__(docs)

@classmethod
def construct(
Expand All @@ -143,9 +130,7 @@ def construct(
:param docs: a Sequence (list) of Document with the same schema
:return: a `DocList` object
"""
new_docs = cls.__new__(cls)
new_docs._data = docs if isinstance(docs, list) else list(docs)
return new_docs
return cls(docs, False)

def __eq__(self, other: Any) -> bool:
if self.__len__() != other.__len__():
Expand All @@ -168,12 +153,6 @@ def _validate_one_doc(self, doc: T_doc) -> T_doc:
raise ValueError(f'{doc} is not a {self.doc_type}')
return doc

def __len__(self):
return len(self._data)

def __iter__(self):
return iter(self._data)

def __bytes__(self) -> bytes:
with io.BytesIO() as bf:
self._write_bytes(bf=bf)
Expand All @@ -185,7 +164,7 @@ def append(self, doc: T_doc):
as the `.doc_type` of this `DocList` otherwise it will fail.
:param doc: A Document
"""
self._data.append(self._validate_one_doc(doc))
super().append(self._validate_one_doc(doc))

def extend(self, docs: Iterable[T_doc]):
"""
Expand All @@ -194,31 +173,28 @@ def extend(self, docs: Iterable[T_doc]):
fail.
:param docs: Iterable of Documents
"""
self._data.extend(self._validate_docs(docs))
super().extend(self._validate_docs(docs))

def insert(self, i: int, doc: T_doc):
def insert(self, i: SupportsIndex, doc: T_doc):
"""
Insert a Document to the `DocList`. The Document must be from the same
class as the doc_type of this `DocList` otherwise it will fail.
:param i: index to insert
:param doc: A Document
"""
self._data.insert(i, self._validate_one_doc(doc))

pop = _delegate_meth_to_data('pop')
remove = _delegate_meth_to_data('remove')
reverse = _delegate_meth_to_data('reverse')
sort = _delegate_meth_to_data('sort')
super().insert(i, self._validate_one_doc(doc))

def _get_data_column(
self: T,
field: str,
) -> Union[MutableSequence, T, 'TorchTensor', 'NdArray']:
"""Return all values of the fields from all docs this doc_list contains
:param field: name of the fields to extract
:return: Returns a list of the field value for each document
in the doc_list like container
"""Return all v @classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):alues of the fields from all docs this doc_list contains
@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
:param field: name of the fields to extract
:return: Returns a list of the field value for each document
in the doc_list like container
"""
field_type = self.__class__.doc_type._get_field_type(field)

Expand Down Expand Up @@ -299,7 +275,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocListProto') -> T:
return super().from_protobuf(pb_msg)

@overload
def __getitem__(self, item: int) -> T_doc:
def __getitem__(self, item: SupportsIndex) -> T_doc:
...

@overload
Expand All @@ -308,3 +284,11 @@ def __getitem__(self: T, item: IndexIterType) -> T:

def __getitem__(self, item):
return super().__getitem__(item)

@classmethod
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):

if isinstance(item, type) and issubclass(item, BaseDoc):
return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore
else:
return super().__class_getitem__(item)
10 changes: 1 addition & 9 deletions docarray/array/doc_list/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __getitem__(self, item: slice):

class IOMixinArray(Iterable[T_doc]):
doc_type: Type[T_doc]
_data: List[T_doc]

@abstractmethod
def __len__(self):
Expand Down Expand Up @@ -329,14 +328,7 @@ def to_json(self) -> bytes:
"""Convert the object into JSON bytes. Can be loaded via `.from_json`.
:return: JSON serialization of `DocList`
"""
return orjson_dumps(self._data)

def _docarray_to_json_compatible(self) -> List[T_doc]:
"""
Convert itself into a json compatible object
:return: A list of documents
"""
return self._data
return orjson_dumps(self)

@classmethod
def from_csv(
Expand Down
2 changes: 1 addition & 1 deletion docarray/array/doc_vec/column_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Union,
)

from docarray.array.doc_vec.list_advance_indexing import ListAdvancedIndexing
from docarray.array.list_advance_indexing import ListAdvancedIndexing
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor

Expand Down
6 changes: 3 additions & 3 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.doc_list import DocList
from docarray.array.doc_vec.column_storage import ColumnStorage, ColumnStorageView
from docarray.array.doc_vec.list_advance_indexing import ListAdvancedIndexing
from docarray.array.list_advance_indexing import ListAdvancedIndexing
from docarray.base_doc import BaseDoc
from docarray.base_doc.mixins.io import _type_to_protobuf
from docarray.typing import NdArray
Expand Down Expand Up @@ -271,9 +271,9 @@ def _get_data_column(
in the array like container
"""
if field in self._storage.any_columns.keys():
return self._storage.any_columns[field].data
return self._storage.any_columns[field]
elif field in self._storage.docs_vec_columns.keys():
return self._storage.docs_vec_columns[field].data
return self._storage.docs_vec_columns[field]
elif field in self._storage.columns.keys():
return self._storage.columns[field]
else:
Expand Down
41 changes: 0 additions & 41 deletions docarray/array/doc_vec/list_advance_indexing.py

This file was deleted.

0 comments on commit b3649b4

Please sign in to comment.