-
Notifications
You must be signed in to change notification settings - Fork 49
Closed
Description
We currently don't support model_validate but we should, it's one of the most used Pydantic APIs. For the most part the following does the trick:
diff --git a/gel/_internal/_qbmodel/_pydantic/_fields.py b/gel/_internal/_qbmodel/_pydantic/_fields.py
index 4656af7..7ea9cd6 100644
--- a/gel/_internal/_qbmodel/_pydantic/_fields.py
+++ b/gel/_internal/_qbmodel/_pydantic/_fields.py
@@ -70,6 +70,15 @@ class _MultiPointer(_abstract.PointerDescriptor[_T_co, _BT_co]):
) -> Any:
raise NotImplementedError
+ @classmethod
+ def _get_final_type(
+ cls,
+ generic_args: tuple[type[Any], type[Any]],
+ ) -> type[Any]:
+ raise NotImplementedError(
+ f"{cls.__name__} is missing _get_final_type() implementation"
+ )
+
@classmethod
def __get_pydantic_core_schema__(
cls,
@@ -78,11 +87,29 @@ class _MultiPointer(_abstract.PointerDescriptor[_T_co, _BT_co]):
) -> pydantic_core.CoreSchema:
if _typing_inspect.is_generic_alias(source_type):
args = typing.get_args(source_type)
- return core_schema.no_info_plain_validator_function(
- functools.partial(cls._validate, generic_args=args),
- serialization=core_schema.plain_serializer_function_ser_schema(
- list,
- ),
+ final_type = cls._get_final_type(args)
+
+ inner_schema = handler.generate_schema(args[0])
+ return core_schema.union_schema(
+ [
+ core_schema.is_instance_schema(
+ final_type,
+ cls_repr=final_type.__name__,
+ ),
+ core_schema.chain_schema(
+ [
+ core_schema.list_schema(inner_schema),
+ core_schema.no_info_plain_validator_function(
+ functools.partial(
+ cls._validate, generic_args=args
+ ),
+ serialization=core_schema.plain_serializer_function_ser_schema(
+ list,
+ ),
+ ),
+ ]
+ ),
+ ]
)
else:
return handler.generate_schema(source_type)however! Patterns that involve inheritance don't work out of the box:
a = models.Account.model_validate(
{"username": "aaa", "watchlist": [{"title": "bb"}]}
)
print(a)
# id=<UUID: UNSET> username='aaa' watchlist=[Content(title='bb')]so basically the base type got picked up and used.
This would be exact same problem when vanilla pydantic is used naively:
class Content(pydantic.BaseModel):
title: str
class Movie(Content):
pass
class TVShow(Content):
pass
class Account(pydantic.BaseModel):
username: str
watchlist: list[Content]
a = models.Account.model_validate(
{"username": "aaa", "watchlist": [{"title": "bb"}]}
)
print(a)
# username='aaa' watchlist=[Content(title='bb')]However, pydantic has mechanisms to specify how unions should be discriminated: https://docs.pydantic.dev/latest/concepts/unions/
class Content(pydantic.BaseModel):
title: str
kind: typing.Literal["movie", "tvshow"]
class Movie(Content):
kind: typing.Literal["movie"]
class TVShow(Content):
kind: typing.Literal["tvshow"]
class Account(pydantic.BaseModel):
username: str
watchlist: list[
typing.Annotated[
Movie | TVShow,
pydantic.Field(
discriminator="kind",
),
]
]
a = Account.model_validate(
{"username": "aaa", "watchlist": [{"kind": "movie", "title": "bb"}]}
)
print(a)
# username='aaa' watchlist=[Movie(title='bb', kind='movie')]We need to discuss what exactly do we do here:
- always require
__type__passed explicitly? - add another schema-level annotation to tag a field used to discriminate?
- add a schema-level union tag to make it more formal and validated than an annotation?
- do both? other options?
Metadata
Metadata
Assignees
Labels
No labels