From aae8a978d05b32fd70eb650dadda491c7b02a4d2 Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Sat, 11 May 2024 13:32:50 +0300 Subject: [PATCH 1/9] Add request_class --- httpx/_client.py | 19 +++++++++++++++++-- tests/client/test_client.py | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/httpx/_client.py b/httpx/_client.py index d95877e8be..f085896914 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -164,6 +164,7 @@ class BaseClient: def __init__( self, *, + request_class: type[Request] | None = None, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -178,6 +179,8 @@ def __init__( ) -> None: event_hooks = {} if event_hooks is None else event_hooks + self._request_class = request_class or Request + self._base_url = self._enforce_trailing_slash(URL(base_url)) self._auth = self._build_auth(auth) @@ -195,6 +198,14 @@ def __init__( self._default_encoding = default_encoding self._state = ClientState.UNOPENED + @property + def request_class(self) -> type[Request]: + return self._request_class + + @request_class.setter + def request_class(self, request_class: type[Request]) -> None: + self._request_class = request_class + @property def is_closed(self) -> bool: """ @@ -356,7 +367,7 @@ def build_request( else Timeout(timeout) ) extensions = dict(**extensions, timeout=timeout.as_dict()) - return Request( + return self.request_class( method, url, content=content, @@ -463,7 +474,7 @@ def _build_redirect_request(self, request: Request, response: Response) -> Reque headers = self._redirect_headers(request, url, method) stream = self._redirect_stream(request, method) cookies = Cookies(self.cookies) - return Request( + return self.request_class( method=method, url=url, headers=headers, @@ -629,6 +640,7 @@ class Client(BaseClient): def __init__( self, *, + request_class: type[Request] | None = None, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -652,6 +664,7 @@ def __init__( default_encoding: str | typing.Callable[[bytes], str] = "utf-8", ) -> None: super().__init__( + request_class=request_class, auth=auth, params=params, headers=headers, @@ -1376,6 +1389,7 @@ class AsyncClient(BaseClient): def __init__( self, *, + request_class: type[Request] | None = None, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -1399,6 +1413,7 @@ def __init__( default_encoding: str | typing.Callable[[bytes], str] = "utf-8", ) -> None: super().__init__( + request_class=request_class, auth=auth, params=params, headers=headers, diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 657839018a..17b618eab7 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -460,3 +460,24 @@ def cp1252_but_no_content_type(request): assert response.reason_phrase == "OK" assert response.encoding == "ISO-8859-1" assert response.text == text + + +def test_client_request_class(): + class Request(httpx.Request): + def __init__(self, *args, **kwargs): + kwargs["content"] = "foobar" + super().__init__(*args, **kwargs) + + class Client(httpx.Client): + request_class = Request + + class AsyncClient(httpx.AsyncClient): + request_class = Request + + request = Client().build_request("GET", "http://www.example.com/") + assert isinstance(request, Request) + assert request.content == b"foobar" + + request = AsyncClient().build_request("GET", "http://www.example.com/") + assert isinstance(request, Request) + assert request.content == b"foobar" From 9a62cb986f5f6c41536814b1f3553b1b28256b1c Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Sat, 11 May 2024 14:20:40 +0300 Subject: [PATCH 2/9] Add response_class --- httpx/_client.py | 24 ++++++++++++++++++++++++ httpx/_transports/default.py | 10 ++++++++-- tests/client/test_client.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 2 deletions(-) diff --git a/httpx/_client.py b/httpx/_client.py index f085896914..d7c106639a 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -641,6 +641,7 @@ def __init__( self, *, request_class: type[Request] | None = None, + response_class: type[Response] | None = None, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -706,6 +707,8 @@ def __init__( allow_env_proxies = trust_env and app is None and transport is None proxy_map = self._get_proxy_map(proxies or proxy, allow_env_proxies) + self._response_class = response_class or Response + self._transport = self._init_transport( verify=verify, cert=cert, @@ -737,6 +740,14 @@ def __init__( self._mounts = dict(sorted(self._mounts.items())) + @property + def response_class(self) -> type[Response]: + return self._response_class + + @response_class.setter + def response_class(self, response_class: type[Response]) -> None: + self._response_class = response_class + def _init_transport( self, verify: VerifyTypes = True, @@ -761,6 +772,7 @@ def _init_transport( http2=http2, limits=limits, trust_env=trust_env, + response_class=self.response_class, ) def _init_proxy_transport( @@ -1390,6 +1402,7 @@ def __init__( self, *, request_class: type[Request] | None = None, + response_class: type[Response] | None = None, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -1455,6 +1468,8 @@ def __init__( allow_env_proxies = trust_env and app is None and transport is None proxy_map = self._get_proxy_map(proxies or proxy, allow_env_proxies) + self._response_class = response_class or Response + self._transport = self._init_transport( verify=verify, cert=cert, @@ -1486,6 +1501,14 @@ def __init__( ) self._mounts = dict(sorted(self._mounts.items())) + @property + def response_class(self) -> type[Response]: + return self._response_class + + @response_class.setter + def response_class(self, response_class: type[Response]) -> None: + self._response_class = response_class + def _init_transport( self, verify: VerifyTypes = True, @@ -1510,6 +1533,7 @@ def _init_transport( http2=http2, limits=limits, trust_env=trust_env, + response_class=self.response_class, ) def _init_proxy_transport( diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index 33db416dd1..551440d62c 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -135,6 +135,7 @@ def __init__( local_address: str | None = None, retries: int = 0, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + response_class: type[Response] | None = None, ) -> None: ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy @@ -201,6 +202,8 @@ def __init__( f" but got {proxy.url.scheme!r}." ) + self._response_class = response_class or Response + def __enter__(self: T) -> T: # Use generics for subclass support. self._pool.__enter__() return self @@ -237,7 +240,7 @@ def handle_request( assert isinstance(resp.stream, typing.Iterable) - return Response( + return self._response_class( status_code=resp.status, headers=resp.headers, stream=ResponseStream(resp.stream), @@ -276,6 +279,7 @@ def __init__( local_address: str | None = None, retries: int = 0, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, + response_class: type[Response] | None = None, ) -> None: ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy @@ -342,6 +346,8 @@ def __init__( " but got {proxy.url.scheme!r}." ) + self._response_class = response_class or Response + async def __aenter__(self: A) -> A: # Use generics for subclass support. await self._pool.__aenter__() return self @@ -378,7 +384,7 @@ async def handle_async_request( assert isinstance(resp.stream, typing.AsyncIterable) - return Response( + return self._response_class( status_code=resp.status, headers=resp.headers, stream=AsyncResponseStream(resp.stream), diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 17b618eab7..2bfbae6970 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -481,3 +481,32 @@ class AsyncClient(httpx.AsyncClient): request = AsyncClient().build_request("GET", "http://www.example.com/") assert isinstance(request, Request) assert request.content == b"foobar" + + +@pytest.mark.anyio +async def test_client_response_class(server): + class Response(httpx.Response): + def iter_bytes(self, chunk_size: int | None = None) -> typing.Iterator[bytes]: + yield b"foobar" + + class Client(httpx.Client): + response_class = Response + + class AsyncResponse(httpx.Response): + async def aiter_bytes( + self, chunk_size: int | None = None + ) -> typing.AsyncIterator[bytes]: + yield b"foobar" + + class AsyncClient(httpx.AsyncClient): + response_class = AsyncResponse + + with Client() as client: + response = client.get(server.url) + assert isinstance(response, Response) + assert response.content == b"foobar" + + async with AsyncClient() as async_client: + response = await async_client.get(server.url) + assert isinstance(response, AsyncResponse) + assert await response.aread() == b"foobar" From 8045218e61bf158018ad9691d125df86d5d66606 Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Sat, 11 May 2024 15:22:52 +0300 Subject: [PATCH 3/9] fixup! Add request_class --- tests/client/test_client.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2bfbae6970..2f4db71d7a 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -482,6 +482,12 @@ class AsyncClient(httpx.AsyncClient): assert isinstance(request, Request) assert request.content == b"foobar" + client = httpx.Client() + client.request_class = Request + request = client.build_request("GET", "http://www.example.com/") + assert isinstance(request, Request) + assert request.content == b"foobar" + @pytest.mark.anyio async def test_client_response_class(server): From 1fa7e160f7a278d147307de5ee3d827d0e0f841a Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Sat, 11 May 2024 15:29:23 +0300 Subject: [PATCH 4/9] fixup! Add response_class --- httpx/_client.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/httpx/_client.py b/httpx/_client.py index d7c106639a..83ccfe7e8f 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -744,10 +744,6 @@ def __init__( def response_class(self) -> type[Response]: return self._response_class - @response_class.setter - def response_class(self, response_class: type[Response]) -> None: - self._response_class = response_class - def _init_transport( self, verify: VerifyTypes = True, @@ -793,6 +789,7 @@ def _init_proxy_transport( limits=limits, trust_env=trust_env, proxy=proxy, + response_class=self.response_class, ) def _transport_for_url(self, url: URL) -> BaseTransport: @@ -1505,10 +1502,6 @@ def __init__( def response_class(self) -> type[Response]: return self._response_class - @response_class.setter - def response_class(self, response_class: type[Response]) -> None: - self._response_class = response_class - def _init_transport( self, verify: VerifyTypes = True, @@ -1554,6 +1547,7 @@ def _init_proxy_transport( limits=limits, trust_env=trust_env, proxy=proxy, + response_class=self.response_class, ) def _transport_for_url(self, url: URL) -> AsyncBaseTransport: From 1e375db608809785286b09740e7573baed381558 Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Sat, 11 May 2024 15:37:56 +0300 Subject: [PATCH 5/9] Make request_class read-only --- httpx/_client.py | 4 ---- tests/client/test_client.py | 16 ++++++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/httpx/_client.py b/httpx/_client.py index 83ccfe7e8f..d1bfe95911 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -202,10 +202,6 @@ def __init__( def request_class(self) -> type[Request]: return self._request_class - @request_class.setter - def request_class(self, request_class: type[Request]) -> None: - self._request_class = request_class - @property def is_closed(self) -> bool: """ diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 2f4db71d7a..cecee778ad 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -482,11 +482,10 @@ class AsyncClient(httpx.AsyncClient): assert isinstance(request, Request) assert request.content == b"foobar" - client = httpx.Client() - client.request_class = Request - request = client.build_request("GET", "http://www.example.com/") - assert isinstance(request, Request) - assert request.content == b"foobar" + with httpx.Client(request_class=Request) as client: + request = client.build_request("GET", "http://www.example.com/") + assert isinstance(request, Request) + assert request.content == b"foobar" @pytest.mark.anyio @@ -510,9 +509,14 @@ class AsyncClient(httpx.AsyncClient): with Client() as client: response = client.get(server.url) assert isinstance(response, Response) - assert response.content == b"foobar" + assert response.read() == b"foobar" async with AsyncClient() as async_client: response = await async_client.get(server.url) assert isinstance(response, AsyncResponse) assert await response.aread() == b"foobar" + + with httpx.Client(response_class=Response) as httpx_client: + response = httpx_client.get(server.url) + assert isinstance(response, Response) + assert response.read() == b"foobar" From 3dd628900e2ccf4033dd635b663aa4616d4b4142 Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Mon, 23 Sep 2024 16:08:40 +0300 Subject: [PATCH 6/9] Move response_class to BaseClient --- httpx/_client.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/httpx/_client.py b/httpx/_client.py index d1bfe95911..2710ad4162 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -164,7 +164,8 @@ class BaseClient: def __init__( self, *, - request_class: type[Request] | None = None, + request_class: type[Request] = Request, + response_class: type[Response] = Response, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -179,7 +180,8 @@ def __init__( ) -> None: event_hooks = {} if event_hooks is None else event_hooks - self._request_class = request_class or Request + self._request_class = request_class + self._response_class = response_class self._base_url = self._enforce_trailing_slash(URL(base_url)) @@ -202,6 +204,10 @@ def __init__( def request_class(self) -> type[Request]: return self._request_class + @property + def response_class(self) -> type[Response]: + return self._response_class + @property def is_closed(self) -> bool: """ @@ -636,8 +642,8 @@ class Client(BaseClient): def __init__( self, *, - request_class: type[Request] | None = None, - response_class: type[Response] | None = None, + request_class: type[Request] = Request, + response_class: type[Response] = Response, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -662,6 +668,7 @@ def __init__( ) -> None: super().__init__( request_class=request_class, + response_class=response_class, auth=auth, params=params, headers=headers, @@ -703,8 +710,6 @@ def __init__( allow_env_proxies = trust_env and app is None and transport is None proxy_map = self._get_proxy_map(proxies or proxy, allow_env_proxies) - self._response_class = response_class or Response - self._transport = self._init_transport( verify=verify, cert=cert, @@ -736,10 +741,6 @@ def __init__( self._mounts = dict(sorted(self._mounts.items())) - @property - def response_class(self) -> type[Response]: - return self._response_class - def _init_transport( self, verify: VerifyTypes = True, @@ -1394,8 +1395,8 @@ class AsyncClient(BaseClient): def __init__( self, *, - request_class: type[Request] | None = None, - response_class: type[Response] | None = None, + request_class: type[Request] = Request, + response_class: type[Response] = Response, auth: AuthTypes | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, @@ -1420,6 +1421,7 @@ def __init__( ) -> None: super().__init__( request_class=request_class, + response_class=response_class, auth=auth, params=params, headers=headers, @@ -1461,8 +1463,6 @@ def __init__( allow_env_proxies = trust_env and app is None and transport is None proxy_map = self._get_proxy_map(proxies or proxy, allow_env_proxies) - self._response_class = response_class or Response - self._transport = self._init_transport( verify=verify, cert=cert, @@ -1494,10 +1494,6 @@ def __init__( ) self._mounts = dict(sorted(self._mounts.items())) - @property - def response_class(self) -> type[Response]: - return self._response_class - def _init_transport( self, verify: VerifyTypes = True, From 8d3e7d291acfa48ba1087aa656a097e939627bfc Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Mon, 23 Sep 2024 16:10:26 +0300 Subject: [PATCH 7/9] default response_class = Response --- httpx/_transports/default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index 551440d62c..7752de49b1 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -135,7 +135,7 @@ def __init__( local_address: str | None = None, retries: int = 0, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, - response_class: type[Response] | None = None, + response_class: type[Response] = Response, ) -> None: ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy @@ -202,7 +202,7 @@ def __init__( f" but got {proxy.url.scheme!r}." ) - self._response_class = response_class or Response + self._response_class = response_class def __enter__(self: T) -> T: # Use generics for subclass support. self._pool.__enter__() From 73b29379b9aba9e06f2c632a7f3daa079716649a Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Mon, 23 Sep 2024 16:11:12 +0300 Subject: [PATCH 8/9] default response_class = Response --- httpx/_transports/default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index 7752de49b1..9aa304fd8c 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -279,7 +279,7 @@ def __init__( local_address: str | None = None, retries: int = 0, socket_options: typing.Iterable[SOCKET_OPTION] | None = None, - response_class: type[Response] | None = None, + response_class: type[Response] = Response, ) -> None: ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env) proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy @@ -346,7 +346,7 @@ def __init__( " but got {proxy.url.scheme!r}." ) - self._response_class = response_class or Response + self._response_class = response_class async def __aenter__(self: A) -> A: # Use generics for subclass support. await self._pool.__aenter__() From 8f763f924b205ded31e28c79a2cc3169eb4dfe6c Mon Sep 17 00:00:00 2001 From: q0w <43147888+q0w@users.noreply.github.com> Date: Tue, 24 Sep 2024 10:04:02 +0300 Subject: [PATCH 9/9] typing self workaround --- httpx/_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/httpx/_models.py b/httpx/_models.py index 01d9583bc5..f5eb057e4f 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -446,6 +446,9 @@ def __setstate__(self, state: dict[str, typing.Any]) -> None: self.stream = UnattachedStream() +_ResponseT = typing.TypeVar("_ResponseT", bound="Response") + + class Response: def __init__( self, @@ -725,7 +728,7 @@ def has_redirect_location(self) -> bool: and "Location" in self.headers ) - def raise_for_status(self) -> Response: + def raise_for_status(self: _ResponseT) -> _ResponseT: """ Raise the `HTTPStatusError` if one occurred. """