Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(internal): improve deserialisation of discriminated unions #1227

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion src/openai/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Protocol,
Required,
TypedDict,
TypeGuard,
final,
override,
runtime_checkable,
Expand All @@ -31,6 +32,7 @@
HttpxRequestFiles,
)
from ._utils import (
PropertyInfo,
is_list,
is_given,
is_mapping,
Expand All @@ -39,6 +41,7 @@
strip_not_given,
extract_type_arg,
is_annotated_type,
strip_annotated_type,
)
from ._compat import (
PYDANTIC_V2,
Expand All @@ -55,6 +58,9 @@
)
from ._constants import RAW_RESPONSE_HEADER

if TYPE_CHECKING:
from pydantic_core.core_schema import ModelField, ModelFieldsSchema

__all__ = ["BaseModel", "GenericModel"]

_T = TypeVar("_T")
Expand Down Expand Up @@ -268,14 +274,18 @@ def _construct_field(value: object, field: FieldInfo, key: str) -> object:

def is_basemodel(type_: type) -> bool:
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
origin = get_origin(type_) or type_
if is_union(type_):
for variant in get_args(type_):
if is_basemodel(variant):
return True

return False

return is_basemodel_type(type_)


def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
origin = get_origin(type_) or type_
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)


Expand All @@ -286,7 +296,10 @@ def construct_type(*, value: object, type_: type) -> object:
"""
# unwrap `Annotated[T, ...]` -> `T`
if is_annotated_type(type_):
meta = get_args(type_)[1:]
type_ = extract_type_arg(type_, 0)
else:
meta = tuple()

# we need to use the origin class for any types that are subscripted generics
# e.g. Dict[str, object]
Expand All @@ -299,6 +312,28 @@ def construct_type(*, value: object, type_: type) -> object:
except Exception:
pass

# if the type is a discriminated union then we want to construct the right variant
# in the union, even if the data doesn't match exactly, otherwise we'd break code
# that relies on the constructed class types, e.g.
#
# class FooType:
# kind: Literal['foo']
# value: str
#
# class BarType:
# kind: Literal['bar']
# value: int
#
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
# we'd end up constructing `FooType` when it should be `BarType`.
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
if discriminator and is_mapping(value):
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
if variant_value and isinstance(variant_value, str):
variant_type = discriminator.mapping.get(variant_value)
if variant_type:
return construct_type(type_=variant_type, value=value)

# if the data is not valid, use the first variant that doesn't fail while deserializing
for variant in args:
try:
Expand Down Expand Up @@ -356,6 +391,129 @@ def construct_type(*, value: object, type_: type) -> object:
return value


@runtime_checkable
class CachedDiscriminatorType(Protocol):
__discriminator__: DiscriminatorDetails


class DiscriminatorDetails:
field_name: str
"""The name of the discriminator field in the variant class, e.g.

```py
class Foo(BaseModel):
type: Literal['foo']
```

Will result in field_name='type'
"""

field_alias_from: str | None
"""The name of the discriminator field in the API response, e.g.

```py
class Foo(BaseModel):
type: Literal['foo'] = Field(alias='type_from_api')
```

Will result in field_alias_from='type_from_api'
"""

mapping: dict[str, type]
"""Mapping of discriminator value to variant type, e.g.

{'foo': FooVariant, 'bar': BarVariant}
"""

def __init__(
self,
*,
mapping: dict[str, type],
discriminator_field: str,
discriminator_alias: str | None,
) -> None:
self.mapping = mapping
self.field_name = discriminator_field
self.field_alias_from = discriminator_alias


def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
if isinstance(union, CachedDiscriminatorType):
return union.__discriminator__

discriminator_field_name: str | None = None

for annotation in meta_annotations:
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
discriminator_field_name = annotation.discriminator
break

if not discriminator_field_name:
return None

mapping: dict[str, type] = {}
discriminator_alias: str | None = None

for variant in get_args(union):
variant = strip_annotated_type(variant)
if is_basemodel_type(variant):
if PYDANTIC_V2:
field = _extract_field_schema_pv2(variant, discriminator_field_name)
if not field:
continue

# Note: if one variant defines an alias then they all should
discriminator_alias = field.get("serialization_alias")

field_schema = field["schema"]

if field_schema["type"] == "literal":
for entry in field_schema["expected"]:
if isinstance(entry, str):
mapping[entry] = variant
else:
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
if not field_info:
continue

# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias

if field_info.annotation and is_literal_type(field_info.annotation):
for entry in get_args(field_info.annotation):
if isinstance(entry, str):
mapping[entry] = variant

if not mapping:
return None

details = DiscriminatorDetails(
mapping=mapping,
discriminator_field=discriminator_field_name,
discriminator_alias=discriminator_alias,
)
cast(CachedDiscriminatorType, union).__discriminator__ = details
return details


def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
schema = model.__pydantic_core_schema__
if schema["type"] != "model":
return None

fields_schema = schema["schema"]
if fields_schema["type"] != "model-fields":
return None

fields_schema = cast("ModelFieldsSchema", fields_schema)

field = fields_schema["fields"].get(field_name)
if not field:
return None

return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]


def validate_type(*, type_: type[_T], value: object) -> _T:
"""Strict validation that the given value matches the expected type"""
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
Expand Down
5 changes: 4 additions & 1 deletion src/openai/_utils/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,24 @@ class MyParams(TypedDict):
alias: str | None
format: PropertyFormat | None
format_template: str | None
discriminator: str | None

def __init__(
self,
*,
alias: str | None = None,
format: PropertyFormat | None = None,
format_template: str | None = None,
discriminator: str | None = None,
) -> None:
self.alias = alias
self.format = format
self.format_template = format_template
self.discriminator = discriminator

@override
def __repr__(self) -> str:
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}')"
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"


def maybe_transform(
Expand Down
Loading