Skip to content

Commit

Permalink
feat: support pydantic data model (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Jan 14, 2022
1 parent b7e5ce7 commit abb332b
Show file tree
Hide file tree
Showing 9 changed files with 320 additions and 24 deletions.
6 changes: 3 additions & 3 deletions docarray/array/mixins/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def to_pydantic_model(self) -> 'PydanticDocumentArray':

@classmethod
def from_pydantic_model(
cls: Type['T'], model: List['BaseModel'], ndarray_as_list: bool = False
cls: Type['T'],
model: List['BaseModel'],
) -> 'T':
"""Convert a list of PydanticDocument into
:param model: the pydantic data model object that represents a DocumentArray
:param ndarray_as_list: if set to True, `embedding` and `blob` are auto-casted to ndarray. :return:
:return: a DocumentArray
"""
from ... import Document

return cls(Document.from_pydantic_model(m, ndarray_as_list) for m in model)
return cls(Document.from_pydantic_model(m) for m in model)
10 changes: 5 additions & 5 deletions docarray/document/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import overload, Dict, Optional, List, TYPE_CHECKING
from typing import overload, Dict, Optional, List, TYPE_CHECKING, Union, Sequence

from .data import DocumentData, default_values
from .mixins import AllMixins
Expand Down Expand Up @@ -58,10 +58,10 @@ def __init__(
location: Optional[List[float]] = None,
embedding: Optional['ArrayType'] = None,
modality: Optional[str] = None,
evaluations: Optional[Dict[str, 'NamedScore']] = None,
scores: Optional[Dict[str, 'NamedScore']] = None,
chunks: Optional['DocumentArray'] = None,
matches: Optional['DocumentArray'] = None,
evaluations: Optional[Dict[str, Dict[str, 'StructValueType']]] = None,
scores: Optional[Dict[str, Dict[str, 'StructValueType']]] = None,
chunks: Optional[Sequence['Document']] = None,
matches: Optional[Sequence['Document']] = None,
):
...

Expand Down
43 changes: 30 additions & 13 deletions docarray/document/mixins/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,31 @@ class PydanticMixin:
"""Provide helper functions to convert to/from a Pydantic model"""

@classmethod
def json_schema(cls, indent: int = 2) -> str:
def get_json_schema(cls, indent: int = 2) -> str:
"""Return a JSON Schema of Document class."""
from ..pydantic_model import PydanticDocument as DP

return DP.schema_json(indent=indent)
from pydantic import schema_json_of

return schema_json_of(DP, title='Document Schema', indent=indent)

def to_pydantic_model(self) -> 'PydanticDocument':
"""Convert a Document object into a Pydantic model."""
from ..pydantic_model import PydanticDocument as DP

return DP(**{f: getattr(self, f) for f in self.non_empty_fields})
_p_dict = {}
for f in self.non_empty_fields:
v = getattr(self, f)
if f in ('matches', 'chunks'):
_p_dict[f] = v.to_pydantic_model()
elif f in ('scores', 'evaluations'):
_p_dict[f] = {k: v.to_dict() for k, v in v.items()}
else:
_p_dict[f] = v
return DP(**_p_dict)

@classmethod
def from_pydantic_model(
cls: Type['T'], model: 'BaseModel', ndarray_as_list: bool = False
) -> 'T':
def from_pydantic_model(cls: Type['T'], model: 'BaseModel') -> 'T':
"""Build a Document object from a Pydantic model
:param model: the pydantic data model object that represents a Document
Expand All @@ -38,15 +47,23 @@ def from_pydantic_model(
from ... import Document

fields = {}
for (field, value) in model.dict(exclude_none=True).items():
if model.chunks:
fields['chunks'] = [Document.from_pydantic_model(d) for d in model.chunks]
if model.matches:
fields['matches'] = [Document.from_pydantic_model(d) for d in model.matches]

for (field, value) in model.dict(
exclude_none=True, exclude={'chunks', 'matches'}
).items():
f_name = field
if f_name == 'chunks' or f_name == 'matches':
fields[f_name] = [Document.from_pydantic_model(d) for d in value]
elif f_name == 'scores' or f_name == 'evaluations':
fields[f_name] = defaultdict(value)
if f_name == 'scores' or f_name == 'evaluations':
from docarray.score import NamedScore

fields[f_name] = defaultdict(NamedScore)
for k, v in value.items():
fields[f_name][k] = NamedScore(v)
elif f_name == 'embedding' or f_name == 'blob':
if not ndarray_as_list:
fields[f_name] = np.array(value)
fields[f_name] = np.array(value)
else:
fields[f_name] = value
return Document(**fields)
2 changes: 1 addition & 1 deletion docarray/proto/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def parse_proto(pb_msg: 'DocumentProto') -> 'Document':
elif f_name == 'location':
fields[f_name] = list(value)
elif f_name == 'scores' or f_name == 'evaluations':
fields[f_name] = defaultdict()
fields[f_name] = defaultdict(NamedScore)
for k, v in value.items():
fields[f_name][k] = NamedScore(
{ff.name: vv for (ff, vv) in v.ListFields()}
Expand Down
231 changes: 231 additions & 0 deletions docs/fundamentals/fastapi-support/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# FastAPI/pydantic Support

Long story short, DocArray supports [pydantic data model](https://pydantic-docs.helpmanual.io/) via {class}`~docarray.document.pydantic_model.PydanticDocument` and {class}`~docarray.document.pydantic_model.PydanticDocumentArray`.

But this is probably too short to make any sense. So let's take a step back and see what does this mean.

When you want to send/receive Document or DocumentArray object via REST API, you can use `.from_json`/`.to_json` that convert the Document/DocumentArray object into JSON. This has been introduced in the {ref}`docarray-serialization` section.

This way, although quite intuitive to many data scientists, is *not* the modern way of building API services. Your engineer friends won't be happy if you give them a service like this. The main problem here is the **data validation**.

Of course, you can include data validation inside your service logic, but it is often brainfuck as you will need to check field by field and repeat things like `isinstance(field, int)`, not even to mention handling nested JSON.

Modern web frameworks validate the data _before_ it enters the core logic. For example, [FastAPI](https://fastapi.tiangolo.com/) leverages [pydantic](https://pydantic-docs.helpmanual.io/) to validate input & output data.

This chapter will introduce how to leverage DocArray's pydantic support in a FastAPI service to build a modern API service.

```{tip}
Features introduced in this chapter require `fastapi` and `pydantic` as dependency, please do `pip install "docarray[full]"` to enable it.
```

## JSON Schema

You can get the [JSON Schema](https://json-schema.org/) (OpenAPI itself is based on JSON Schema) of Document and DocumentArray by {meth}`~docarray.array.mixins.pydantic.PydanticMixin.get_json_schema`.

````{tab} Document
```python
from docarray import Document
Document.get_json_schema()
```
```json
{
"$ref": "#/definitions/PydanticDocument",
"definitions": {
"PydanticDocument": {
"title": "PydanticDocument",
"type": "object",
"properties": {
"id": {
"title": "Id",
"type": "string"
},
```
````
````{tab} DocumentArray
```python
from docarray import DocumentArray
DocumentArray.get_json_schema()
```
```json
{
"title": "DocumentArray Schema",
"type": "array",
"items": {
"$ref": "#/definitions/PydanticDocument"
},
"definitions": {
"PydanticDocument": {
"title": "PydanticDocument",
"type": "object",
"properties": {
"id": {
"title": "Id",
```
````
Hand them over to your engineer friends, they will be happy as now they can understand what data format you are working on. With these schemas, they can easily integrate DocArray into the system.

## FastAPI usage

The fundamentals of FastAPI can be learned from its docs. I won't repeat them here again.

### Validate incoming Document and DocumentArray

You can import {class}`~docarray.document.pydantic_model.PydanticDocument` and {class}`~docarray.document.pydantic_model.PydanticDocumentArray` pydantic data models, and use them to type hint your endpoint. This will enable the data validation.

```python
from docarray.document.pydantic_model import PydanticDocument, PydanticDocumentArray
from fastapi import FastAPI

app = FastAPI()

@app.post('/single')
async def create_item(item: PydanticDocument):
...

@app.post('/multi')
async def create_array(items: PydanticDocumentArray):
...
```

Let's now send some JSON:

```python
from starlette.testclient import TestClient
client = TestClient(app)

response = client.post('/single', {'hello': 'world'})
print(response, response.text)
response = client.post('/single', {'id': [12, 23]})
print(response, response.text)
```

```text
<Response [422]> {"detail":[{"loc":["body"],"msg":"value is not a valid dict","type":"type_error.dict"}]}
<Response [422]> {"detail":[{"loc":["body"],"msg":"value is not a valid dict","type":"type_error.dict"}]}
```

Both got rejected (422 error) as they are not valid.

## Convert between pydantic model and DocArray objects

{class}`~docarray.document.pydantic_model.PydanticDocument` and {class}`~docarray.document.pydantic_model.PydanticDocumentArray` are mainly for data validation. When you want to implement real logics, you need to convert it into Document or DocumentArray. This can be easily achieved via {meth}`~docarray.array.mixins.pydantic.PydanticMixin.from_pydantic_model`. When you are done with processing and want to send back, you can call {meth}`~docarray.array.mixins.pydantic.PydanticMixin.to_pydantic_model`.

In a nutshell, the whole procedure looks like the following:

```{figure} lifetime-pydantic.svg
```


Let's see an example,

```python
from docarray import Document, DocumentArray

@app.post('/single')
async def create_item(item: PydanticDocument):
d = Document.from_pydantic_model(item)
# now `d` is a Document object
... # process `d` how ever you want
return d.to_pydantic_model()


@app.post('/multi')
async def create_array(items: PydanticDocumentArray):
da = DocumentArray.from_pydantic_model(items)
# now `da` is a DocumentArray object
... # process `da` how ever you want
return da.to_pydantic_model()
```



## Limit returned fields by response model

Supporting pydantic data model means much more beyond data validation. One useful pattern is to define a smaller data model and restrict the response to certain fields of the Document.

Imagine we have a DocumentArray with `.embeddings` on the server side. But we do not want to return them to the client for some reasons (1. meaningless to users; 2. too big to transfer). One can simply define the interested fields via
`pydantic.BaseModel` and then use it in `response_model=`.

```python
from pydantic import BaseModel
from docarray import Document

class IdOnly(BaseModel):
id: str

@app.get('/single', response_model=IdOnly)
async def get_item_no_embedding():
d = Document(embedding=[1, 2, 3])
return d.to_pydantic_model()
```

And you get:

```text
<Response [200]> {'id': '065a5548756211ecaa8d1e008a366d49'}
```

## Limit returned results recursively

The same idea applies to DocumentArray as well. Say after [`.match()`](../documentarray/matching.md), you are only interested in `.id` - the parent `.id` and all matches `id`. You can declare a `BaseModel` as follows:

```python
from typing import List, Optional

class IdAndMatch(BaseModel):
id: str
matches: Optional[List['IdMatch']]
```

Bind it to `response_model`:

```python
@app.get('/get_match', response_model=List[IdAndMatch])
async def get_match_id_only():
da = DocumentArray.empty(10)
da.embeddings = np.random.random([len(da), 3])
da.match(da)
return da.to_pydantic_model()
```

Then you get a very nice result of `id`s of matches (potentially unlimited depth).

```text
[{'id': 'ef82e4f4756411ecb2c01e008a366d49',
'matches': [{'id': 'ef82e4f4756411ecb2c01e008a366d49', 'matches': None},
{'id': 'ef82e6d4756411ecb2c01e008a366d49', 'matches': None},
{'id': 'ef82e760756411ecb2c01e008a366d49', 'matches': None},
{'id': 'ef82e7ec756411ecb2c01e008a366d49', 'matches': None},
...
```

If `'matches': None` is annoying to you (they are here because you didn't compute second-degree matches), you can further leverage FastAPI's feature and do:
```python
@app.get('/get_match', response_model=List[IdMatch], response_model_exclude_none=True)
async def get_match_id_only():
...
```

Finally, you get a very clean results with ids and matches only:

```text
[{'id': '3da6383e756511ecb7cb1e008a366d49',
'matches': [{'id': '3da6383e756511ecb7cb1e008a366d49'},
{'id': '3da63a14756511ecb7cb1e008a366d49'},
{'id': '3da6392e756511ecb7cb1e008a366d49'},
{'id': '3da63b72756511ecb7cb1e008a366d49'},
{'id': '3da639ce756511ecb7cb1e008a366d49'},
{'id': '3da63a5a756511ecb7cb1e008a366d49'},
{'id': '3da63ae6756511ecb7cb1e008a366d49'},
{'id': '3da63aa0756511ecb7cb1e008a366d49'},
{'id': '3da63b2c756511ecb7cb1e008a366d49'},
{'id': '3da63988756511ecb7cb1e008a366d49'}]},
{'id': '3da6392e756511ecb7cb1e008a366d49',
'matches': [{'id': '3da6392e756511ecb7cb1e008a366d49'},
{'id': '3da639ce756511ecb7cb1e008a366d49'},
...
```

More tricks and usages of pydantic model can be found in its docs. Same for FastAPI. I strongly recommend interested readers to go through their documentations.
1 change: 1 addition & 0 deletions docs/fundamentals/fastapi-support/lifetime-pydantic.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ get-started/what-is
fundamentals/document/index
fundamentals/documentarray/index
fundamentals/notebook-support/index
datatypes/index
fundamentals/notebook-support/index
fundamentals/fastapi-support/index
```


Expand Down
2 changes: 2 additions & 0 deletions tests/unit/document/test_protobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def test_from_to_namescore_default_dict(attr, meth):
d = Document()
getattr(d, attr)['relevance'].value = 3.0
assert isinstance(d.scores, defaultdict)
assert isinstance(d.scores['random_score1'], NamedScore)

r_d = getattr(Document, f'from_{meth}')(getattr(d, f'to_{meth}')())
assert isinstance(r_d.scores, defaultdict)
assert isinstance(r_d.scores['random_score2'], NamedScore)
Loading

0 comments on commit abb332b

Please sign in to comment.