Skip to content

Commit

Permalink
fix: fix create dynamic code class
Browse files Browse the repository at this point in the history
Signed-off-by: Joan Martinez <joan.fontanals.martinez@jina.ai>
  • Loading branch information
JoanFM committed Feb 19, 2024
1 parent 791e4a0 commit 569c431
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
26 changes: 20 additions & 6 deletions docarray/utils/create_dynamic_doc_class.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union, Set

from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
Expand All @@ -22,7 +22,10 @@
]


def create_pure_python_type_model(model: BaseModel) -> BaseDoc:
def create_pure_python_type_model(
model: BaseModel,
cached_models: Optional[Dict[str, Any]] = None,
) -> BaseDoc:
"""
Take a Pydantic model and cast DocList fields into List fields.
Expand All @@ -44,16 +47,18 @@ class MyDoc(BaseDoc):
texts: DocList[TextDoc]
MyDocCorrected = create_new_model_cast_doclist_to_list(CustomDoc)
MyDocCorrected = create_pure_python_type_model(CustomDoc)
```
---
:param model: The input model
:param cached_models: A set of names of models that have been converted to their pure python type model
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
"""
fields: Dict[str, Any] = {}
import copy

cached_models = cached_models or {}

Check warning on line 61 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L61

Added line #L61 was not covered by tests
fields_copy = copy.deepcopy(model.__fields__)
annotations_copy = copy.deepcopy(model.__annotations__)
for field_name, field in annotations_copy.items():
Expand All @@ -67,14 +72,23 @@ class MyDoc(BaseDoc):
try:
if safe_issubclass(field, DocList):
t: Any = field.doc_type
t_aux = create_pure_python_type_model(t)
fields[field_name] = (List[t_aux], field_info)
if t.__name__ in cached_models:
fields[field_name] = (List[cached_models[t.__name__]], field_info)

Check warning on line 76 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L75-L76

Added lines #L75 - L76 were not covered by tests
else:
t_aux = create_pure_python_type_model(t, cached_models)
cached_models[t.__name__] = t_aux
fields[field_name] = (List[t_aux], field_info)

Check warning on line 80 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L78-L80

Added lines #L78 - L80 were not covered by tests
else:
fields[field_name] = (field, field_info)
except TypeError:
fields[field_name] = (field, field_info)

return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)
new_model = create_model(

Check warning on line 86 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L86

Added line #L86 was not covered by tests
model.__name__, __base__=model, __doc__=model.__doc__, **fields
)
cached_models[model.__name__] = new_model

Check warning on line 89 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L89

Added line #L89 was not covered by tests

return new_model

Check warning on line 91 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L91

Added line #L91 was not covered by tests


def _get_field_annotation_from_schema(
Expand Down
26 changes: 26 additions & 0 deletions tests/units/util/test_create_dynamic_code_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,29 @@ class SearchResult(BaseDoc):
QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist)
)
assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey'


def test_create_pure_python_model_with_multiple_doclists_of_same_type():
from docarray import DocList, BaseDoc

class MyTextDoc(BaseDoc):
text: str

class QuoteFile(BaseDoc):
texts: DocList[MyTextDoc]

class QuoteFileType(BaseDoc):
"""
QuoteFileType class.
"""

id: str = (
None # same as name, compatibility reasons for a generic, shared `id` field
)
name: str = None
total_count: int = None
docs: DocList[QuoteFile] = None
chunks: DocList[QuoteFile] = None

new_model = create_pure_python_type_model(QuoteFileType)
new_model.schema()

0 comments on commit 569c431

Please sign in to comment.