Skip to content

Commit

Permalink
fix: proto ser and deser for nested tuple/dict/list (#1278)
Browse files Browse the repository at this point in the history
* feat: add failing test

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* refactor: shorten if else statememt

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* refactor: shorten if else statememt

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: fix proto and list

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: fix proto and dict

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* feat: add very complex test

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* feat: fix pure tensor stuff

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* feat: fix pure tensor stuff

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: fix mypy

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: fix from protobuf

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: fix from protobuf tensorflow

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: add more test

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: fix mypy

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: add more test

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

* fix: import ndarray

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>

---------

Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
  • Loading branch information
samsja committed Mar 23, 2023
1 parent 11d013e commit 3364127
Show file tree
Hide file tree
Showing 6 changed files with 283 additions and 111 deletions.
126 changes: 80 additions & 46 deletions docarray/base_document/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,26 @@
TypeVar,
)

import numpy as np
from typing_inspect import is_union_type

from docarray.base_document.base_node import BaseNode
from docarray.typing import NdArray
from docarray.typing.proto_register import _PROTO_TYPE_NAME_TO_CLASS
from docarray.utils.compress import _compress_bytes, _decompress_bytes
from docarray.utils.misc import is_tf_available, is_torch_available

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

from docarray.typing import TensorFlowTensor

torch_available = is_torch_available()
if torch_available:
import torch

from docarray.typing import TorchTensor

if TYPE_CHECKING:
from pydantic.fields import ModelField
Expand All @@ -36,60 +51,69 @@ def _type_to_protobuf(value: Any) -> 'NodeProto':
"""
from docarray.proto import NodeProto

nested_item: 'NodeProto'
if isinstance(value, BaseNode):
nested_item = value._to_node_protobuf()
basic_type_to_key = {
str: 'text',
bool: 'boolean',
int: 'integer',
float: 'float',
bytes: 'blob',
}

elif isinstance(value, str):
nested_item = NodeProto(text=value)
container_type_to_key = {list: 'list', set: 'set', tuple: 'tuple'}

elif isinstance(value, bool):
nested_item = NodeProto(boolean=value)
nested_item: 'NodeProto'

elif isinstance(value, int):
nested_item = NodeProto(integer=value)
if isinstance(value, BaseNode):
nested_item = value._to_node_protobuf()
return nested_item

elif isinstance(value, float):
nested_item = NodeProto(float=value)
base_node_wrap: BaseNode
if torch_available:
if isinstance(value, torch.Tensor):
base_node_wrap = TorchTensor._docarray_from_native(value)
return base_node_wrap._to_node_protobuf()

elif isinstance(value, bytes):
nested_item = NodeProto(blob=value)
if tf_available:
if isinstance(value, tf.Tensor):
base_node_wrap = TensorFlowTensor._docarray_from_native(value)
return base_node_wrap._to_node_protobuf()

elif isinstance(value, list):
from google.protobuf.struct_pb2 import ListValue
if isinstance(value, np.ndarray):
base_node_wrap = NdArray._docarray_from_native(value)
return base_node_wrap._to_node_protobuf()

lvalue = ListValue()
for item in value:
lvalue.append(item)
nested_item = NodeProto(list=lvalue)
for basic_type, key_name in basic_type_to_key.items():
if isinstance(value, basic_type):
nested_item = NodeProto(**{key_name: value})
return nested_item

elif isinstance(value, set):
from google.protobuf.struct_pb2 import ListValue
for container_type, key_name in container_type_to_key.items():
if isinstance(value, container_type):
from docarray.proto import ListOfAnyProto

lvalue = ListValue()
for item in value:
lvalue.append(item)
nested_item = NodeProto(set=lvalue)
lvalue = ListOfAnyProto()
for item in value:
lvalue.data.append(_type_to_protobuf(item))
nested_item = NodeProto(**{key_name: lvalue})
return nested_item

elif isinstance(value, tuple):
from google.protobuf.struct_pb2 import ListValue
if isinstance(value, dict):
from docarray.proto import DictOfAnyProto

lvalue = ListValue()
for item in value:
lvalue.append(item)
nested_item = NodeProto(tuple=lvalue)
data = {}

elif isinstance(value, dict):
from google.protobuf.struct_pb2 import Struct
for key, content in value.items():
data[key] = _type_to_protobuf(content)

struct = Struct()
struct.update(value)
struct = DictOfAnyProto(data=data)
nested_item = NodeProto(dict=struct)
return nested_item

elif value is None:
nested_item = NodeProto()
return nested_item
else:
raise ValueError(f'{type(value)} is not supported with protobuf')
return nested_item


class IOMixin(Iterable[Tuple[str, Any]]):
Expand Down Expand Up @@ -208,7 +232,9 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T:
return cls(**fields)

@classmethod
def _get_content_from_node_proto(cls, value: 'NodeProto', field_name: str) -> Any:
def _get_content_from_node_proto(
cls, value: 'NodeProto', field_name: Optional[str] = None
) -> Any:
"""
load the proto data from a node proto
Expand All @@ -217,12 +243,6 @@ def _get_content_from_node_proto(cls, value: 'NodeProto', field_name: str) -> An
:return: the loaded field
"""
content_type_dict = _PROTO_TYPE_NAME_TO_CLASS
arg_to_container: Dict[str, Callable] = {
'list': list,
'set': set,
'tuple': tuple,
'dict': dict,
}

content_key = value.WhichOneof('content')
docarray_type = (
Expand All @@ -236,23 +256,37 @@ def _get_content_from_node_proto(cls, value: 'NodeProto', field_name: str) -> An
getattr(value, content_key)
)
elif content_key in ['document', 'document_array']:
if field_name is None:
raise ValueError(
'field_name cannot be None when trying to deseriliaze a Document or a DocumentArray'
)
return_field = cls._get_field_type(field_name).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
elif content_key is None:
return_field = None
elif docarray_type is None:

arg_to_container: Dict[str, Callable] = {
'list': list,
'set': set,
'tuple': tuple,
}

if content_key in ['text', 'blob', 'integer', 'float', 'boolean']:
return_field = getattr(value, content_key)

elif content_key in arg_to_container.keys():
from google.protobuf.json_format import MessageToDict

return_field = arg_to_container[content_key](
MessageToDict(getattr(value, content_key))
cls._get_content_from_node_proto(node)
for node in getattr(value, content_key).data
)

elif content_key == 'dict':
deser_dict: Dict[str, Any] = dict()
for key_name, node in value.dict.data.items():
deser_dict[key_name] = cls._get_content_from_node_proto(node)
return_field = deser_dict
else:
raise ValueError(
f'key {content_key} is not supported for deserialization'
Expand Down
3 changes: 3 additions & 0 deletions docarray/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

if __pb__version__.startswith('4'):
from docarray.proto.pb.docarray_pb2 import (
DictOfAnyProto,
DocumentArrayProto,
DocumentArrayStackedProto,
DocumentProto,
Expand All @@ -12,6 +13,7 @@
)
else:
from docarray.proto.pb2.docarray_pb2 import (
DictOfAnyProto,
DocumentArrayProto,
DocumentArrayStackedProto,
DocumentProto,
Expand All @@ -30,4 +32,5 @@
'DocumentArrayProto',
'ListOfDocumentArrayProto',
'ListOfAnyProto',
'DictOfAnyProto',
]
26 changes: 17 additions & 9 deletions docarray/proto/docarray.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ message GenericDictValue {
repeated KeyValuePair entries = 1;
}

//


message NodeProto {

oneof content {
Expand All @@ -56,13 +57,13 @@ message NodeProto {
// a sub DocumentArray
DocumentArrayProto document_array = 8;
//any list
google.protobuf.ListValue list = 9;
ListOfAnyProto list = 9;
//any set
google.protobuf.ListValue set = 10;
ListOfAnyProto set = 10;
//any tuple
google.protobuf.ListValue tuple = 11;
ListOfAnyProto tuple = 11;
// dictionary with string as keys
google.protobuf.Struct dict = 12;
DictOfAnyProto dict = 12;
}

oneof docarray_type {
Expand All @@ -80,18 +81,25 @@ message DocumentProto {

}

message DictOfAnyProto {

map<string, NodeProto> data = 1;

}

message ListOfAnyProto {
repeated NodeProto data = 1;
}

message DocumentArrayProto {
repeated DocumentProto docs = 1; // a list of Documents
}


message ListOfDocumentArrayProto {
repeated DocumentArrayProto data = 1;
}

message ListOfAnyProto {
repeated NodeProto data = 1;
}

message DocumentArrayStackedProto{
map<string, NdArrayProto> tensor_columns = 1; // a dict of document columns
map<string, DocumentArrayStackedProto> doc_columns = 2; // a dict of tensor columns
Expand Down
50 changes: 28 additions & 22 deletions docarray/proto/pb/docarray_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 3364127

Please sign in to comment.