Skip to content

Commit

Permalink
fix: fix issue serializing deserializing complex schemas (#1836)
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Martinez <joan.fontanals.martinez@jina.ai>
  • Loading branch information
JoanFM committed Dec 19, 2023
1 parent 3cfa0b8 commit 21e107b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 10 deletions.
18 changes: 10 additions & 8 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def _get_content_from_node_proto(
)

return_field: Any

if docarray_type in content_type_dict:
return_field = content_type_dict[docarray_type].from_protobuf(
getattr(value, content_key)
Expand All @@ -308,13 +307,18 @@ def _get_content_from_node_proto(
f'{field_type} is not supported for proto deserialization'
)
elif content_key == 'doc_array':
if field_name is None:
if field_type is not None and field_name is None:
return_field = field_type.from_protobuf(getattr(value, content_key))
elif field_name is not None:
return_field = cls._get_field_annotation_array(
field_name
).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
else:
raise ValueError(
'field_name cannot be None when trying to deserialize a BaseDoc'
'field_name and field_type cannot be None when trying to deserialize a DocArray'
)
return_field = cls._get_field_annotation_array(field_name).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
elif content_key is None:
return_field = None
elif docarray_type is None:
Expand All @@ -330,8 +334,6 @@ def _get_content_from_node_proto(
elif content_key in arg_to_container.keys():
if field_name and field_name in cls._docarray_fields():
field_type = cls._get_field_inner_type(field_name)
else:
field_type = None

if isinstance(field_type, GenericAlias):
field_type = get_args(field_type)[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/units/array/test_array_from_to_json.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Dict, List

import numpy as np
import pytest
Expand Down
39 changes: 39 additions & 0 deletions tests/units/array/test_array_proto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from typing import Dict, List

from docarray import BaseDoc, DocList
from docarray.base_doc import AnyDoc
Expand Down Expand Up @@ -111,3 +112,41 @@ class BasisUnion(BaseDoc):
docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_copy = DocList[BasisUnion].from_protobuf(docs_basic.to_protobuf())
assert docs_copy == docs_basic


class MySimpleDoc(BaseDoc):
title: str


class MyComplexDoc(BaseDoc):
content_dict_doclist: Dict[str, DocList[MySimpleDoc]]
content_dict_list: Dict[str, List[MySimpleDoc]]
aux_dict: Dict[str, int]


def test_to_from_proto_complex():
da = DocList[MyComplexDoc](
[
MyComplexDoc(
content_dict_doclist={
'test1': DocList[MySimpleDoc](
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
)
},
content_dict_list={
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
},
aux_dict={'a': 0},
)
]
)
da2 = DocList[MyComplexDoc].from_protobuf(da.to_protobuf())
assert len(da2) == 1
d2 = da2[0]
assert d2.aux_dict == {'a': 0}
assert len(d2.content_dict_doclist['test1']) == 2
assert d2.content_dict_doclist['test1'][0].title == '123'
assert d2.content_dict_doclist['test1'][1].title == '456'
assert len(d2.content_dict_list['test1']) == 2
assert d2.content_dict_list['test1'][0].title == '123'
assert d2.content_dict_list['test1'][1].title == '456'
63 changes: 62 additions & 1 deletion tests/units/document/test_from_to_bytes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from typing import Dict, List

from docarray import BaseDoc
from docarray import BaseDoc, DocList
from docarray.documents import ImageDoc
from docarray.typing import NdArray

Expand All @@ -11,6 +12,16 @@ class MyDoc(BaseDoc):
image: ImageDoc


class MySimpleDoc(BaseDoc):
title: str


class MyComplexDoc(BaseDoc):
content_dict_doclist: Dict[str, DocList[MySimpleDoc]]
content_dict_list: Dict[str, List[MySimpleDoc]]
aux_dict: Dict[str, int]


@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_to_from_bytes(protocol, compress):
Expand Down Expand Up @@ -39,3 +50,53 @@ def test_to_from_base64(protocol, compress):
assert d2.text == 'hello'
assert d2.embedding.tolist() == [1, 2, 3, 4, 5]
assert d2.image.url == 'aux.png'


@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_to_from_bytes_complex(protocol, compress):
d = MyComplexDoc(
content_dict_doclist={
'test1': DocList[MySimpleDoc](
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
)
},
content_dict_list={
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
},
aux_dict={'a': 0},
)
bstr = d.to_bytes(protocol=protocol, compress=compress)
d2 = MyComplexDoc.from_bytes(bstr, protocol=protocol, compress=compress)
assert d2.aux_dict == {'a': 0}
assert len(d2.content_dict_doclist['test1']) == 2
assert d2.content_dict_doclist['test1'][0].title == '123'
assert d2.content_dict_doclist['test1'][1].title == '456'
assert len(d2.content_dict_list['test1']) == 2
assert d2.content_dict_list['test1'][0].title == '123'
assert d2.content_dict_list['test1'][1].title == '456'


@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_to_from_base64_complex(protocol, compress):
d = MyComplexDoc(
content_dict_doclist={
'test1': DocList[MySimpleDoc](
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
)
},
content_dict_list={
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
},
aux_dict={'a': 0},
)
bstr = d.to_base64(protocol=protocol, compress=compress)
d2 = MyComplexDoc.from_base64(bstr, protocol=protocol, compress=compress)
assert d2.aux_dict == {'a': 0}
assert len(d2.content_dict_doclist['test1']) == 2
assert d2.content_dict_doclist['test1'][0].title == '123'
assert d2.content_dict_doclist['test1'][1].title == '456'
assert len(d2.content_dict_list['test1']) == 2
assert d2.content_dict_list['test1'][0].title == '123'
assert d2.content_dict_list['test1'][1].title == '456'

0 comments on commit 21e107b

Please sign in to comment.