diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index b19747d7a9b..bc49ea8cdb5 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -12,6 +12,8 @@ Tuple, Type, TypeVar, + Union, + get_origin, ) import numpy as np @@ -286,9 +288,18 @@ def _get_content_from_node_proto( raise ValueError( 'field_type cannot be None when trying to deserialize a BaseDoc' ) - return_field = field_type.from_protobuf( - getattr(value, content_key) - ) # we get to the parent class + try: + return_field = field_type.from_protobuf( + getattr(value, content_key) + ) # we get to the parent class + except Exception: + if get_origin(field_type) is Union: + raise ValueError( + 'Union type is not supported for proto deserialization. Please use JSON serialization instead' + ) + raise ValueError( + f'{field_type} is not supported for proto deserialization' + ) elif content_key == 'doc_array': if field_name is None: raise ValueError( diff --git a/docarray/helper.py b/docarray/helper.py index ebb58b8378c..5db06eb6d6f 100644 --- a/docarray/helper.py +++ b/docarray/helper.py @@ -15,6 +15,8 @@ Union, ) +from docarray.utils._internal._typing import safe_issubclass + if TYPE_CHECKING: from docarray import BaseDoc @@ -147,9 +149,9 @@ def _get_field_type_by_access_path( return doc_type._get_field_type(field) else: d = doc_type._get_field_type(field) - if issubclass(d, DocList): + if safe_issubclass(d, DocList): return _get_field_type_by_access_path(d.doc_type, remaining) - elif issubclass(d, BaseDoc): + elif safe_issubclass(d, BaseDoc): return _get_field_type_by_access_path(d, remaining) else: return None diff --git a/docarray/utils/_internal/_typing.py b/docarray/utils/_internal/_typing.py index 520c0614cf3..69babecea10 100644 --- a/docarray/utils/_internal/_typing.py +++ b/docarray/utils/_internal/_typing.py @@ -1,4 +1,4 @@ -from typing import Any, ForwardRef, Optional +from typing import Any, ForwardRef, Optional, Union from typing_extensions import get_origin from typing_inspect import get_args, is_typevar, is_union_type @@ -47,7 +47,7 @@ def safe_issubclass(x: type, a_tuple: type) -> bool: Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'. """ if ( - (get_origin(x) in (list, tuple, dict, set)) + (get_origin(x) in (list, tuple, dict, set, Union)) or is_typevar(x) or (type(x) == ForwardRef) or is_typevar(x) diff --git a/docs/user_guide/sending/serialization.md b/docs/user_guide/sending/serialization.md index dd30895a3e5..6cc2b64f35e 100644 --- a/docs/user_guide/sending/serialization.md +++ b/docs/user_guide/sending/serialization.md @@ -303,4 +303,8 @@ assert dv_from_proto_numpy.tensor_type == NdArray assert isinstance(dv_from_proto_numpy.tensor, NdArray) ``` +!!! note + Serialization to protobuf is not supported for union types involving `BaseDoc` types. + + diff --git a/tests/units/array/test_array_from_to_bytes.py b/tests/units/array/test_array_from_to_bytes.py index 7cd9f0dfd8c..7530b2b82be 100644 --- a/tests/units/array/test_array_from_to_bytes.py +++ b/tests/units/array/test_array_from_to_bytes.py @@ -69,3 +69,24 @@ def test_from_to_base64(protocol, compress, show_progress): assert d1.image.url == d2.image.url assert da[1].image.url is None assert da2[1].image.url is None + + +def test_union_type_error(tmp_path): + from typing import Union + + from docarray.documents import TextDoc + + class CustomDoc(BaseDoc): + ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type') + + docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))]) + + with pytest.raises(ValueError): + docs.from_bytes(docs.to_bytes()) + + class BasisUnion(BaseDoc): + ud: Union[int, str] + + docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")]) + docs_copy = DocList[BasisUnion].from_bytes(docs_basic.to_bytes()) + assert docs_copy == docs_basic diff --git a/tests/units/array/test_array_from_to_csv.py b/tests/units/array/test_array_from_to_csv.py index d00ea172c4e..b04323f63d4 100644 --- a/tests/units/array/test_array_from_to_csv.py +++ b/tests/units/array/test_array_from_to_csv.py @@ -120,3 +120,26 @@ class Book(BaseDoc): tmp_file = str(tmpdir / 'tmp.csv') with pytest.raises(TypeError): docs.to_csv(tmp_file) + + +def test_union_type_error(tmp_path): + from typing import Union + + from docarray.documents import TextDoc + + class CustomDoc(BaseDoc): + ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type') + + docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))]) + + with pytest.raises(ValueError): + docs.to_csv(str(tmp_path) + ".csv") + DocList[CustomDoc].from_csv(str(tmp_path) + ".csv") + + class BasisUnion(BaseDoc): + ud: Union[int, str] + + docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")]) + docs_basic.to_csv(str(tmp_path) + ".csv") + docs_copy = DocList[BasisUnion].from_csv(str(tmp_path) + ".csv") + assert docs_copy == docs_basic diff --git a/tests/units/array/test_array_from_to_json.py b/tests/units/array/test_array_from_to_json.py index c36b8af92a9..5767fe9de11 100644 --- a/tests/units/array/test_array_from_to_json.py +++ b/tests/units/array/test_array_from_to_json.py @@ -28,3 +28,17 @@ def test_from_to_json(): assert d1.image.url == d2.image.url assert da[1].image.url is None assert da2[1].image.url is None + + +def test_union_type(): + from typing import Union + + from docarray.documents import TextDoc + + class CustomDoc(BaseDoc): + ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type') + + docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))]) + + docs_copy = docs.from_json(docs.to_json()) + assert docs == docs_copy diff --git a/tests/units/array/test_array_from_to_pandas.py b/tests/units/array/test_array_from_to_pandas.py index 6d122822d91..bef4427ca6d 100644 --- a/tests/units/array/test_array_from_to_pandas.py +++ b/tests/units/array/test_array_from_to_pandas.py @@ -99,3 +99,25 @@ class Book(BaseDoc): docs = DocList([Book(title='hello'), Book(title='world')]) with pytest.raises(TypeError): docs.to_dataframe() + + +@pytest.mark.proto +def test_union_type_error(): + from typing import Union + + from docarray.documents import TextDoc + + class CustomDoc(BaseDoc): + ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type') + + docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))]) + + with pytest.raises(ValueError): + DocList[CustomDoc].from_dataframe(docs.to_dataframe()) + + class BasisUnion(BaseDoc): + ud: Union[int, str] + + docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")]) + docs_copy = DocList[BasisUnion].from_dataframe(docs_basic.to_dataframe()) + assert docs_copy == docs_basic diff --git a/tests/units/array/test_array_proto.py b/tests/units/array/test_array_proto.py index e57cc3313f5..c4ac74332ef 100644 --- a/tests/units/array/test_array_proto.py +++ b/tests/units/array/test_array_proto.py @@ -91,3 +91,23 @@ class ResultTestDoc(BaseDoc): assert docs[0].matches[0].id == '0' assert len(docs[0].matches) == 2 assert len(docs) == 1 + + +@pytest.mark.proto +def test_union_type_error(): + from typing import Union + + class CustomDoc(BaseDoc): + ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type') + + docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))]) + + with pytest.raises(ValueError): + DocList[CustomDoc].from_protobuf(docs.to_protobuf()) + + class BasisUnion(BaseDoc): + ud: Union[int, str] + + docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")]) + docs_copy = DocList[BasisUnion].from_protobuf(docs_basic.to_protobuf()) + assert docs_copy == docs_basic