Skip to content

support model_validate #755

@1st1

Description

@1st1

We currently don't support model_validate but we should, it's one of the most used Pydantic APIs. For the most part the following does the trick:

diff --git a/gel/_internal/_qbmodel/_pydantic/_fields.py b/gel/_internal/_qbmodel/_pydantic/_fields.py
index 4656af7..7ea9cd6 100644
--- a/gel/_internal/_qbmodel/_pydantic/_fields.py
+++ b/gel/_internal/_qbmodel/_pydantic/_fields.py
@@ -70,6 +70,15 @@ class _MultiPointer(_abstract.PointerDescriptor[_T_co, _BT_co]):
     ) -> Any:
         raise NotImplementedError

+    @classmethod
+    def _get_final_type(
+        cls,
+        generic_args: tuple[type[Any], type[Any]],
+    ) -> type[Any]:
+        raise NotImplementedError(
+            f"{cls.__name__} is missing _get_final_type() implementation"
+        )
+
     @classmethod
     def __get_pydantic_core_schema__(
         cls,
@@ -78,11 +87,29 @@ class _MultiPointer(_abstract.PointerDescriptor[_T_co, _BT_co]):
     ) -> pydantic_core.CoreSchema:
         if _typing_inspect.is_generic_alias(source_type):
             args = typing.get_args(source_type)
-            return core_schema.no_info_plain_validator_function(
-                functools.partial(cls._validate, generic_args=args),
-                serialization=core_schema.plain_serializer_function_ser_schema(
-                    list,
-                ),
+            final_type = cls._get_final_type(args)
+
+            inner_schema = handler.generate_schema(args[0])
+            return core_schema.union_schema(
+                [
+                    core_schema.is_instance_schema(
+                        final_type,
+                        cls_repr=final_type.__name__,
+                    ),
+                    core_schema.chain_schema(
+                        [
+                            core_schema.list_schema(inner_schema),
+                            core_schema.no_info_plain_validator_function(
+                                functools.partial(
+                                    cls._validate, generic_args=args
+                                ),
+                                serialization=core_schema.plain_serializer_function_ser_schema(
+                                    list,
+                                ),
+                            ),
+                        ]
+                    ),
+                ]
             )
         else:
             return handler.generate_schema(source_type)

however! Patterns that involve inheritance don't work out of the box:

a = models.Account.model_validate(
  {"username": "aaa", "watchlist": [{"title": "bb"}]}
)
print(a)
# id=<UUID: UNSET> username='aaa' watchlist=[Content(title='bb')]

so basically the base type got picked up and used.

This would be exact same problem when vanilla pydantic is used naively:

class Content(pydantic.BaseModel):
    title: str

class Movie(Content):
    pass

class TVShow(Content):
    pass

class Account(pydantic.BaseModel):
    username: str
    watchlist: list[Content]

a = models.Account.model_validate(
  {"username": "aaa", "watchlist": [{"title": "bb"}]}
)
print(a)
# username='aaa' watchlist=[Content(title='bb')]

However, pydantic has mechanisms to specify how unions should be discriminated: https://docs.pydantic.dev/latest/concepts/unions/

class Content(pydantic.BaseModel):
    title: str
    kind: typing.Literal["movie", "tvshow"]


class Movie(Content):
    kind: typing.Literal["movie"]


class TVShow(Content):
    kind: typing.Literal["tvshow"]


class Account(pydantic.BaseModel):
    username: str
    watchlist: list[
        typing.Annotated[
            Movie | TVShow,
            pydantic.Field(
                discriminator="kind",
            ),
        ]
    ]

a = Account.model_validate(
    {"username": "aaa", "watchlist": [{"kind": "movie", "title": "bb"}]}
)

print(a)
# username='aaa' watchlist=[Movie(title='bb', kind='movie')]

We need to discuss what exactly do we do here:

  • always require __type__ passed explicitly?
  • add another schema-level annotation to tag a field used to discriminate?
  • add a schema-level union tag to make it more formal and validated than an annotation?
  • do both? other options?

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions