Skip to content

Commit

Permalink
fix: fix nested doc to json (#1502)
Browse files Browse the repository at this point in the history
Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
  • Loading branch information
samsja committed May 8, 2023
1 parent f5f692d commit 67c7e6d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 32 deletions.
8 changes: 7 additions & 1 deletion docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
overload,
)

from pydantic import parse_obj_as
from typing_extensions import SupportsIndex
from typing_inspect import is_union_type

Expand Down Expand Up @@ -268,8 +269,13 @@ def validate(

if isinstance(value, (cls, DocVec)):
return value
elif isinstance(value, Iterable):
elif isinstance(value, cls):
return cls(value)
elif isinstance(value, Iterable):
docs = []
for doc in value:
docs.append(parse_obj_as(cls.doc_type, doc))
return cls(docs)
else:
raise TypeError(f'Expecting an Iterable of {cls.doc_type}')

Expand Down
106 changes: 77 additions & 29 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import os
import warnings
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Dict,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
no_type_check,
)

import orjson
from pydantic import BaseModel, Field
from pydantic.main import ROOT_KEY
from rich.console import Console

from docarray.base_doc.base_node import BaseNode
Expand All @@ -36,6 +41,9 @@
T_update = TypeVar('T_update', bound='UpdateMixin')


ExcludeType = Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]


class BaseDoc(BaseModel, IOMixin, UpdateMixin, BaseNode):
"""
BaseDoc is the base class for all Documents. This class should be subclassed
Expand Down Expand Up @@ -191,7 +199,7 @@ def json(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: ExcludeType = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
Expand All @@ -208,19 +216,46 @@ def json(
`encoder` is an optional function to supply as `default` to json.dumps(),
other arguments as per `json.dumps()`.
"""
return super().json(
include=include,
exclude=exclude,
by_alias=by_alias,
skip_defaults=skip_defaults,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
encoder=encoder,
models_as_dict=models_as_dict,
**dumps_kwargs,
exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist(
exclude=exclude
)

# this is copy from pydantic code
if skip_defaults is not None:
warnings.warn(
f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"',
DeprecationWarning,
)
exclude_unset = skip_defaults
encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__)

# We don't directly call `self.dict()`, which does exactly this with `to_dict=True`
# because we want to be able to keep raw `BaseModel` instances and not as `dict`.
# This allows users to write custom JSON encoders for given `BaseModel` classes.
data = dict(
self._iter(
to_dict=models_as_dict,
by_alias=by_alias,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
)

# this is the custom part to deal with DocList
for field in doclist_exclude_fields:
# we need to do this because pydantic will not recognize DocList correctly
original_exclude = original_exclude or {}
if field not in original_exclude:
data[field] = [doc.dict() for doc in getattr(self, field)]

# this is copy from pydantic code
if self.__custom_root_type__:
data = data[ROOT_KEY]
return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs)

@no_type_check
@classmethod
def parse_raw(
Expand Down Expand Up @@ -253,7 +288,7 @@ def dict(
self,
*,
include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None,
exclude: ExcludeType = None,
by_alias: bool = False,
skip_defaults: Optional[bool] = None,
exclude_unset: bool = False,
Expand All @@ -266,22 +301,9 @@ def dict(
"""

doclist_exclude_fields = []
for field in self.__fields__.keys():
from docarray import DocList

type_ = self._get_field_type(field)
if isinstance(type_, type) and issubclass(type_, DocList):
doclist_exclude_fields.append(field)

original_exclude = exclude
if exclude is None:
exclude = set(doclist_exclude_fields)
elif isinstance(exclude, AbstractSet):
exclude = set([*exclude, *doclist_exclude_fields])
elif isinstance(exclude, Mapping):
exclude = dict(**exclude)
exclude.update({field: ... for field in doclist_exclude_fields})
exclude, original_exclude, doclist_exclude_fields = self._exclude_doclist(
exclude=exclude
)

data = super().dict(
include=include,
Expand All @@ -301,4 +323,30 @@ def dict(

return data

def _exclude_doclist(
self, exclude: ExcludeType
) -> Tuple[ExcludeType, ExcludeType, List[str]]:
doclist_exclude_fields = []
for field in self.__fields__.keys():
from docarray import DocList

type_ = self._get_field_type(field)
if isinstance(type_, type) and issubclass(type_, DocList):
doclist_exclude_fields.append(field)

original_exclude = exclude
if exclude is None:
exclude = set(doclist_exclude_fields)
elif isinstance(exclude, AbstractSet):
exclude = set([*exclude, *doclist_exclude_fields])
elif isinstance(exclude, Mapping):
exclude = dict(**exclude)
exclude.update({field: ... for field in doclist_exclude_fields})

return (
exclude,
original_exclude,
doclist_exclude_fields,
)

to_json = json
16 changes: 16 additions & 0 deletions tests/units/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
import torch
from pydantic import parse_obj_as

from docarray import BaseDoc, DocList
from docarray.typing import ImageUrl, NdArray, TorchTensor
Expand Down Expand Up @@ -463,3 +464,18 @@ class Image(BaseDoc):
assert docs.features == [None for _ in range(10)]
assert isinstance(docs.features, list)
assert not isinstance(docs.features, DocList)


def test_validate_list_dict():

images = [
dict(url=f'http://url.com/foo_{i}.png', tensor=NdArray(i)) for i in [2, 0, 1]
]

docs = parse_obj_as(DocList[Image], images)

assert docs.url == [
'http://url.com/foo_2.png',
'http://url.com/foo_0.png',
'http://url.com/foo_1.png',
]
10 changes: 8 additions & 2 deletions tests/units/document/test_base_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import pytest

from docarray import BaseDoc, DocList
from docarray import DocList
from docarray.base_doc.doc import BaseDoc
from docarray.typing import NdArray


Expand Down Expand Up @@ -80,6 +81,11 @@ def test_nested_to_dict_exclude_set(nested_docs):
assert 'hello' not in d.keys()


def test_nested_to_dict_exclude_dict(nested_docs): # doto change
def test_nested_to_dict_exclude_dict(nested_docs):
d = nested_docs.dict(exclude={'hello': True})
assert 'hello' not in d.keys()


def test_nested_to_json(nested_docs):
d = nested_docs.json()
nested_docs.__class__.parse_raw(d)

0 comments on commit 67c7e6d

Please sign in to comment.