From 79c13a09980c97c9b4fcf9857566ff2a46c95c83 Mon Sep 17 00:00:00 2001 From: Han Xiao Date: Tue, 18 Jan 2022 18:30:25 +0100 Subject: [PATCH] fix(document): serialize blob with base64 in dict/json (#63) --- docarray/document/__init__.py | 2 -- docarray/document/mixins/pydantic.py | 10 +++++++++- docarray/document/pydantic_model.py | 9 +++++++++ tests/unit/test_pydantic.py | 17 +++++++++++++++++ 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/docarray/document/__init__.py b/docarray/document/__init__.py index 5083c10e7ac..b2bf5a6cb1f 100644 --- a/docarray/document/__init__.py +++ b/docarray/document/__init__.py @@ -6,8 +6,6 @@ if TYPE_CHECKING: from ..types import ArrayType, StructValueType, DocumentContentType - from .. import DocumentArray - from ..score import NamedScore class Document(AllMixins, BaseDCType): diff --git a/docarray/document/mixins/pydantic.py b/docarray/document/mixins/pydantic.py index a8e2abbb399..72347eb2d4d 100644 --- a/docarray/document/mixins/pydantic.py +++ b/docarray/document/mixins/pydantic.py @@ -1,3 +1,4 @@ +import base64 from collections import defaultdict from typing import TYPE_CHECKING, Type @@ -41,7 +42,6 @@ def from_pydantic_model(cls: Type['T'], model: 'BaseModel') -> 'T': """Build a Document object from a Pydantic model :param model: the pydantic data model object that represents a Document - :param ndarray_as_list: if set to True, `embedding` and `tensor` are auto-casted to ndarray. :return: a Document object """ from ... import Document @@ -65,6 +65,14 @@ def from_pydantic_model(cls: Type['T'], model: 'BaseModel') -> 'T': fields[f_name][k] = NamedScore(v) elif f_name == 'embedding' or f_name == 'tensor': fields[f_name] = np.array(value) + elif f_name == 'blob': + # here is a dirty fishy itchy trick + # the original bytes will be encoded two times: + # first time is real during `to_dict/to_json`, it converts into base64 string + # second time is at `from_dict/from_json`, it is unnecessary yet inevitable, the result string get + # converted into a binary string and encoded again. + # consequently, we need to decode two times here! + fields[f_name] = base64.b64decode(base64.b64decode(value)) else: fields[f_name] = value diff --git a/docarray/document/pydantic_model.py b/docarray/document/pydantic_model.py index 2be077c75eb..596fa568b78 100644 --- a/docarray/document/pydantic_model.py +++ b/docarray/document/pydantic_model.py @@ -1,3 +1,4 @@ +import base64 from typing import Optional, List, Dict, Any, TYPE_CHECKING, Union from pydantic import BaseModel, validator @@ -43,6 +44,14 @@ class PydanticDocument(BaseModel): _tensor2list = validator('tensor', allow_reuse=True)(_convert_ndarray_to_list) _embedding2list = validator('embedding', allow_reuse=True)(_convert_ndarray_to_list) + @validator('blob') + def _blob2base64(cls, v): + if v is not None: + if isinstance(v, bytes): + return base64.b64encode(v).decode('utf8') + else: + raise ValueError('must be bytes') + PydanticDocument.update_forward_refs() diff --git a/tests/unit/test_pydantic.py b/tests/unit/test_pydantic.py index 9f9c3e1c22f..cdd8a87da4d 100644 --- a/tests/unit/test_pydantic.py +++ b/tests/unit/test_pydantic.py @@ -1,3 +1,4 @@ +import os from collections import defaultdict from typing import List, Optional @@ -142,3 +143,19 @@ def test_tags_int_float_str_bool(tag_type, tag_value, protocol): dd = d.to_dict(protocol=protocol)['tags']['hello'][-1] assert dd == tag_value assert isinstance(dd, tag_type) + + +@pytest.mark.parametrize( + 'blob', [None, b'123', bytes(Document()), bytes(bytearray(os.urandom(512 * 4)))] +) +@pytest.mark.parametrize('protocol', ['jsonschema', 'protobuf']) +@pytest.mark.parametrize('to_fn', ['dict', 'json']) +def test_to_from_with_blob(protocol, to_fn, blob): + d = Document(blob=blob) + r_d = getattr(Document, f'from_{to_fn}')( + getattr(d, f'to_{to_fn}')(protocol=protocol), protocol=protocol + ) + + assert d.blob == r_d.blob + if d.blob: + assert isinstance(r_d.blob, bytes)