Skip to content

Commit

Permalink
fix: bugs when serialize union type (#1655)
Browse files Browse the repository at this point in the history
Signed-off-by: maxwelljin2 <gejin@berkeley.edu>
Signed-off-by: Joan Fontanals <jfontanalsmartinez@gmail.com>
Co-authored-by: Joan Fontanals <jfontanalsmartinez@gmail.com>
  • Loading branch information
maxwelljin and JoanFM committed Jun 16, 2023
1 parent dc96e38 commit 0c27fef
Show file tree
Hide file tree
Showing 9 changed files with 124 additions and 7 deletions.
17 changes: 14 additions & 3 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
Tuple,
Type,
TypeVar,
Union,
get_origin,
)

import numpy as np
Expand Down Expand Up @@ -286,9 +288,18 @@ def _get_content_from_node_proto(
raise ValueError(
'field_type cannot be None when trying to deserialize a BaseDoc'
)
return_field = field_type.from_protobuf(
getattr(value, content_key)
) # we get to the parent class
try:
return_field = field_type.from_protobuf(
getattr(value, content_key)
) # we get to the parent class
except Exception:
if get_origin(field_type) is Union:
raise ValueError(
'Union type is not supported for proto deserialization. Please use JSON serialization instead'
)
raise ValueError(
f'{field_type} is not supported for proto deserialization'
)
elif content_key == 'doc_array':
if field_name is None:
raise ValueError(
Expand Down
6 changes: 4 additions & 2 deletions docarray/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Union,
)

from docarray.utils._internal._typing import safe_issubclass

if TYPE_CHECKING:
from docarray import BaseDoc

Expand Down Expand Up @@ -147,9 +149,9 @@ def _get_field_type_by_access_path(
return doc_type._get_field_type(field)
else:
d = doc_type._get_field_type(field)
if issubclass(d, DocList):
if safe_issubclass(d, DocList):
return _get_field_type_by_access_path(d.doc_type, remaining)
elif issubclass(d, BaseDoc):
elif safe_issubclass(d, BaseDoc):
return _get_field_type_by_access_path(d, remaining)
else:
return None
Expand Down
4 changes: 2 additions & 2 deletions docarray/utils/_internal/_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ForwardRef, Optional
from typing import Any, ForwardRef, Optional, Union

from typing_extensions import get_origin
from typing_inspect import get_args, is_typevar, is_union_type
Expand Down Expand Up @@ -47,7 +47,7 @@ def safe_issubclass(x: type, a_tuple: type) -> bool:
Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'.
"""
if (
(get_origin(x) in (list, tuple, dict, set))
(get_origin(x) in (list, tuple, dict, set, Union))
or is_typevar(x)
or (type(x) == ForwardRef)
or is_typevar(x)
Expand Down
4 changes: 4 additions & 0 deletions docs/user_guide/sending/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,8 @@ assert dv_from_proto_numpy.tensor_type == NdArray
assert isinstance(dv_from_proto_numpy.tensor, NdArray)
```

!!! note
Serialization to protobuf is not supported for union types involving `BaseDoc` types.



21 changes: 21 additions & 0 deletions tests/units/array/test_array_from_to_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,24 @@ def test_from_to_base64(protocol, compress, show_progress):
assert d1.image.url == d2.image.url
assert da[1].image.url is None
assert da2[1].image.url is None


def test_union_type_error(tmp_path):
from typing import Union

from docarray.documents import TextDoc

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

with pytest.raises(ValueError):
docs.from_bytes(docs.to_bytes())

class BasisUnion(BaseDoc):
ud: Union[int, str]

docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_copy = DocList[BasisUnion].from_bytes(docs_basic.to_bytes())
assert docs_copy == docs_basic
23 changes: 23 additions & 0 deletions tests/units/array/test_array_from_to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,26 @@ class Book(BaseDoc):
tmp_file = str(tmpdir / 'tmp.csv')
with pytest.raises(TypeError):
docs.to_csv(tmp_file)


def test_union_type_error(tmp_path):
from typing import Union

from docarray.documents import TextDoc

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

with pytest.raises(ValueError):
docs.to_csv(str(tmp_path) + ".csv")
DocList[CustomDoc].from_csv(str(tmp_path) + ".csv")

class BasisUnion(BaseDoc):
ud: Union[int, str]

docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_basic.to_csv(str(tmp_path) + ".csv")
docs_copy = DocList[BasisUnion].from_csv(str(tmp_path) + ".csv")
assert docs_copy == docs_basic
14 changes: 14 additions & 0 deletions tests/units/array/test_array_from_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,17 @@ def test_from_to_json():
assert d1.image.url == d2.image.url
assert da[1].image.url is None
assert da2[1].image.url is None


def test_union_type():
from typing import Union

from docarray.documents import TextDoc

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

docs_copy = docs.from_json(docs.to_json())
assert docs == docs_copy
22 changes: 22 additions & 0 deletions tests/units/array/test_array_from_to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,25 @@ class Book(BaseDoc):
docs = DocList([Book(title='hello'), Book(title='world')])
with pytest.raises(TypeError):
docs.to_dataframe()


@pytest.mark.proto
def test_union_type_error():
from typing import Union

from docarray.documents import TextDoc

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

with pytest.raises(ValueError):
DocList[CustomDoc].from_dataframe(docs.to_dataframe())

class BasisUnion(BaseDoc):
ud: Union[int, str]

docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_copy = DocList[BasisUnion].from_dataframe(docs_basic.to_dataframe())
assert docs_copy == docs_basic
20 changes: 20 additions & 0 deletions tests/units/array/test_array_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,23 @@ class ResultTestDoc(BaseDoc):
assert docs[0].matches[0].id == '0'
assert len(docs[0].matches) == 2
assert len(docs) == 1


@pytest.mark.proto
def test_union_type_error():
from typing import Union

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

with pytest.raises(ValueError):
DocList[CustomDoc].from_protobuf(docs.to_protobuf())

class BasisUnion(BaseDoc):
ud: Union[int, str]

docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_copy = DocList[BasisUnion].from_protobuf(docs_basic.to_protobuf())
assert docs_copy == docs_basic

0 comments on commit 0c27fef

Please sign in to comment.