diff --git a/src/openai/_models.py b/src/openai/_models.py index af71a91850..51c3be04fa 100644 --- a/src/openai/_models.py +++ b/src/openai/_models.py @@ -3,6 +3,7 @@ import os import inspect from typing import TYPE_CHECKING, Any, Type, Tuple, Union, Generic, TypeVar, Callable, Optional, cast +from weakref import WeakKeyDictionary from datetime import date, datetime from typing_extensions import ( List, @@ -77,6 +78,8 @@ ReprArgs = Sequence[Tuple[Optional[str], Any]] +_DISCRIMINATOR_CACHE: "WeakKeyDictionary[type, DiscriminatorDetails]" = WeakKeyDictionary() + @runtime_checkable class _ConfigProtocol(Protocol): @@ -593,11 +596,6 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any] return value -@runtime_checkable -class CachedDiscriminatorType(Protocol): - __discriminator__: DiscriminatorDetails - - class DiscriminatorDetails: field_name: str """The name of the discriminator field in the variant class, e.g. @@ -640,8 +638,9 @@ def __init__( def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None: - if isinstance(union, CachedDiscriminatorType): - return union.__discriminator__ + cached_discriminator = _DISCRIMINATOR_CACHE.get(union) + if cached_discriminator is not None: + return cached_discriminator discriminator_field_name: str | None = None @@ -694,7 +693,7 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, discriminator_field=discriminator_field_name, discriminator_alias=discriminator_alias, ) - cast(CachedDiscriminatorType, union).__discriminator__ = details + _DISCRIMINATOR_CACHE[union] = details return details diff --git a/tests/test_models.py b/tests/test_models.py index 410ec3bf4e..48c4748816 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,7 +9,7 @@ from openai._utils import PropertyInfo from openai._compat import PYDANTIC_V1, parse_obj, model_dump, model_json -from openai._models import BaseModel, construct_type +from openai._models import BaseModel, construct_type, _DISCRIMINATOR_CACHE class BasicModel(BaseModel): @@ -809,7 +809,7 @@ class B(BaseModel): UnionType = cast(Any, Union[A, B]) - assert not hasattr(UnionType, "__discriminator__") + assert _DISCRIMINATOR_CACHE.get(UnionType) is None m = construct_type( value={"type": "b", "data": "foo"}, type_=cast(Any, Annotated[UnionType, PropertyInfo(discriminator="type")]) @@ -818,7 +818,8 @@ class B(BaseModel): assert m.type == "b" assert m.data == "foo" # type: ignore[comparison-overlap] - discriminator = UnionType.__discriminator__ + discriminator = _DISCRIMINATOR_CACHE.get(UnionType) + assert discriminator is not None m = construct_type( @@ -830,7 +831,7 @@ class B(BaseModel): # if the discriminator details object stays the same between invocations then # we hit the cache - assert UnionType.__discriminator__ is discriminator + assert _DISCRIMINATOR_CACHE.get(UnionType) is discriminator @pytest.mark.skipif(PYDANTIC_V1, reason="TypeAliasType is not supported in Pydantic v1")