Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide request_class, response_class for httpx.Client #3199

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class BaseClient:
def __init__(
self,
*,
request_class: type[Request] | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably also response_class, right?
Also I guess move this to the end of the parameters.

Copy link
Author

@q0w q0w May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably also response_class, right?

I thought about that, but baseclient has no logic related with responses, so response_class is now added only in Client for passing to transports. I can add response_class, if needed.

auth: AuthTypes | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
Expand All @@ -178,6 +179,8 @@ def __init__(
) -> None:
event_hooks = {} if event_hooks is None else event_hooks

self._request_class = request_class or Request

self._base_url = self._enforce_trailing_slash(URL(base_url))

self._auth = self._build_auth(auth)
Expand All @@ -195,6 +198,10 @@ def __init__(
self._default_encoding = default_encoding
self._state = ClientState.UNOPENED

@property
def request_class(self) -> type[Request]:
return self._request_class

@property
def is_closed(self) -> bool:
"""
Expand Down Expand Up @@ -356,7 +363,7 @@ def build_request(
else Timeout(timeout)
)
extensions = dict(**extensions, timeout=timeout.as_dict())
return Request(
return self.request_class(
method,
url,
content=content,
Expand Down Expand Up @@ -463,7 +470,7 @@ def _build_redirect_request(self, request: Request, response: Response) -> Reque
headers = self._redirect_headers(request, url, method)
stream = self._redirect_stream(request, method)
cookies = Cookies(self.cookies)
return Request(
return self.request_class(
method=method,
url=url,
headers=headers,
Expand Down Expand Up @@ -629,6 +636,8 @@ class Client(BaseClient):
def __init__(
self,
*,
request_class: type[Request] | None = None,
response_class: type[Response] | None = None,
auth: AuthTypes | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
Expand All @@ -652,6 +661,7 @@ def __init__(
default_encoding: str | typing.Callable[[bytes], str] = "utf-8",
) -> None:
super().__init__(
request_class=request_class,
auth=auth,
params=params,
headers=headers,
Expand Down Expand Up @@ -693,6 +703,8 @@ def __init__(
allow_env_proxies = trust_env and app is None and transport is None
proxy_map = self._get_proxy_map(proxies or proxy, allow_env_proxies)

self._response_class = response_class or Response

self._transport = self._init_transport(
verify=verify,
cert=cert,
Expand Down Expand Up @@ -724,6 +736,10 @@ def __init__(

self._mounts = dict(sorted(self._mounts.items()))

@property
def response_class(self) -> type[Response]:
return self._response_class

def _init_transport(
self,
verify: VerifyTypes = True,
Expand All @@ -748,6 +764,7 @@ def _init_transport(
http2=http2,
limits=limits,
trust_env=trust_env,
response_class=self.response_class,
)

def _init_proxy_transport(
Expand All @@ -768,6 +785,7 @@ def _init_proxy_transport(
limits=limits,
trust_env=trust_env,
proxy=proxy,
response_class=self.response_class,
)

def _transport_for_url(self, url: URL) -> BaseTransport:
Expand Down Expand Up @@ -1376,6 +1394,8 @@ class AsyncClient(BaseClient):
def __init__(
self,
*,
request_class: type[Request] | None = None,
response_class: type[Response] | None = None,
auth: AuthTypes | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
Expand All @@ -1399,6 +1419,7 @@ def __init__(
default_encoding: str | typing.Callable[[bytes], str] = "utf-8",
) -> None:
super().__init__(
request_class=request_class,
auth=auth,
params=params,
headers=headers,
Expand Down Expand Up @@ -1440,6 +1461,8 @@ def __init__(
allow_env_proxies = trust_env and app is None and transport is None
proxy_map = self._get_proxy_map(proxies or proxy, allow_env_proxies)

self._response_class = response_class or Response

self._transport = self._init_transport(
verify=verify,
cert=cert,
Expand Down Expand Up @@ -1471,6 +1494,10 @@ def __init__(
)
self._mounts = dict(sorted(self._mounts.items()))

@property
def response_class(self) -> type[Response]:
return self._response_class

def _init_transport(
self,
verify: VerifyTypes = True,
Expand All @@ -1495,6 +1522,7 @@ def _init_transport(
http2=http2,
limits=limits,
trust_env=trust_env,
response_class=self.response_class,
)

def _init_proxy_transport(
Expand All @@ -1515,6 +1543,7 @@ def _init_proxy_transport(
limits=limits,
trust_env=trust_env,
proxy=proxy,
response_class=self.response_class,
)

def _transport_for_url(self, url: URL) -> AsyncBaseTransport:
Expand Down
10 changes: 8 additions & 2 deletions httpx/_transports/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
response_class: type[Response] | None = None,
) -> None:
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
Expand Down Expand Up @@ -201,6 +202,8 @@ def __init__(
f" but got {proxy.url.scheme!r}."
)

self._response_class = response_class or Response

def __enter__(self: T) -> T: # Use generics for subclass support.
self._pool.__enter__()
return self
Expand Down Expand Up @@ -237,7 +240,7 @@ def handle_request(

assert isinstance(resp.stream, typing.Iterable)

return Response(
return self._response_class(
status_code=resp.status,
headers=resp.headers,
stream=ResponseStream(resp.stream),
Expand Down Expand Up @@ -276,6 +279,7 @@ def __init__(
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
response_class: type[Response] | None = None,
) -> None:
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
Expand Down Expand Up @@ -342,6 +346,8 @@ def __init__(
" but got {proxy.url.scheme!r}."
)

self._response_class = response_class or Response

async def __aenter__(self: A) -> A: # Use generics for subclass support.
await self._pool.__aenter__()
return self
Expand Down Expand Up @@ -378,7 +384,7 @@ async def handle_async_request(

assert isinstance(resp.stream, typing.AsyncIterable)

return Response(
return self._response_class(
status_code=resp.status,
headers=resp.headers,
stream=AsyncResponseStream(resp.stream),
Expand Down
60 changes: 60 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,63 @@ def cp1252_but_no_content_type(request):
assert response.reason_phrase == "OK"
assert response.encoding == "ISO-8859-1"
assert response.text == text


def test_client_request_class():
class Request(httpx.Request):
def __init__(self, *args, **kwargs):
kwargs["content"] = "foobar"
super().__init__(*args, **kwargs)

class Client(httpx.Client):
request_class = Request

class AsyncClient(httpx.AsyncClient):
request_class = Request

request = Client().build_request("GET", "http://www.example.com/")
assert isinstance(request, Request)
assert request.content == b"foobar"

request = AsyncClient().build_request("GET", "http://www.example.com/")
assert isinstance(request, Request)
assert request.content == b"foobar"

with httpx.Client(request_class=Request) as client:
request = client.build_request("GET", "http://www.example.com/")
assert isinstance(request, Request)
assert request.content == b"foobar"


@pytest.mark.anyio
async def test_client_response_class(server):
class Response(httpx.Response):
def iter_bytes(self, chunk_size: int | None = None) -> typing.Iterator[bytes]:
yield b"foobar"

class Client(httpx.Client):
response_class = Response

class AsyncResponse(httpx.Response):
async def aiter_bytes(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
yield b"foobar"

class AsyncClient(httpx.AsyncClient):
response_class = AsyncResponse

with Client() as client:
response = client.get(server.url)
assert isinstance(response, Response)
assert response.read() == b"foobar"

async with AsyncClient() as async_client:
response = await async_client.get(server.url)
assert isinstance(response, AsyncResponse)
assert await response.aread() == b"foobar"

with httpx.Client(response_class=Response) as httpx_client:
response = httpx_client.get(server.url)
assert isinstance(response, Response)
assert response.read() == b"foobar"