diff --git a/src/openai/_response.py b/src/openai/_response.py index 933c37525e..6b7c86e544 100644 --- a/src/openai/_response.py +++ b/src/openai/_response.py @@ -5,12 +5,12 @@ import datetime import functools from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast -from typing_extensions import Awaitable, ParamSpec, get_args, override, get_origin +from typing_extensions import Awaitable, ParamSpec, override, get_origin import httpx from ._types import NoneType, UnknownResponse, BinaryResponseContent -from ._utils import is_given +from ._utils import is_given, extract_type_var_from_base from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER from ._exceptions import APIResponseValidationError @@ -221,12 +221,13 @@ def __init__(self) -> None: def _extract_stream_chunk_type(stream_cls: type) -> type: - args = get_args(stream_cls) - if not args: - raise TypeError( - f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}", - ) - return cast(type, args[0]) + from ._base_client import Stream, AsyncStream + + return extract_type_var_from_base( + stream_cls, + index=0, + generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), + ) def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]: diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index e323c59ac0..f1896a242a 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -2,12 +2,12 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Generic, Iterator, AsyncIterator -from typing_extensions import override +from types import TracebackType +from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast +from typing_extensions import Self, override import httpx -from ._types import ResponseT from ._utils import is_mapping from ._exceptions import APIError @@ -15,7 +15,10 @@ from ._client import OpenAI, AsyncOpenAI -class Stream(Generic[ResponseT]): +_T = TypeVar("_T") + + +class Stream(Generic[_T]): """Provides the core interface to iterate over a synchronous stream response.""" response: httpx.Response @@ -23,7 +26,7 @@ class Stream(Generic[ResponseT]): def __init__( self, *, - cast_to: type[ResponseT], + cast_to: type[_T], response: httpx.Response, client: OpenAI, ) -> None: @@ -33,18 +36,18 @@ def __init__( self._decoder = SSEDecoder() self._iterator = self.__stream__() - def __next__(self) -> ResponseT: + def __next__(self) -> _T: return self._iterator.__next__() - def __iter__(self) -> Iterator[ResponseT]: + def __iter__(self) -> Iterator[_T]: for item in self._iterator: yield item def _iter_events(self) -> Iterator[ServerSentEvent]: yield from self._decoder.iter(self.response.iter_lines()) - def __stream__(self) -> Iterator[ResponseT]: - cast_to = self._cast_to + def __stream__(self) -> Iterator[_T]: + cast_to = cast(Any, self._cast_to) response = self.response process_data = self._client._process_response_data iterator = self._iter_events() @@ -68,8 +71,27 @@ def __stream__(self) -> Iterator[ResponseT]: for _sse in iterator: ... + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + self.close() + + def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self.response.close() -class AsyncStream(Generic[ResponseT]): + +class AsyncStream(Generic[_T]): """Provides the core interface to iterate over an asynchronous stream response.""" response: httpx.Response @@ -77,7 +99,7 @@ class AsyncStream(Generic[ResponseT]): def __init__( self, *, - cast_to: type[ResponseT], + cast_to: type[_T], response: httpx.Response, client: AsyncOpenAI, ) -> None: @@ -87,10 +109,10 @@ def __init__( self._decoder = SSEDecoder() self._iterator = self.__stream__() - async def __anext__(self) -> ResponseT: + async def __anext__(self) -> _T: return await self._iterator.__anext__() - async def __aiter__(self) -> AsyncIterator[ResponseT]: + async def __aiter__(self) -> AsyncIterator[_T]: async for item in self._iterator: yield item @@ -98,8 +120,8 @@ async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: async for sse in self._decoder.aiter(self.response.aiter_lines()): yield sse - async def __stream__(self) -> AsyncIterator[ResponseT]: - cast_to = self._cast_to + async def __stream__(self) -> AsyncIterator[_T]: + cast_to = cast(Any, self._cast_to) response = self.response process_data = self._client._process_response_data iterator = self._iter_events() @@ -123,6 +145,25 @@ async def __stream__(self) -> AsyncIterator[ResponseT]: async for _sse in iterator: ... + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self.close() + + async def close(self) -> None: + """ + Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self.response.aclose() + class ServerSentEvent: def __init__( diff --git a/src/openai/_types.py b/src/openai/_types.py index 8d543171eb..a20a4b4c1b 100644 --- a/src/openai/_types.py +++ b/src/openai/_types.py @@ -353,3 +353,17 @@ def get(self, __key: str) -> str | None: IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" PostParser = Callable[[Any], Any] + + +@runtime_checkable +class InheritsGeneric(Protocol): + """Represents a type that has inherited from `Generic` + The `__orig_bases__` property can be used to determine the resolved + type variable for a given base class. + """ + + __orig_bases__: tuple[_GenericAlias] + + +class _GenericAlias(Protocol): + __origin__: type[object] diff --git a/src/openai/_utils/__init__.py b/src/openai/_utils/__init__.py index 400ca9b828..a43201d3c7 100644 --- a/src/openai/_utils/__init__.py +++ b/src/openai/_utils/__init__.py @@ -9,13 +9,11 @@ from ._utils import parse_date as parse_date from ._utils import is_sequence as is_sequence from ._utils import coerce_float as coerce_float -from ._utils import is_list_type as is_list_type from ._utils import is_mapping_t as is_mapping_t from ._utils import removeprefix as removeprefix from ._utils import removesuffix as removesuffix from ._utils import extract_files as extract_files from ._utils import is_sequence_t as is_sequence_t -from ._utils import is_union_type as is_union_type from ._utils import required_args as required_args from ._utils import coerce_boolean as coerce_boolean from ._utils import coerce_integer as coerce_integer @@ -23,15 +21,20 @@ from ._utils import parse_datetime as parse_datetime from ._utils import strip_not_given as strip_not_given from ._utils import deepcopy_minimal as deepcopy_minimal -from ._utils import extract_type_arg as extract_type_arg -from ._utils import is_required_type as is_required_type from ._utils import get_async_library as get_async_library -from ._utils import is_annotated_type as is_annotated_type from ._utils import maybe_coerce_float as maybe_coerce_float from ._utils import get_required_header as get_required_header from ._utils import maybe_coerce_boolean as maybe_coerce_boolean from ._utils import maybe_coerce_integer as maybe_coerce_integer -from ._utils import strip_annotated_type as strip_annotated_type +from ._typing import is_list_type as is_list_type +from ._typing import is_union_type as is_union_type +from ._typing import extract_type_arg as extract_type_arg +from ._typing import is_required_type as is_required_type +from ._typing import is_annotated_type as is_annotated_type +from ._typing import strip_annotated_type as strip_annotated_type +from ._typing import extract_type_var_from_base as extract_type_var_from_base +from ._streams import consume_sync_iterator as consume_sync_iterator +from ._streams import consume_async_iterator as consume_async_iterator from ._transform import PropertyInfo as PropertyInfo from ._transform import transform as transform from ._transform import maybe_transform as maybe_transform diff --git a/src/openai/_utils/_streams.py b/src/openai/_utils/_streams.py new file mode 100644 index 0000000000..f4a0208f01 --- /dev/null +++ b/src/openai/_utils/_streams.py @@ -0,0 +1,12 @@ +from typing import Any +from typing_extensions import Iterator, AsyncIterator + + +def consume_sync_iterator(iterator: Iterator[Any]) -> None: + for _ in iterator: + ... + + +async def consume_async_iterator(iterator: AsyncIterator[Any]) -> None: + async for _ in iterator: + ... diff --git a/src/openai/_utils/_transform.py b/src/openai/_utils/_transform.py index 769f7362b9..9117559064 100644 --- a/src/openai/_utils/_transform.py +++ b/src/openai/_utils/_transform.py @@ -6,9 +6,8 @@ import pydantic -from ._utils import ( - is_list, - is_mapping, +from ._utils import is_list, is_mapping +from ._typing import ( is_list_type, is_union_type, extract_type_arg, diff --git a/src/openai/_utils/_typing.py b/src/openai/_utils/_typing.py new file mode 100644 index 0000000000..b5e2c2e397 --- /dev/null +++ b/src/openai/_utils/_typing.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Any, cast +from typing_extensions import Required, Annotated, get_args, get_origin + +from .._types import InheritsGeneric +from .._compat import is_union as _is_union + + +def is_annotated_type(typ: type) -> bool: + return get_origin(typ) == Annotated + + +def is_list_type(typ: type) -> bool: + return (get_origin(typ) or typ) == list + + +def is_union_type(typ: type) -> bool: + return _is_union(get_origin(typ)) + + +def is_required_type(typ: type) -> bool: + return get_origin(typ) == Required + + +# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] +def strip_annotated_type(typ: type) -> type: + if is_required_type(typ) or is_annotated_type(typ): + return strip_annotated_type(cast(type, get_args(typ)[0])) + + return typ + + +def extract_type_arg(typ: type, index: int) -> type: + args = get_args(typ) + try: + return cast(type, args[index]) + except IndexError as err: + raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err + + +def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type: + """Given a type like `Foo[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyResponse(Foo[bytes]): + ... + + extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes + ``` + """ + cls = cast(object, get_origin(typ) or typ) + if cls in generic_bases: + # we're given the class directly + return extract_type_arg(typ, index) + + # if a subclass is given + # --- + # this is needed as __orig_bases__ is not present in the typeshed stubs + # because it is intended to be for internal use only, however there does + # not seem to be a way to resolve generic TypeVars for inherited subclasses + # without using it. + if isinstance(cls, InheritsGeneric): + target_base_class: Any | None = None + for base in cls.__orig_bases__: + if base.__origin__ in generic_bases: + target_base_class = base + break + + if target_base_class is None: + raise RuntimeError( + "Could not find the generic base class;\n" + "This should never happen;\n" + f"Does {cls} inherit from one of {generic_bases} ?" + ) + + return extract_type_arg(target_base_class, index) + + raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/src/openai/_utils/_utils.py b/src/openai/_utils/_utils.py index c874d3682d..993462a66b 100644 --- a/src/openai/_utils/_utils.py +++ b/src/openai/_utils/_utils.py @@ -16,12 +16,11 @@ overload, ) from pathlib import Path -from typing_extensions import Required, Annotated, TypeGuard, get_args, get_origin +from typing_extensions import TypeGuard import sniffio from .._types import Headers, NotGiven, FileTypes, NotGivenOr, HeadersLike -from .._compat import is_union as _is_union from .._compat import parse_date as parse_date from .._compat import parse_datetime as parse_datetime @@ -166,38 +165,6 @@ def is_list(obj: object) -> TypeGuard[list[object]]: return isinstance(obj, list) -def is_annotated_type(typ: type) -> bool: - return get_origin(typ) == Annotated - - -def is_list_type(typ: type) -> bool: - return (get_origin(typ) or typ) == list - - -def is_union_type(typ: type) -> bool: - return _is_union(get_origin(typ)) - - -def is_required_type(typ: type) -> bool: - return get_origin(typ) == Required - - -# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] -def strip_annotated_type(typ: type) -> type: - if is_required_type(typ) or is_annotated_type(typ): - return strip_annotated_type(cast(type, get_args(typ)[0])) - - return typ - - -def extract_type_arg(typ: type, index: int) -> type: - args = get_args(typ) - try: - return cast(type, args[index]) - except IndexError as err: - raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err - - def deepcopy_minimal(item: _T) -> _T: """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: