Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix create dynamic code class #1871

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
24 changes: 19 additions & 5 deletions docarray/utils/create_dynamic_doc_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
]


def create_pure_python_type_model(model: BaseModel) -> BaseDoc:
def create_pure_python_type_model(
model: BaseModel,
cached_models: Optional[Dict[str, Any]] = None,
) -> BaseDoc:
"""
Take a Pydantic model and cast DocList fields into List fields.

Expand All @@ -44,16 +47,18 @@
texts: DocList[TextDoc]


MyDocCorrected = create_new_model_cast_doclist_to_list(CustomDoc)
MyDocCorrected = create_pure_python_type_model(CustomDoc)
```

---
:param model: The input model
:param cached_models: A set of names of models that have been converted to their pure python type model
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
"""
fields: Dict[str, Any] = {}
import copy

cached_models = cached_models or {}

Check warning on line 61 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L61

Added line #L61 was not covered by tests
fields_copy = copy.deepcopy(model.__fields__)
annotations_copy = copy.deepcopy(model.__annotations__)
for field_name, field in annotations_copy.items():
Expand All @@ -67,14 +72,23 @@
try:
if safe_issubclass(field, DocList):
t: Any = field.doc_type
t_aux = create_pure_python_type_model(t)
fields[field_name] = (List[t_aux], field_info)
if t.__name__ in cached_models:
fields[field_name] = (List[cached_models[t.__name__]], field_info)

Check warning on line 76 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L75-L76

Added lines #L75 - L76 were not covered by tests
else:
t_aux = create_pure_python_type_model(t, cached_models)
cached_models[t.__name__] = t_aux
fields[field_name] = (List[t_aux], field_info)

Check warning on line 80 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L78-L80

Added lines #L78 - L80 were not covered by tests
else:
fields[field_name] = (field, field_info)
except TypeError:
fields[field_name] = (field, field_info)

return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)
new_model = create_model(

Check warning on line 86 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L86

Added line #L86 was not covered by tests
model.__name__, __base__=model, __doc__=model.__doc__, **fields
)
cached_models[model.__name__] = new_model

Check warning on line 89 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L89

Added line #L89 was not covered by tests

return new_model

Check warning on line 91 in docarray/utils/create_dynamic_doc_class.py

View check run for this annotation

Codecov / codecov/patch

docarray/utils/create_dynamic_doc_class.py#L91

Added line #L91 was not covered by tests


def _get_field_annotation_from_schema(
Expand Down
26 changes: 26 additions & 0 deletions tests/units/util/test_create_dynamic_code_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,29 @@ class SearchResult(BaseDoc):
QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist)
)
assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey'


def test_create_pure_python_model_with_multiple_doclists_of_same_type():
from docarray import DocList, BaseDoc

class MyTextDoc(BaseDoc):
text: str

class QuoteFile(BaseDoc):
texts: DocList[MyTextDoc]

class QuoteFileType(BaseDoc):
"""
QuoteFileType class.
"""

id: str = (
None # same as name, compatibility reasons for a generic, shared `id` field
)
name: str = None
total_count: int = None
docs: DocList[QuoteFile] = None
chunks: DocList[QuoteFile] = None

new_model = create_pure_python_type_model(QuoteFileType)
new_model.schema()