-
Notifications
You must be signed in to change notification settings - Fork 225
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: initial support on pydantic (#49)
- Loading branch information
Showing
7 changed files
with
219 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import TYPE_CHECKING, Type, List | ||
|
||
if TYPE_CHECKING: | ||
from ...document.pydantic_model import PydanticDocumentArray | ||
from ...types import T | ||
from pydantic import BaseModel | ||
|
||
|
||
class PydanticMixin: | ||
@classmethod | ||
def get_json_schema(cls, indent: int = 2) -> str: | ||
"""Return a JSON Schema of DocumentArray class.""" | ||
from pydantic import schema_json_of | ||
from ...document.pydantic_model import PydanticDocumentArray | ||
|
||
return schema_json_of( | ||
PydanticDocumentArray, title='DocumentArray Schema', indent=indent | ||
) | ||
|
||
def to_pydantic_model(self) -> 'PydanticDocumentArray': | ||
"""Convert a DocumentArray object into a Pydantic model.""" | ||
return [d.to_pydantic_model() for d in self] | ||
|
||
@classmethod | ||
def from_pydantic_model( | ||
cls: Type['T'], model: List['BaseModel'], ndarray_as_list: bool = False | ||
) -> 'T': | ||
"""Convert a list of PydanticDocument into | ||
:param model: the pydantic data model object that represents a DocumentArray | ||
:param ndarray_as_list: if set to True, `embedding` and `blob` are auto-casted to ndarray. :return: | ||
:return: a DocumentArray | ||
""" | ||
from ... import Document | ||
|
||
return cls(Document.from_pydantic_model(m, ndarray_as_list) for m in model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from collections import defaultdict | ||
from typing import TYPE_CHECKING, Type | ||
|
||
import numpy as np | ||
|
||
if TYPE_CHECKING: | ||
from pydantic import BaseModel | ||
from ...types import T | ||
from ..pydantic_model import PydanticDocument | ||
|
||
|
||
class PydanticMixin: | ||
"""Provide helper functions to convert to/from a Pydantic model""" | ||
|
||
@classmethod | ||
def json_schema(cls, indent: int = 2) -> str: | ||
"""Return a JSON Schema of Document class.""" | ||
from ..pydantic_model import PydanticDocument as DP | ||
|
||
return DP.schema_json(indent=indent) | ||
|
||
def to_pydantic_model(self) -> 'PydanticDocument': | ||
"""Convert a Document object into a Pydantic model.""" | ||
from ..pydantic_model import PydanticDocument as DP | ||
|
||
return DP(**{f: getattr(self, f) for f in self.non_empty_fields}) | ||
|
||
@classmethod | ||
def from_pydantic_model( | ||
cls: Type['T'], model: 'BaseModel', ndarray_as_list: bool = False | ||
) -> 'T': | ||
"""Build a Document object from a Pydantic model | ||
:param model: the pydantic data model object that represents a Document | ||
:param ndarray_as_list: if set to True, `embedding` and `blob` are auto-casted to ndarray. | ||
:return: a Document object | ||
""" | ||
from ... import Document | ||
|
||
fields = {} | ||
for (field, value) in model.dict(exclude_none=True).items(): | ||
f_name = field | ||
if f_name == 'chunks' or f_name == 'matches': | ||
fields[f_name] = [Document.from_pydantic_model(d) for d in value] | ||
elif f_name == 'scores' or f_name == 'evaluations': | ||
fields[f_name] = defaultdict(value) | ||
elif f_name == 'embedding' or f_name == 'blob': | ||
if not ndarray_as_list: | ||
fields[f_name] = np.array(value) | ||
else: | ||
fields[f_name] = value | ||
return Document(**fields) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Optional, List, Dict, Any, TYPE_CHECKING, Union | ||
|
||
from pydantic import BaseModel, validator | ||
|
||
from ..math.ndarray import to_list | ||
|
||
if TYPE_CHECKING: | ||
from ..types import ArrayType | ||
|
||
_ProtoValueType = Optional[Union[str, bool, float]] | ||
_StructValueType = Union[ | ||
_ProtoValueType, List[_ProtoValueType], Dict[str, _ProtoValueType] | ||
] | ||
|
||
|
||
def _convert_ndarray_to_list(v: 'ArrayType'): | ||
return to_list(v) | ||
|
||
|
||
class PydanticDocument(BaseModel): | ||
id: str | ||
parent_id: Optional[str] | ||
granularity: Optional[int] | ||
adjacency: Optional[int] | ||
buffer: Optional[bytes] | ||
blob: Optional[Any] | ||
mime_type: Optional[str] | ||
text: Optional[str] | ||
weight: Optional[float] | ||
uri: Optional[str] | ||
tags: Optional[Dict[str, '_StructValueType']] | ||
offset: Optional[float] | ||
location: Optional[List[float]] | ||
embedding: Optional[Any] | ||
modality: Optional[str] | ||
evaluations: Optional[Dict[str, Dict[str, '_StructValueType']]] | ||
scores: Optional[Dict[str, Dict[str, '_StructValueType']]] | ||
chunks: Optional[List['PydanticDocument']] | ||
matches: Optional[List['PydanticDocument']] | ||
|
||
_blob2list = validator('blob', allow_reuse=True)(_convert_ndarray_to_list) | ||
_embedding2list = validator('embedding', allow_reuse=True)(_convert_ndarray_to_list) | ||
|
||
|
||
PydanticDocument.update_forward_refs() | ||
|
||
PydanticDocumentArray = List[PydanticDocument] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
from fastapi import FastAPI | ||
from pydantic import BaseModel | ||
from starlette.testclient import TestClient | ||
|
||
from docarray import DocumentArray, Document | ||
from docarray.document.pydantic_model import PydanticDocument, PydanticDocumentArray | ||
|
||
|
||
def test_pydantic_doc_da(pytestconfig): | ||
da = DocumentArray.from_files( | ||
[ | ||
f'{pytestconfig.rootdir}/**/*.png', | ||
f'{pytestconfig.rootdir}/**/*.jpg', | ||
f'{pytestconfig.rootdir}/**/*.jpeg', | ||
] | ||
) | ||
|
||
assert da | ||
assert da.get_json_schema(2) | ||
assert da.from_pydantic_model(da.to_pydantic_model()) | ||
da.embeddings = np.random.random([len(da), 10]) | ||
da_r = da.from_pydantic_model(da.to_pydantic_model()) | ||
assert da_r.embeddings.shape == (len(da), 10) | ||
|
||
|
||
app = FastAPI() | ||
|
||
|
||
class IdOnly(BaseModel): | ||
id: str | ||
|
||
|
||
class TextOnly(BaseModel): | ||
text: str | ||
|
||
|
||
@app.post('/single', response_model=IdOnly) | ||
async def create_item(item: PydanticDocument): | ||
return Document.from_pydantic_model(item).to_pydantic_model() | ||
|
||
|
||
@app.post('/multi', response_model=List[TextOnly]) | ||
async def create_item(items: PydanticDocumentArray): | ||
da = DocumentArray.from_pydantic_model(items) | ||
da.texts = [f'hello_{j}' for j in range(len(da))] | ||
return da.to_pydantic_model() | ||
|
||
|
||
client = TestClient(app) | ||
|
||
|
||
def test_read_main(): | ||
response = client.post('/single', Document(text='hello').to_json()) | ||
r = response.json() | ||
assert r['id'] | ||
assert 'text' not in r | ||
assert len(r) == 1 | ||
|
||
response = client.post('/multi', DocumentArray.empty(2).to_json()) | ||
|
||
r = response.json() | ||
assert isinstance(r, list) | ||
assert len(r[0]) == 1 | ||
assert len(r[1]) == 1 | ||
assert r[0]['text'] == 'hello_0' | ||
assert r[1]['text'] == 'hello_1' |