Skip to content

Commit

Permalink
refactor(da): remove tensor type from DocumentArray init (#1268)
Browse files Browse the repository at this point in the history
* fix: remove tensor type from DocumentArray

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: fix test

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: fix tensorflow test

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: docstrng

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* feat: apply charllote suggestion

Co-authored-by: Charlotte Gerhaher <charlotte.gerhaher@jina.ai>
Signed-off-by: samsja <55492238+samsja@users.noreply.github.com>

* feat: apply saba suggestion

Co-authored-by: Saba Sturua <45267439+jupyterjazz@users.noreply.github.com>
Signed-off-by: samsja <55492238+samsja@users.noreply.github.com>

---------

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
Signed-off-by: samsja <55492238+samsja@users.noreply.github.com>
Co-authored-by: Charlotte Gerhaher <charlotte.gerhaher@jina.ai>
Co-authored-by: Saba Sturua <45267439+jupyterjazz@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 22, 2023
1 parent 6707f4c commit 64532dd
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 90 deletions.
2 changes: 0 additions & 2 deletions docarray/array/abstract_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from docarray.base_document import BaseDocument
from docarray.display.document_array_summary import DocumentArraySummary
from docarray.typing import NdArray
from docarray.typing.abstract_type import AbstractType
from docarray.utils._typing import change_cls_name

Expand All @@ -36,7 +35,6 @@

class AnyDocumentArray(Sequence[T_doc], Generic[T_doc], AbstractType):
document_type: Type[BaseDocument]
tensor_type: Type['AbstractTensor'] = NdArray
__typed_da__: Dict[Type['AnyDocumentArray'], Dict[Type[BaseDocument], Type]] = {}

def __repr__(self):
Expand Down
18 changes: 9 additions & 9 deletions docarray/array/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class Image(BaseDocument):
del da[0:5] # remove elements for 0 to 5 from DocumentArray
:param docs: iterable of Document
:param tensor_type: Class used to wrap the tensors of the Documents when stacked
"""

Expand All @@ -126,27 +125,22 @@ class Image(BaseDocument):
def __init__(
self,
docs: Optional[Iterable[T_doc]] = None,
tensor_type: Type['AbstractTensor'] = NdArray,
):
self._data: List[T_doc] = list(self._validate_docs(docs)) if docs else []
self.tensor_type = tensor_type

@classmethod
def construct(
cls: Type[T],
docs: Sequence[T_doc],
tensor_type: Type['AbstractTensor'] = NdArray,
) -> T:
"""
Create a DocumentArray without validation any data. The data must come from a
trusted source
:param docs: a Sequence (list) of Document with the same schema
:param tensor_type: Class used to wrap the tensors of the Documents when stacked
:return:
"""
da = cls.__new__(cls)
da._data = docs if isinstance(docs, list) else list(docs)
da.tensor_type = tensor_type
return da

def _validate_docs(self, docs: Iterable[T_doc]) -> Iterable[T_doc]:
Expand Down Expand Up @@ -227,7 +221,7 @@ def _get_data_column(
# most likely a bug in mypy though
# bug reported here https://github.com/python/mypy/issues/14111
return DocumentArray.__class_getitem__(field_type)(
(getattr(doc, field) for doc in self), tensor_type=self.tensor_type
(getattr(doc, field) for doc in self),
)
else:
return [getattr(doc, field) for doc in self]
Expand All @@ -247,15 +241,21 @@ def _set_data_column(
for doc, value in zip(self, values):
setattr(doc, field, value)

def stack(self) -> 'DocumentArrayStacked':
def stack(
self,
tensor_type: Type['AbstractTensor'] = NdArray,
) -> 'DocumentArrayStacked':
"""
Convert the DocumentArray into a DocumentArrayStacked. `Self` cannot be used
afterwards
:param tensor_type: Tensor Class used to wrap the stacked tensors. This is useful
if the BaseDocument has some undefined tensor type like AnyTensor or Union of NdArray and TorchTensor
:return: A DocumentArrayStacked of the same document type as self
"""
from docarray.array.stacked.array_stacked import DocumentArrayStacked

return DocumentArrayStacked.__class_getitem__(self.document_type)(
self, tensor_type=self.tensor_type
self, tensor_type=tensor_type
)

@classmethod
Expand Down
24 changes: 15 additions & 9 deletions docarray/array/stacked/array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ class DocumentArrayStacked(AnyDocumentArray[T_doc]):
numpy/PyTorch.
:param docs: a DocumentArray
:param tensor_type: Class used to wrap the stacked tensors
:param tensor_type: Tensor Class used to wrap the stacked tensors. This is useful
if the BaseDocument of this DocumentArrayStacked has some undefined tensor type like
AnyTensor or Union of NdArray and TorchTensor
"""

document_type: Type[T_doc]
Expand Down Expand Up @@ -158,12 +159,17 @@ def __init__(
cast(AbstractTensor, tensor_columns[field_name])[i] = val

elif issubclass(field_type, BaseDocument):
doc_columns[field_name] = getattr(docs, field_name).stack()
doc_columns[field_name] = getattr(docs, field_name).stack(
tensor_type=self.tensor_type
)

elif issubclass(field_type, DocumentArray):
elif issubclass(field_type, AnyDocumentArray):
docs_list = list()
for doc in docs:
docs_list.append(getattr(doc, field_name).stack())
da = getattr(doc, field_name)
if isinstance(da, DocumentArray):
da = da.stack(tensor_type=self.tensor_type)
docs_list.append(da)
da_columns[field_name] = ListAdvancedIndexing(docs_list)
else:
any_columns[field_name] = ListAdvancedIndexing(
Expand Down Expand Up @@ -318,7 +324,9 @@ def _set_data_and_columns(
f'{value} schema : {value.document_type} is not compatible with '
f'this DocumentArrayStacked schema : {self.document_type}'
)
processed_value = cast(T, value.stack()) # we need to copy data here
processed_value = cast(
T, value.stack(tensor_type=self.tensor_type)
) # we need to copy data here

elif isinstance(value, DocumentArrayStacked):
if not issubclass(value.document_type, self.document_type):
Expand Down Expand Up @@ -507,9 +515,7 @@ def unstack(self: T) -> DocumentArray[T_doc]:

del self._storage

return DocumentArray.__class_getitem__(self.document_type).construct(
docs, tensor_type=self.tensor_type
)
return DocumentArray.__class_getitem__(self.document_type).construct(docs)

def traverse_flat(
self,
Expand Down
8 changes: 4 additions & 4 deletions docarray/data/torch_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from torch.utils.data import Dataset

from docarray import BaseDocument, DocumentArray
from docarray import BaseDocument, DocumentArray, DocumentArrayStacked
from docarray.typing import TorchTensor
from docarray.utils._typing import change_cls_name

Expand Down Expand Up @@ -123,13 +123,13 @@ def __getitem__(self, item: int):
def collate_fn(cls, batch: List[T_doc]):
doc_type = cls.document_type
if doc_type:
batch_da = DocumentArray[doc_type]( # type: ignore
batch_da = DocumentArrayStacked[doc_type]( # type: ignore
batch,
tensor_type=TorchTensor,
)
else:
batch_da = DocumentArray(batch, tensor_type=TorchTensor)
return batch_da.stack()
batch_da = DocumentArrayStacked(batch, tensor_type=TorchTensor)
return batch_da

@classmethod
def __class_getitem__(cls, item: Type[BaseDocument]) -> Type['MultiModalDataset']:
Expand Down
52 changes: 22 additions & 30 deletions tests/units/array/stack/test_array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,10 @@ def test_any_tensor_with_torch(tensor_type, tensor):
class ImageDoc(BaseDocument):
tensor: AnyTensor

da = DocumentArray[ImageDoc](
da = DocumentArrayStacked[ImageDoc](
[ImageDoc(tensor=tensor) for _ in range(10)],
tensor_type=tensor_type,
).stack()
)

for i in range(len(da)):
assert (da[i].tensor == tensor).all()
Expand All @@ -284,10 +284,10 @@ class ImageDoc(BaseDocument):
class TopDoc(BaseDocument):
img: ImageDoc

da = DocumentArray[TopDoc](
da = DocumentArrayStacked[TopDoc](
[TopDoc(img=ImageDoc(tensor=tensor)) for _ in range(10)],
tensor_type=TorchTensor,
).stack()
)

for i in range(len(da)):
assert (da.img[i].tensor == tensor).all()
Expand All @@ -300,9 +300,9 @@ def test_dict_stack():
class MyDoc(BaseDocument):
my_dict: Dict[str, int]

da = DocumentArray[MyDoc](
da = DocumentArrayStacked[MyDoc](
[MyDoc(my_dict={'a': 1, 'b': 2}) for _ in range(10)]
).stack()
)

da.my_dict

Expand All @@ -314,9 +314,9 @@ class Doc(BaseDocument):

N = 10

da = DocumentArray[Doc](
da = DocumentArrayStacked[Doc](
[Doc(text=f'hello{i}', tensor=np.zeros((3, 224, 224))) for i in range(N)]
).stack()
)

da_sliced = da[0:10:2]
assert isinstance(da_sliced, DocumentArrayStacked)
Expand All @@ -334,9 +334,7 @@ def test_stack_embedding():
class MyDoc(BaseDocument):
embedding: AnyEmbedding

da = DocumentArray[MyDoc](
[MyDoc(embedding=np.zeros(10)) for _ in range(10)]
).stack()
da = DocumentArrayStacked[MyDoc]([MyDoc(embedding=np.zeros(10)) for _ in range(10)])

assert 'embedding' in da._storage.tensor_columns.keys()
assert (da.embedding == np.zeros((10, 10))).all()
Expand All @@ -347,18 +345,17 @@ def test_stack_none(tensor_backend):
class MyDoc(BaseDocument):
tensor: Optional[AnyTensor]

da = DocumentArray[MyDoc](
da = DocumentArrayStacked[MyDoc](
[MyDoc(tensor=None) for _ in range(10)], tensor_type=tensor_backend
).stack()
)

assert 'tensor' in da._storage.tensor_columns.keys()


def test_to_device():
da = DocumentArray[ImageDoc](
da = DocumentArrayStacked[ImageDoc](
[ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor
)
da = da.stack()
assert da.tensor.device == torch.device('cpu')
da.to('meta')
assert da.tensor.device == torch.device('meta')
Expand All @@ -368,12 +365,11 @@ def test_to_device_with_nested_da():
class Video(BaseDocument):
images: DocumentArray[ImageDoc]

da_image = DocumentArray[ImageDoc](
da_image = DocumentArrayStacked[ImageDoc](
[ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor
)

da = DocumentArray[Video]([Video(images=da_image)])
da = da.stack()
da = DocumentArrayStacked[Video]([Video(images=da_image)])
assert da.images[0].tensor.device == torch.device('cpu')
da.to('meta')
assert da.images[0].tensor.device == torch.device('meta')
Expand All @@ -384,11 +380,10 @@ class MyDoc(BaseDocument):
tensor: TorchTensor
docs: ImageDoc

da = DocumentArray[MyDoc](
da = DocumentArrayStacked[MyDoc](
[MyDoc(tensor=torch.zeros(3, 5), docs=ImageDoc(tensor=torch.zeros(3, 5)))],
tensor_type=TorchTensor,
)
da = da.stack()
assert da.tensor.device == torch.device('cpu')
assert da.docs.tensor.device == torch.device('cpu')
da.to('meta')
Expand All @@ -397,10 +392,9 @@ class MyDoc(BaseDocument):


def test_to_device_numpy():
da = DocumentArray[ImageDoc](
da = DocumentArrayStacked[ImageDoc](
[ImageDoc(tensor=np.zeros((3, 5)))], tensor_type=NdArray
)
da = da.stack()
with pytest.raises(NotImplementedError):
da.to('meta')

Expand Down Expand Up @@ -444,9 +438,7 @@ def test_np_scalar():
class MyDoc(BaseDocument):
scalar: NdArray

da = DocumentArray[MyDoc](
[MyDoc(scalar=np.array(2.0)) for _ in range(3)], tensor_type=NdArray
)
da = DocumentArray[MyDoc]([MyDoc(scalar=np.array(2.0)) for _ in range(3)])
assert all(doc.scalar.ndim == 0 for doc in da)
assert all(doc.scalar == 2.0 for doc in da)

Expand All @@ -467,11 +459,11 @@ class MyDoc(BaseDocument):
scalar: TorchTensor

da = DocumentArray[MyDoc](
[MyDoc(scalar=torch.tensor(2.0)) for _ in range(3)], tensor_type=TorchTensor
[MyDoc(scalar=torch.tensor(2.0)) for _ in range(3)],
)
assert all(doc.scalar.ndim == 0 for doc in da)
assert all(doc.scalar == 2.0 for doc in da)
stacked_da = da.stack()
stacked_da = da.stack(tensor_type=TorchTensor)
assert type(stacked_da.scalar) == TorchTensor

assert all(type(doc.scalar) == TorchTensor for doc in stacked_da)
Expand All @@ -486,7 +478,7 @@ def test_np_nan():
class MyDoc(BaseDocument):
scalar: Optional[NdArray]

da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)], tensor_type=NdArray)
da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)])
assert all(doc.scalar is None for doc in da)
assert all(doc.scalar == doc.scalar for doc in da)
stacked_da = da.stack()
Expand All @@ -505,10 +497,10 @@ def test_torch_nan():
class MyDoc(BaseDocument):
scalar: Optional[TorchTensor]

da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)], tensor_type=TorchTensor)
da = DocumentArray[MyDoc]([MyDoc() for _ in range(3)])
assert all(doc.scalar is None for doc in da)
assert all(doc.scalar == doc.scalar for doc in da)
stacked_da = da.stack()
stacked_da = da.stack(tensor_type=TorchTensor)
assert type(stacked_da.scalar) == TorchTensor

assert all(type(doc.scalar) == TorchTensor for doc in stacked_da)
Expand Down
Loading

0 comments on commit 64532dd

Please sign in to comment.