Skip to content

Commit

Permalink
feat: initial support on pydantic (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Jan 14, 2022
1 parent 2027672 commit b7e5ce7
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docarray/array/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from .sample import SampleMixin
from .text import TextToolsMixin
from .traverse import TraverseMixin
from .pydantic import PydanticMixin


class AllMixins(
GetAttributeMixin,
ContentPropertyMixin,
PydanticMixin,
GroupMixin,
EmptyMixin,
CsvIOMixin,
Expand Down
36 changes: 36 additions & 0 deletions docarray/array/mixins/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import TYPE_CHECKING, Type, List

if TYPE_CHECKING:
from ...document.pydantic_model import PydanticDocumentArray
from ...types import T
from pydantic import BaseModel


class PydanticMixin:
@classmethod
def get_json_schema(cls, indent: int = 2) -> str:
"""Return a JSON Schema of DocumentArray class."""
from pydantic import schema_json_of
from ...document.pydantic_model import PydanticDocumentArray

return schema_json_of(
PydanticDocumentArray, title='DocumentArray Schema', indent=indent
)

def to_pydantic_model(self) -> 'PydanticDocumentArray':
"""Convert a DocumentArray object into a Pydantic model."""
return [d.to_pydantic_model() for d in self]

@classmethod
def from_pydantic_model(
cls: Type['T'], model: List['BaseModel'], ndarray_as_list: bool = False
) -> 'T':
"""Convert a list of PydanticDocument into
:param model: the pydantic data model object that represents a DocumentArray
:param ndarray_as_list: if set to True, `embedding` and `blob` are auto-casted to ndarray. :return:
:return: a DocumentArray
"""
from ... import Document

return cls(Document.from_pydantic_model(m, ndarray_as_list) for m in model)
2 changes: 2 additions & 0 deletions docarray/document/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
from .porting import PortingMixin
from .property import PropertyMixin
from .protobuf import ProtobufMixin
from .pydantic import PydanticMixin
from .sugar import SingletonSugarMixin
from .text import TextDataMixin
from .video import VideoDataMixin


class AllMixins(
ProtobufMixin,
PydanticMixin,
PropertyMixin,
ContentPropertyMixin,
ConvertMixin,
Expand Down
52 changes: 52 additions & 0 deletions docarray/document/mixins/pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Type

import numpy as np

if TYPE_CHECKING:
from pydantic import BaseModel
from ...types import T
from ..pydantic_model import PydanticDocument


class PydanticMixin:
"""Provide helper functions to convert to/from a Pydantic model"""

@classmethod
def json_schema(cls, indent: int = 2) -> str:
"""Return a JSON Schema of Document class."""
from ..pydantic_model import PydanticDocument as DP

return DP.schema_json(indent=indent)

def to_pydantic_model(self) -> 'PydanticDocument':
"""Convert a Document object into a Pydantic model."""
from ..pydantic_model import PydanticDocument as DP

return DP(**{f: getattr(self, f) for f in self.non_empty_fields})

@classmethod
def from_pydantic_model(
cls: Type['T'], model: 'BaseModel', ndarray_as_list: bool = False
) -> 'T':
"""Build a Document object from a Pydantic model
:param model: the pydantic data model object that represents a Document
:param ndarray_as_list: if set to True, `embedding` and `blob` are auto-casted to ndarray.
:return: a Document object
"""
from ... import Document

fields = {}
for (field, value) in model.dict(exclude_none=True).items():
f_name = field
if f_name == 'chunks' or f_name == 'matches':
fields[f_name] = [Document.from_pydantic_model(d) for d in value]
elif f_name == 'scores' or f_name == 'evaluations':
fields[f_name] = defaultdict(value)
elif f_name == 'embedding' or f_name == 'blob':
if not ndarray_as_list:
fields[f_name] = np.array(value)
else:
fields[f_name] = value
return Document(**fields)
47 changes: 47 additions & 0 deletions docarray/document/pydantic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Optional, List, Dict, Any, TYPE_CHECKING, Union

from pydantic import BaseModel, validator

from ..math.ndarray import to_list

if TYPE_CHECKING:
from ..types import ArrayType

_ProtoValueType = Optional[Union[str, bool, float]]
_StructValueType = Union[
_ProtoValueType, List[_ProtoValueType], Dict[str, _ProtoValueType]
]


def _convert_ndarray_to_list(v: 'ArrayType'):
return to_list(v)


class PydanticDocument(BaseModel):
id: str
parent_id: Optional[str]
granularity: Optional[int]
adjacency: Optional[int]
buffer: Optional[bytes]
blob: Optional[Any]
mime_type: Optional[str]
text: Optional[str]
weight: Optional[float]
uri: Optional[str]
tags: Optional[Dict[str, '_StructValueType']]
offset: Optional[float]
location: Optional[List[float]]
embedding: Optional[Any]
modality: Optional[str]
evaluations: Optional[Dict[str, Dict[str, '_StructValueType']]]
scores: Optional[Dict[str, Dict[str, '_StructValueType']]]
chunks: Optional[List['PydanticDocument']]
matches: Optional[List['PydanticDocument']]

_blob2list = validator('blob', allow_reuse=True)(_convert_ndarray_to_list)
_embedding2list = validator('embedding', allow_reuse=True)(_convert_ndarray_to_list)


PydanticDocument.update_forward_refs()

PydanticDocumentArray = List[PydanticDocument]
12 changes: 11 additions & 1 deletion docarray/math/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple, Sequence, Optional
from typing import TYPE_CHECKING, Tuple, Sequence, Optional, List

import numpy as np

Expand Down Expand Up @@ -145,3 +145,13 @@ def to_numpy_array(value) -> 'np.ndarray':
if hasattr(v, 'numpy'):
v = v.numpy()
return v


def to_list(value) -> List[float]:
r = to_numpy_array(value)
if isinstance(r, np.ndarray):
return r.tolist()
elif isinstance(r, list):
return r
else:
raise TypeError(f'{r} can not be converted into list')
69 changes: 69 additions & 0 deletions tests/unit/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List

import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel
from starlette.testclient import TestClient

from docarray import DocumentArray, Document
from docarray.document.pydantic_model import PydanticDocument, PydanticDocumentArray


def test_pydantic_doc_da(pytestconfig):
da = DocumentArray.from_files(
[
f'{pytestconfig.rootdir}/**/*.png',
f'{pytestconfig.rootdir}/**/*.jpg',
f'{pytestconfig.rootdir}/**/*.jpeg',
]
)

assert da
assert da.get_json_schema(2)
assert da.from_pydantic_model(da.to_pydantic_model())
da.embeddings = np.random.random([len(da), 10])
da_r = da.from_pydantic_model(da.to_pydantic_model())
assert da_r.embeddings.shape == (len(da), 10)


app = FastAPI()


class IdOnly(BaseModel):
id: str


class TextOnly(BaseModel):
text: str


@app.post('/single', response_model=IdOnly)
async def create_item(item: PydanticDocument):
return Document.from_pydantic_model(item).to_pydantic_model()


@app.post('/multi', response_model=List[TextOnly])
async def create_item(items: PydanticDocumentArray):
da = DocumentArray.from_pydantic_model(items)
da.texts = [f'hello_{j}' for j in range(len(da))]
return da.to_pydantic_model()


client = TestClient(app)


def test_read_main():
response = client.post('/single', Document(text='hello').to_json())
r = response.json()
assert r['id']
assert 'text' not in r
assert len(r) == 1

response = client.post('/multi', DocumentArray.empty(2).to_json())

r = response.json()
assert isinstance(r, list)
assert len(r[0]) == 1
assert len(r[1]) == 1
assert r[0]['text'] == 'hello_0'
assert r[1]['text'] == 'hello_1'

0 comments on commit b7e5ce7

Please sign in to comment.