Skip to content

Commit

Permalink
feat(test): DocumentArray method tests similar to list methods like r…
Browse files Browse the repository at this point in the history
…everse, sort, remove, pop (#1291)

* feat: isort format fix

Signed-off-by: agaraman0 <agaraman0@gmail.com>

* refactor: comment fixes

Signed-off-by: agaraman0 <agaraman0@gmail.com>

* refactor: comment fixes

Signed-off-by: agaraman0 <agaraman0@gmail.com>

---------

Signed-off-by: agaraman0 <agaraman0@gmail.com>
  • Loading branch information
agaraman0 committed Mar 27, 2023
1 parent 89c2a0a commit 081a03f
Showing 1 changed file with 48 additions and 1 deletion.
49 changes: 48 additions & 1 deletion tests/units/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from docarray import BaseDocument, DocumentArray
from docarray.typing import NdArray, TorchTensor
from docarray.typing import ImageUrl, NdArray, TorchTensor
from docarray.utils.misc import is_tf_available

tf_available = is_tf_available()
Expand Down Expand Up @@ -319,3 +319,50 @@ class Text(BaseDocument):
da = DocumentArray[Text].construct(docs)

assert da._data is docs


def test_reverse():
class Text(BaseDocument):
text: str

docs = [Text(text=f'hello {i}') for i in range(10)]

da = DocumentArray[Text](docs)
da.reverse()
assert da[-1].text == 'hello 0'
assert da[0].text == 'hello 9'


class Image(BaseDocument):
tensor: Optional[NdArray]
url: ImageUrl


def test_remove():
images = [Image(url=f'http://url.com/foo_{i}.png') for i in range(3)]
da = DocumentArray[Image](images)
da.remove(images[1])
assert len(da) == 2
assert da[0] == images[0]
assert da[1] == images[2]


def test_pop():
images = [Image(url=f'http://url.com/foo_{i}.png') for i in range(3)]
da = DocumentArray[Image](images)
popped = da.pop(1)
assert len(da) == 2
assert popped == images[1]
assert da[0] == images[0]
assert da[1] == images[2]


def test_sort():
images = [
Image(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1]
]
da = DocumentArray[Image](images)
da.sort(key=lambda img: len(img.tensor))
assert len(da) == 3
assert da[0].url == 'http://url.com/foo_0.png'
assert da[1].url == 'http://url.com/foo_1.png'

0 comments on commit 081a03f

Please sign in to comment.