Skip to content

Commit

Permalink
fix(document): serialize tag value in the correct priority (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxiao committed Jan 18, 2022
1 parent f5c013f commit e179ef6
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docarray/document/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
if TYPE_CHECKING:
from ..types import ArrayType

_ProtoValueType = Optional[Union[str, bool, float]]
# this order must be preserved: https://pydantic-docs.helpmanual.io/usage/types/#unions
_ProtoValueType = Optional[Union[bool, float, str]]
_StructValueType = Union[
_ProtoValueType, List[_ProtoValueType], Dict[str, _ProtoValueType]
]
Expand Down
6 changes: 3 additions & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def random_docs(

d = Document(id=doc_id)
d.text = text
d.tags['id'] = doc_id
d.tags['id'] = f'myself id is: {doc_id}'
if embedding:
if sparse_embedding:
from scipy.sparse import coo_matrix
Expand All @@ -42,8 +42,8 @@ def random_docs(
c.embedding = np.random.random(
[embed_dim + np.random.randint(0, jitter)]
)
c.tags['parent_id'] = doc_id
c.tags['id'] = chunk_doc_id
c.tags['parent_id'] = f'my parent is: {id}'
c.tags['id'] = f'myself id is: {doc_id}'
d.chunks.append(c)
next_chunk_doc_id += 1

Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional

import numpy as np
import pytest
from fastapi import FastAPI
from pydantic import BaseModel
from starlette.testclient import TestClient
Expand Down Expand Up @@ -116,3 +117,28 @@ def test_match_to_from_pydantic():
def test_with_embedding_no_tensor():
d = Document(embedding=np.random.rand(2, 2))
PydanticDocument.parse_obj(d.to_pydantic_model().dict())


@pytest.mark.parametrize(
'tag_value, tag_type',
[(3, float), (3.4, float), ('hello', str), (True, bool), (False, bool)],
)
@pytest.mark.parametrize('protocol', ['protobuf', 'jsonschema'])
def test_tags_int_float_str_bool(tag_type, tag_value, protocol):
d = Document(tags={'hello': tag_value})
dd = d.to_dict(protocol=protocol)['tags']['hello']
assert dd == tag_value
assert isinstance(dd, tag_type)

# now nested tags in dict

d = Document(tags={'hello': {'world': tag_value}})
dd = d.to_dict(protocol=protocol)['tags']['hello']['world']
assert dd == tag_value
assert isinstance(dd, tag_type)

# now nested in list
d = Document(tags={'hello': [tag_value] * 10})
dd = d.to_dict(protocol=protocol)['tags']['hello'][-1]
assert dd == tag_value
assert isinstance(dd, tag_type)

0 comments on commit e179ef6

Please sign in to comment.