Skip to content

Commit

Permalink
feat: add config to load more field that the schema (#1437)
Browse files Browse the repository at this point in the history
Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
Co-authored-by: Joan Fontanals Martinez <joan.martinez@jina.ai>
  • Loading branch information
samsja and JoanFM committed Apr 26, 2023
1 parent 9bf0512 commit 83e7384
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 8 deletions.
10 changes: 10 additions & 0 deletions docarray/base_doc/any_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ class AnyDoc(BaseDoc):
AnyDoc is a Document that is not tied to any schema
"""

class Config:
_load_extra_fields_from_protobuf = True # I introduce this variable to allow to load more that the fields defined in the schema
# will documented this behavior later if this fix our problem

def __init__(self, **kwargs):
super().__init__()
self.__dict__.update(kwargs)
Expand All @@ -22,3 +26,9 @@ def _get_field_type(cls, field: str) -> Type['BaseDoc']:
:return:
"""
return AnyDoc

@classmethod
def _get_field_type_array(cls, field: str) -> Type:
from docarray import DocList

return DocList
24 changes: 23 additions & 1 deletion docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,28 @@

class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode):
"""
The base class for Documents
BaseDoc is the base class for all Documents. This class should be subclassed
to create new Document types with a specific schema.
The schema of a Document is defined by the fields of the class.
Example:
```python
from docarray import BaseDoc
from docarray.typing import NdArray, ImageUrl
import numpy as np
class MyDoc(BaseDoc):
embedding: NdArray[512]
image: ImageUrl
doc = MyDoc(embedding=np.zeros(512), image='https://example.com/image.jpg')
```
BaseDoc is a subclass of [pydantic.BaseModel](https://docs.pydantic.dev/usage/models/) and can be used in a similar way.
"""

id: Optional[ID] = Field(default_factory=lambda: ID(os.urandom(16).hex()))
Expand All @@ -50,6 +71,7 @@ class Config:
json_encoders = {AbstractTensor: lambda x: x}

validate_assignment = True
_load_extra_fields_from_protobuf = False

@classmethod
def from_view(cls: Type[T], storage_view: 'ColumnStorageView') -> T:
Expand Down
31 changes: 25 additions & 6 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
if torch is not None:
from docarray.typing import TorchTensor


T = TypeVar('T', bound='IOMixin')


Expand Down Expand Up @@ -122,11 +121,18 @@ class IOMixin(Iterable[Tuple[str, Any]]):

__fields__: Dict[str, 'ModelField']

class Config:
_load_extra_fields_from_protobuf: bool

@classmethod
@abstractmethod
def _get_field_type(cls, field: str) -> Type:
...

@classmethod
def _get_field_type_array(cls, field: str) -> Type:
return cls._get_field_type(field)

def __bytes__(self) -> bytes:
return self.to_bytes()

Expand All @@ -149,7 +155,8 @@ def to_bytes(
bstr = self.to_protobuf().SerializePartialToString()
else:
raise ValueError(
f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.'
f'protocol={protocol} is not supported. Can be only `protobuf` or '
f'pickle protocols 0-5.'
)
return _compress_bytes(bstr, algorithm=compress)

Expand Down Expand Up @@ -178,7 +185,8 @@ def from_bytes(
return cls.from_protobuf(pb_msg)
else:
raise ValueError(
f'protocol={protocol} is not supported. Can be only `protobuf` or pickle protocols 0-5.'
f'protocol={protocol} is not supported. Can be only `protobuf` or '
f'pickle protocols 0-5.'
)

def to_base64(
Expand Down Expand Up @@ -219,7 +227,10 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T:
fields: Dict[str, Any] = {}

for field_name in pb_msg.data:
if field_name not in cls.__fields__.keys():
if (
not (cls.Config._load_extra_fields_from_protobuf)
and field_name not in cls.__fields__.keys()
):
continue # optimization we don't even load the data if the key does not
# match any field in the cls or in the mapping

Expand Down Expand Up @@ -253,14 +264,22 @@ def _get_content_from_node_proto(
return_field = content_type_dict[docarray_type].from_protobuf(
getattr(value, content_key)
)
elif content_key in ['doc', 'doc_array']:
elif content_key == 'doc':
if field_name is None:
raise ValueError(
'field_name cannot be None when trying to deseriliaze a Document or a DocList'
'field_name cannot be None when trying to deseriliaze a BaseDoc'
)
return_field = cls._get_field_type(field_name).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
elif content_key == 'doc_array':
if field_name is None:
raise ValueError(
'field_name cannot be None when trying to deseriliaze a BaseDoc'
)
return_field = cls._get_field_type_array(field_name).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
elif content_key is None:
return_field = None
elif docarray_type is None:
Expand Down
32 changes: 32 additions & 0 deletions tests/units/array/test_array_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

from docarray import BaseDoc, DocList
from docarray.base_doc import AnyDoc
from docarray.documents import ImageDoc, TextDoc
from docarray.typing import NdArray

Expand Down Expand Up @@ -59,3 +60,34 @@ class CustomDocument(BaseDoc):
)

DocList.from_protobuf(da.to_protobuf())


@pytest.mark.proto
def test_any_doc_list_proto():
doc = AnyDoc(hello='world')
pt = DocList([doc]).to_protobuf()
docs = DocList.from_protobuf(pt)
assert docs[0].dict()['hello'] == 'world'


@pytest.mark.proto
def test_any_nested_doc_list_proto():
from docarray import BaseDoc, DocList

class TextDocWithId(BaseDoc):
id: str
text: str

class ResultTestDoc(BaseDoc):
matches: DocList[TextDocWithId]

index_da = DocList[TextDocWithId](
[TextDocWithId(id=f'{i}', text=f'ID {i}') for i in range(10)]
)

out_da = DocList[ResultTestDoc]([ResultTestDoc(matches=index_da[0:2])])
pb = out_da.to_protobuf()
docs = DocList.from_protobuf(pb)
assert docs[0].matches[0].id == '0'
assert len(docs[0].matches) == 2
assert len(docs) == 1
10 changes: 9 additions & 1 deletion tests/units/document/proto/test_document_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from docarray import DocList
from docarray.base_doc import BaseDoc
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.typing import NdArray, TorchTensor
from docarray.utils._internal.misc import is_tf_available

Expand Down Expand Up @@ -296,3 +296,11 @@ class MyDoc(BaseDoc):
doc = MyDoc(data=data)

MyDoc.from_protobuf(doc.to_protobuf())


@pytest.mark.proto
def test_any_doc_proto():
doc = AnyDoc(hello='world')
pt = doc.to_protobuf()
doc2 = AnyDoc.from_protobuf(pt)
assert doc2.dict()['hello'] == 'world'

0 comments on commit 83e7384

Please sign in to comment.