diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index d10f5bf23f..08e526d9cc 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union, Set from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo @@ -22,7 +22,10 @@ ] -def create_pure_python_type_model(model: BaseModel) -> BaseDoc: +def create_pure_python_type_model( + model: BaseModel, + cached_models: Optional[Dict[str, Any]] = None, +) -> BaseDoc: """ Take a Pydantic model and cast DocList fields into List fields. @@ -44,16 +47,18 @@ class MyDoc(BaseDoc): texts: DocList[TextDoc] - MyDocCorrected = create_new_model_cast_doclist_to_list(CustomDoc) + MyDocCorrected = create_pure_python_type_model(CustomDoc) ``` --- :param model: The input model + :param cached_models: A set of names of models that have been converted to their pure python type model :return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List. """ fields: Dict[str, Any] = {} import copy + cached_models = cached_models or {} fields_copy = copy.deepcopy(model.__fields__) annotations_copy = copy.deepcopy(model.__annotations__) for field_name, field in annotations_copy.items(): @@ -67,14 +72,23 @@ class MyDoc(BaseDoc): try: if safe_issubclass(field, DocList): t: Any = field.doc_type - t_aux = create_pure_python_type_model(t) - fields[field_name] = (List[t_aux], field_info) + if t.__name__ in cached_models: + fields[field_name] = (List[cached_models[t.__name__]], field_info) + else: + t_aux = create_pure_python_type_model(t, cached_models) + cached_models[t.__name__] = t_aux + fields[field_name] = (List[t_aux], field_info) else: fields[field_name] = (field, field_info) except TypeError: fields[field_name] = (field, field_info) - return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields) + new_model = create_model( + model.__name__, __base__=model, __doc__=model.__doc__, **fields + ) + cached_models[model.__name__] = new_model + + return new_model def _get_field_annotation_from_schema( diff --git a/tests/units/util/test_create_dynamic_code_class.py b/tests/units/util/test_create_dynamic_code_class.py index 9d9ec3d0b2..e69b3367ee 100644 --- a/tests/units/util/test_create_dynamic_code_class.py +++ b/tests/units/util/test_create_dynamic_code_class.py @@ -315,3 +315,29 @@ class SearchResult(BaseDoc): QuoteFile_reconstructed_in_gateway_from_Search_results(id='0', texts=textlist) ) assert reconstructed_in_gateway_from_Search_results.texts[0].text == 'hey' + + +def test_create_pure_python_model_with_multiple_doclists_of_same_type(): + from docarray import DocList, BaseDoc + + class MyTextDoc(BaseDoc): + text: str + + class QuoteFile(BaseDoc): + texts: DocList[MyTextDoc] + + class QuoteFileType(BaseDoc): + """ + QuoteFileType class. + """ + + id: str = ( + None # same as name, compatibility reasons for a generic, shared `id` field + ) + name: str = None + total_count: int = None + docs: DocList[QuoteFile] = None + chunks: DocList[QuoteFile] = None + + new_model = create_pure_python_type_model(QuoteFileType) + new_model.schema()