diff --git a/src/openai/__init__.py b/src/openai/__init__.py index 0de58b332..118fe8ee9 100644 --- a/src/openai/__init__.py +++ b/src/openai/__init__.py @@ -9,6 +9,7 @@ from ._types import NoneType, Transport, ProxiesTypes from ._utils import file_from_path from ._client import Client, OpenAI, Stream, Timeout, Transport, AsyncClient, AsyncOpenAI, AsyncStream, RequestOptions +from ._models import BaseModel from ._version import __title__, __version__ from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse from ._exceptions import ( @@ -59,6 +60,7 @@ "OpenAI", "AsyncOpenAI", "file_from_path", + "BaseModel", ] from .lib import azure as _azure diff --git a/src/openai/_legacy_response.py b/src/openai/_legacy_response.py index c36c94f16..6eaa691d9 100644 --- a/src/openai/_legacy_response.py +++ b/src/openai/_legacy_response.py @@ -5,25 +5,28 @@ import logging import datetime import functools -from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast -from typing_extensions import Awaitable, ParamSpec, get_args, override, deprecated, get_origin +from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, Iterator, AsyncIterator, cast, overload +from typing_extensions import Awaitable, ParamSpec, override, deprecated, get_origin import anyio import httpx +import pydantic from ._types import NoneType from ._utils import is_given from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER +from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type from ._exceptions import APIResponseValidationError if TYPE_CHECKING: from ._models import FinalRequestOptions - from ._base_client import Stream, BaseClient, AsyncStream + from ._base_client import BaseClient P = ParamSpec("P") R = TypeVar("R") +_T = TypeVar("_T") log: logging.Logger = logging.getLogger(__name__) @@ -43,7 +46,7 @@ class LegacyAPIResponse(Generic[R]): _cast_to: type[R] _client: BaseClient[Any, Any] - _parsed: R | None + _parsed_by_type: dict[type[Any], Any] _stream: bool _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None _options: FinalRequestOptions @@ -62,27 +65,60 @@ def __init__( ) -> None: self._cast_to = cast_to self._client = client - self._parsed = None + self._parsed_by_type = {} self._stream = stream self._stream_cls = stream_cls self._options = options self.http_response = raw + @overload + def parse(self, *, to: type[_T]) -> _T: + ... + + @overload def parse(self) -> R: + ... + + def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. + NOTE: For the async client: this will become a coroutine in the next major version. + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. - NOTE: For the async client: this will become a coroutine in the next major version. + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from openai import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `httpx.Response` """ - if self._parsed is not None: - return self._parsed + cache_key = to if to is not None else self._cast_to + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] - parsed = self._parse() + parsed = self._parse(to=to) if is_given(self._options.post_parser): parsed = self._options.post_parser(parsed) - self._parsed = parsed + self._parsed_by_type[cache_key] = parsed return parsed @property @@ -135,13 +171,29 @@ def elapsed(self) -> datetime.timedelta: """The time taken for the complete request/response cycle to complete.""" return self.http_response.elapsed - def _parse(self) -> R: + def _parse(self, *, to: type[_T] | None = None) -> R | _T: if self._stream: + if to: + if not is_stream_class_type(to): + raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}") + + return cast( + _T, + to( + cast_to=extract_stream_chunk_type( + to, + failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]", + ), + response=self.http_response, + client=cast(Any, self._client), + ), + ) + if self._stream_cls: return cast( R, self._stream_cls( - cast_to=_extract_stream_chunk_type(self._stream_cls), + cast_to=extract_stream_chunk_type(self._stream_cls), response=self.http_response, client=cast(Any, self._client), ), @@ -160,7 +212,7 @@ def _parse(self) -> R: ), ) - cast_to = self._cast_to + cast_to = to if to is not None else self._cast_to if cast_to is NoneType: return cast(R, None) @@ -186,14 +238,9 @@ def _parse(self) -> R: raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`") return cast(R, response) - # The check here is necessary as we are subverting the the type system - # with casts as the relationship between TypeVars and Types are very strict - # which means we must return *exactly* what was input or transform it in a - # way that retains the TypeVar state. As we cannot do that in this function - # then we have to resort to using `cast`. At the time of writing, we know this - # to be safe as we have handled all the types that could be bound to the - # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then - # this function would become unsafe but a type checker would not report an error. + if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): + raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`") + if ( cast_to is not object and not origin is list @@ -202,12 +249,12 @@ def _parse(self) -> R: and not issubclass(origin, BaseModel) ): raise RuntimeError( - f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}." + f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." ) # split is required to handle cases where additional information is included # in the response, e.g. application/json; charset=utf-8 - content_type, *_ = response.headers.get("content-type").split(";") + content_type, *_ = response.headers.get("content-type", "*").split(";") if content_type != "application/json": if is_basemodel(cast_to): try: @@ -253,15 +300,6 @@ def __init__(self) -> None: ) -def _extract_stream_chunk_type(stream_cls: type) -> type: - args = get_args(stream_cls) - if not args: - raise TypeError( - f"Expected stream_cls to have been given a generic type argument, e.g. Stream[Foo] but received {stream_cls}", - ) - return cast(type, args[0]) - - def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]: """Higher order function that takes one of our bound API methods and wraps it to support returning the raw `APIResponse` object directly. diff --git a/src/openai/_response.py b/src/openai/_response.py index 15a323afa..b1e070122 100644 --- a/src/openai/_response.py +++ b/src/openai/_response.py @@ -16,25 +16,29 @@ Iterator, AsyncIterator, cast, + overload, ) from typing_extensions import Awaitable, ParamSpec, override, get_origin import anyio import httpx +import pydantic from ._types import NoneType from ._utils import is_given, extract_type_var_from_base from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER +from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type from ._exceptions import OpenAIError, APIResponseValidationError if TYPE_CHECKING: from ._models import FinalRequestOptions - from ._base_client import Stream, BaseClient, AsyncStream + from ._base_client import BaseClient P = ParamSpec("P") R = TypeVar("R") +_T = TypeVar("_T") _APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") _AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]") @@ -44,7 +48,7 @@ class BaseAPIResponse(Generic[R]): _cast_to: type[R] _client: BaseClient[Any, Any] - _parsed: R | None + _parsed_by_type: dict[type[Any], Any] _is_sse_stream: bool _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None _options: FinalRequestOptions @@ -63,7 +67,7 @@ def __init__( ) -> None: self._cast_to = cast_to self._client = client - self._parsed = None + self._parsed_by_type = {} self._is_sse_stream = stream self._stream_cls = stream_cls self._options = options @@ -116,8 +120,24 @@ def __repr__(self) -> str: f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>" ) - def _parse(self) -> R: + def _parse(self, *, to: type[_T] | None = None) -> R | _T: if self._is_sse_stream: + if to: + if not is_stream_class_type(to): + raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}") + + return cast( + _T, + to( + cast_to=extract_stream_chunk_type( + to, + failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]", + ), + response=self.http_response, + client=cast(Any, self._client), + ), + ) + if self._stream_cls: return cast( R, @@ -141,7 +161,7 @@ def _parse(self) -> R: ), ) - cast_to = self._cast_to + cast_to = to if to is not None else self._cast_to if cast_to is NoneType: return cast(R, None) @@ -171,14 +191,9 @@ def _parse(self) -> R: raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`") return cast(R, response) - # The check here is necessary as we are subverting the the type system - # with casts as the relationship between TypeVars and Types are very strict - # which means we must return *exactly* what was input or transform it in a - # way that retains the TypeVar state. As we cannot do that in this function - # then we have to resort to using `cast`. At the time of writing, we know this - # to be safe as we have handled all the types that could be bound to the - # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then - # this function would become unsafe but a type checker would not report an error. + if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): + raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`") + if ( cast_to is not object and not origin is list @@ -187,12 +202,12 @@ def _parse(self) -> R: and not issubclass(origin, BaseModel) ): raise RuntimeError( - f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}." + f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." ) # split is required to handle cases where additional information is included # in the response, e.g. application/json; charset=utf-8 - content_type, *_ = response.headers.get("content-type").split(";") + content_type, *_ = response.headers.get("content-type", "*").split(";") if content_type != "application/json": if is_basemodel(cast_to): try: @@ -228,22 +243,55 @@ def _parse(self) -> R: class APIResponse(BaseAPIResponse[R]): + @overload + def parse(self, *, to: type[_T]) -> _T: + ... + + @overload def parse(self) -> R: + ... + + def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from openai import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `httpx.Response` """ - if self._parsed is not None: - return self._parsed + cache_key = to if to is not None else self._cast_to + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] if not self._is_sse_stream: self.read() - parsed = self._parse() + parsed = self._parse(to=to) if is_given(self._options.post_parser): parsed = self._options.post_parser(parsed) - self._parsed = parsed + self._parsed_by_type[cache_key] = parsed return parsed def read(self) -> bytes: @@ -297,22 +345,55 @@ def iter_lines(self) -> Iterator[str]: class AsyncAPIResponse(BaseAPIResponse[R]): + @overload + async def parse(self, *, to: type[_T]) -> _T: + ... + + @overload async def parse(self) -> R: + ... + + async def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from openai import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `httpx.Response` """ - if self._parsed is not None: - return self._parsed + cache_key = to if to is not None else self._cast_to + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] if not self._is_sse_stream: await self.read() - parsed = self._parse() + parsed = self._parse(to=to) if is_given(self._options.post_parser): parsed = self._options.post_parser(parsed) - self._parsed = parsed + self._parsed_by_type[cache_key] = parsed return parsed async def read(self) -> bytes: @@ -708,26 +789,6 @@ def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]: return wrapped -def extract_stream_chunk_type(stream_cls: type) -> type: - """Given a type like `Stream[T]`, returns the generic type variable `T`. - - This also handles the case where a concrete subclass is given, e.g. - ```py - class MyStream(Stream[bytes]): - ... - - extract_stream_chunk_type(MyStream) -> bytes - ``` - """ - from ._base_client import Stream, AsyncStream - - return extract_type_var_from_base( - stream_cls, - index=0, - generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), - ) - - def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: """Given a type like `APIResponse[T]`, returns the generic type variable `T`. diff --git a/src/openai/_streaming.py b/src/openai/_streaming.py index 85cec70c1..74878fd0a 100644 --- a/src/openai/_streaming.py +++ b/src/openai/_streaming.py @@ -2,13 +2,14 @@ from __future__ import annotations import json +import inspect from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast -from typing_extensions import Self, override +from typing_extensions import Self, TypeGuard, override, get_origin import httpx -from ._utils import is_mapping +from ._utils import is_mapping, extract_type_var_from_base from ._exceptions import APIError if TYPE_CHECKING: @@ -281,3 +282,34 @@ def decode(self, line: str) -> ServerSentEvent | None: pass # Field is ignored. return None + + +def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: + """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" + origin = get_origin(typ) or typ + return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream)) + + +def extract_stream_chunk_type( + stream_cls: type, + *, + failure_message: str | None = None, +) -> type: + """Given a type like `Stream[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyStream(Stream[bytes]): + ... + + extract_stream_chunk_type(MyStream) -> bytes + ``` + """ + from ._base_client import Stream, AsyncStream + + return extract_type_var_from_base( + stream_cls, + index=0, + generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), + failure_message=failure_message, + ) diff --git a/src/openai/_utils/_typing.py b/src/openai/_utils/_typing.py index a020822bc..c1d1ebb9a 100644 --- a/src/openai/_utils/_typing.py +++ b/src/openai/_utils/_typing.py @@ -45,7 +45,13 @@ def extract_type_arg(typ: type, index: int) -> type: raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err -def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type: +def extract_type_var_from_base( + typ: type, + *, + generic_bases: tuple[type, ...], + index: int, + failure_message: str | None = None, +) -> type: """Given a type like `Foo[T]`, returns the generic type variable `T`. This also handles the case where a concrete subclass is given, e.g. @@ -104,4 +110,4 @@ class MyResponse(Foo[_T]): return extracted - raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}") + raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/tests/test_legacy_response.py b/tests/test_legacy_response.py new file mode 100644 index 000000000..995250a58 --- /dev/null +++ b/tests/test_legacy_response.py @@ -0,0 +1,65 @@ +import json + +import httpx +import pytest +import pydantic + +from openai import OpenAI, BaseModel +from openai._streaming import Stream +from openai._base_client import FinalRequestOptions +from openai._legacy_response import LegacyAPIResponse + + +class PydanticModel(pydantic.BaseModel): + ... + + +def test_response_parse_mismatched_basemodel(client: OpenAI) -> None: + response = LegacyAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`", + ): + response.parse(to=PydanticModel) + + +def test_response_parse_custom_stream(client: OpenAI) -> None: + response = LegacyAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = response.parse(to=Stream[int]) + assert stream._cast_to == int + + +class CustomModel(BaseModel): + foo: str + bar: int + + +def test_response_parse_custom_model(client: OpenAI) -> None: + response = LegacyAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2 diff --git a/tests/test_response.py b/tests/test_response.py index 335ca7922..7c99327b4 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,8 +1,11 @@ +import json from typing import List import httpx import pytest +import pydantic +from openai import OpenAI, BaseModel, AsyncOpenAI from openai._response import ( APIResponse, BaseAPIResponse, @@ -11,6 +14,8 @@ AsyncBinaryAPIResponse, extract_response_type, ) +from openai._streaming import Stream +from openai._base_client import FinalRequestOptions class ConcreteBaseAPIResponse(APIResponse[bytes]): @@ -48,3 +53,107 @@ def test_extract_response_type_concrete_subclasses() -> None: def test_extract_response_type_binary_response() -> None: assert extract_response_type(BinaryAPIResponse) == bytes assert extract_response_type(AsyncBinaryAPIResponse) == bytes + + +class PydanticModel(pydantic.BaseModel): + ... + + +def test_response_parse_mismatched_basemodel(client: OpenAI) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`", + ): + response.parse(to=PydanticModel) + + +@pytest.mark.asyncio +async def test_async_response_parse_mismatched_basemodel(async_client: AsyncOpenAI) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`", + ): + await response.parse(to=PydanticModel) + + +def test_response_parse_custom_stream(client: OpenAI) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = response.parse(to=Stream[int]) + assert stream._cast_to == int + + +@pytest.mark.asyncio +async def test_async_response_parse_custom_stream(async_client: AsyncOpenAI) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=async_client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = await response.parse(to=Stream[int]) + assert stream._cast_to == int + + +class CustomModel(BaseModel): + foo: str + bar: int + + +def test_response_parse_custom_model(client: OpenAI) -> None: + response = APIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +@pytest.mark.asyncio +async def test_async_response_parse_custom_model(async_client: AsyncOpenAI) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2