diff --git a/CHANGELOG.md b/CHANGELOG.md index 992f4e4e38..fa4d61112a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,14 +4,31 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). -## 0.17.1 +## Master + +The 0.18.x release series formalises our low-level Transport API, introducing the +base classes `httpx.BaseTransport` and `httpx.AsyncBaseTransport`. + +See the "Writing custom transports" documentation and the `httpx.BaseTransport.handle_request()` +docstring for more complete details on implementing custom transports. + +Pull request #1522 includes a checklist of differences from the previous `httpcore` transport API, +for developers implementing custom transports. + +### Changed + +* Transport instances now inherit from `httpx.BaseTransport` or `httpx.AsyncBaseTransport`, + and should implement either the `handle_request` method or `handle_async_request` method. +* The `response.ext` property and `Response(ext=...)` argument are now named `extensions`. + +## 0.17.1 (March 15th, 2021) ### Fixed * Type annotation on `CertTypes` allows `keyfile` and `password` to be optional. (Pull #1503) * Fix httpcore pinned version. (Pull #1495) -## 0.17.0 +## 0.17.0 (February 28th, 2021) ### Added diff --git a/docs/advanced.md b/docs/advanced.md index 61bf4c1938..0b31b47855 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -1015,31 +1015,39 @@ This [public gist](https://gist.github.com/florimondmanca/d56764d78d748eb9f73165 ### Writing custom transports -A transport instance must implement the Transport API defined by -[`httpcore`](https://www.encode.io/httpcore/api/). You -should either subclass `httpcore.AsyncHTTPTransport` to implement a transport to -use with `AsyncClient`, or subclass `httpcore.SyncHTTPTransport` to implement a -transport to use with `Client`. +A transport instance must implement the low-level Transport API, which deals +with sending a single request, and returning a response. You should either +subclass `httpx.BaseTransport` to implement a transport to use with `Client`, +or subclass `httpx.AsyncBaseTransport` to implement a transport to +use with `AsyncClient`. + +At the layer of the transport API we're using plain primitives. +No `Request` or `Response` models, no fancy `URL` or `Header` handling. +This strict point of cut-off provides a clear design separation between the +HTTPX API, and the low-level network handling. + +See the `handle_request` and `handle_async_request` docstrings for more details +on the specifics of the Transport API. A complete example of a custom transport implementation would be: ```python import json -import httpcore +import httpx -class HelloWorldTransport(httpcore.SyncHTTPTransport): +class HelloWorldTransport(httpx.BaseTransport): """ A mock transport that always returns a JSON "Hello, world!" response. """ - def request(self, method, url, headers=None, stream=None, ext=None): + def handle_request(self, method, url, headers, stream, extensions): message = {"text": "Hello, world!"} content = json.dumps(message).encode("utf-8") - stream = httpcore.PlainByteStream(content) + stream = [content] headers = [(b"content-type", b"application/json")] - ext = {"http_version": b"HTTP/1.1"} - return 200, headers, stream, ext + extensions = {} + return 200, headers, stream, extensions ``` Which we can use in the same way: @@ -1084,24 +1092,23 @@ which transport an outgoing request should be routed via, with [the same style used for specifying proxy routing](#routing). ```python -import httpcore import httpx -class HTTPSRedirectTransport(httpcore.SyncHTTPTransport): +class HTTPSRedirectTransport(httpx.BaseTransport): """ A transport that always redirects to HTTPS. """ - def request(self, method, url, headers=None, stream=None, ext=None): + def handle_request(self, method, url, headers, stream, extensions): scheme, host, port, path = url if port is None: location = b"https://%s%s" % (host, path) else: location = b"https://%s:%d%s" % (host, port, path) - stream = httpcore.PlainByteStream(b"") + stream = [b""] headers = [(b"location", location)] - ext = {"http_version": b"HTTP/1.1"} - return 303, headers, stream, ext + extensions = {} + return 303, headers, stream, extensions # A client where any `http` requests are always redirected to `https` diff --git a/httpx/__init__.py b/httpx/__init__.py index 96d9e0c2f8..a441669bf6 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -36,6 +36,7 @@ from ._models import URL, Cookies, Headers, QueryParams, Request, Response from ._status_codes import StatusCode, codes from ._transports.asgi import ASGITransport +from ._transports.base import AsyncBaseTransport, BaseTransport from ._transports.default import AsyncHTTPTransport, HTTPTransport from ._transports.mock import MockTransport from ._transports.wsgi import WSGITransport @@ -45,9 +46,11 @@ "__title__", "__version__", "ASGITransport", + "AsyncBaseTransport", "AsyncClient", "AsyncHTTPTransport", "Auth", + "BaseTransport", "BasicAuth", "Client", "CloseError", diff --git a/httpx/_client.py b/httpx/_client.py index da38a14346..691111ba13 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -4,8 +4,6 @@ import warnings from types import TracebackType -import httpcore - from .__version__ import __version__ from ._auth import Auth, BasicAuth, FunctionAuth from ._config import ( @@ -20,15 +18,15 @@ ) from ._decoders import SUPPORTED_DECODERS from ._exceptions import ( - HTTPCORE_EXC_MAP, InvalidURL, RemoteProtocolError, TooManyRedirects, - map_exceptions, + request_context, ) from ._models import URL, Cookies, Headers, QueryParams, Request, Response from ._status_codes import codes from ._transports.asgi import ASGITransport +from ._transports.base import AsyncBaseTransport, BaseTransport from ._transports.default import AsyncHTTPTransport, HTTPTransport from ._transports.wsgi import WSGITransport from ._types import ( @@ -569,14 +567,14 @@ def __init__( cert: CertTypes = None, http2: bool = False, proxies: ProxiesTypes = None, - mounts: typing.Mapping[str, httpcore.SyncHTTPTransport] = None, + mounts: typing.Mapping[str, BaseTransport] = None, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, limits: Limits = DEFAULT_LIMITS, pool_limits: Limits = None, max_redirects: int = DEFAULT_MAX_REDIRECTS, event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None, base_url: URLTypes = "", - transport: httpcore.SyncHTTPTransport = None, + transport: BaseTransport = None, app: typing.Callable = None, trust_env: bool = True, ): @@ -620,9 +618,7 @@ def __init__( app=app, trust_env=trust_env, ) - self._mounts: typing.Dict[ - URLPattern, typing.Optional[httpcore.SyncHTTPTransport] - ] = { + self._mounts: typing.Dict[URLPattern, typing.Optional[BaseTransport]] = { URLPattern(key): None if proxy is None else self._init_proxy_transport( @@ -648,10 +644,10 @@ def _init_transport( cert: CertTypes = None, http2: bool = False, limits: Limits = DEFAULT_LIMITS, - transport: httpcore.SyncHTTPTransport = None, + transport: BaseTransport = None, app: typing.Callable = None, trust_env: bool = True, - ) -> httpcore.SyncHTTPTransport: + ) -> BaseTransport: if transport is not None: return transport @@ -670,7 +666,7 @@ def _init_proxy_transport( http2: bool = False, limits: Limits = DEFAULT_LIMITS, trust_env: bool = True, - ) -> httpcore.SyncHTTPTransport: + ) -> BaseTransport: return HTTPTransport( verify=verify, cert=cert, @@ -680,7 +676,7 @@ def _init_proxy_transport( proxy=proxy, ) - def _transport_for_url(self, url: URL) -> httpcore.SyncHTTPTransport: + def _transport_for_url(self, url: URL) -> BaseTransport: """ Returns the transport instance that should be used for a given URL. This will either be the standard connection pool, or a proxy. @@ -775,21 +771,18 @@ def send( allow_redirects=allow_redirects, history=[], ) - - if not stream: - try: + try: + if not stream: response.read() - finally: - response.close() - try: for hook in self._event_hooks["response"]: hook(response) - except Exception: - response.close() - raise - return response + return response + + except Exception as exc: + response.close() + raise exc def _send_handling_auth( self, @@ -813,18 +806,20 @@ def _send_handling_auth( history=history, ) try: - next_request = auth_flow.send(response) - except StopIteration: - return response - except BaseException as exc: - response.close() - raise exc from None - else: + try: + next_request = auth_flow.send(response) + except StopIteration: + return response + response.history = list(history) response.read() request = next_request history.append(response) + except Exception as exc: + response.close() + raise exc + def _send_handling_redirects( self, request: Request, @@ -839,19 +834,24 @@ def _send_handling_redirects( ) response = self._send_single_request(request, timeout) - response.history = list(history) + try: + response.history = list(history) - if not response.is_redirect: - return response + if not response.is_redirect: + return response - if allow_redirects: - response.read() - request = self._build_redirect_request(request, response) - history = history + [response] + request = self._build_redirect_request(request, response) + history = history + [response] - if not allow_redirects: - response.next_request = request - return response + if allow_redirects: + response.read() + else: + response.next_request = request + return response + + except Exception as exc: + response.close() + raise exc def _send_single_request(self, request: Request, timeout: Timeout) -> Response: """ @@ -861,25 +861,25 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response: timer = Timer() timer.sync_start() - with map_exceptions(HTTPCORE_EXC_MAP, request=request): - (status_code, headers, stream, ext) = transport.request( + with request_context(request=request): + (status_code, headers, stream, extensions) = transport.handle_request( request.method.encode(), request.url.raw, headers=request.headers.raw, stream=request.stream, # type: ignore - ext={"timeout": timeout.as_dict()}, + extensions={"timeout": timeout.as_dict()}, ) def on_close(response: Response) -> None: response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed()) - if hasattr(stream, "close"): - stream.close() + if "close" in extensions: + extensions["close"]() response = Response( status_code, headers=headers, - stream=stream, # type: ignore - ext=ext, + stream=stream, + extensions=extensions, request=request, on_close=on_close, ) @@ -1202,14 +1202,14 @@ def __init__( cert: CertTypes = None, http2: bool = False, proxies: ProxiesTypes = None, - mounts: typing.Mapping[str, httpcore.AsyncHTTPTransport] = None, + mounts: typing.Mapping[str, AsyncBaseTransport] = None, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, limits: Limits = DEFAULT_LIMITS, pool_limits: Limits = None, max_redirects: int = DEFAULT_MAX_REDIRECTS, event_hooks: typing.Mapping[str, typing.List[typing.Callable]] = None, base_url: URLTypes = "", - transport: httpcore.AsyncHTTPTransport = None, + transport: AsyncBaseTransport = None, app: typing.Callable = None, trust_env: bool = True, ): @@ -1254,9 +1254,7 @@ def __init__( trust_env=trust_env, ) - self._mounts: typing.Dict[ - URLPattern, typing.Optional[httpcore.AsyncHTTPTransport] - ] = { + self._mounts: typing.Dict[URLPattern, typing.Optional[AsyncBaseTransport]] = { URLPattern(key): None if proxy is None else self._init_proxy_transport( @@ -1281,10 +1279,10 @@ def _init_transport( cert: CertTypes = None, http2: bool = False, limits: Limits = DEFAULT_LIMITS, - transport: httpcore.AsyncHTTPTransport = None, + transport: AsyncBaseTransport = None, app: typing.Callable = None, trust_env: bool = True, - ) -> httpcore.AsyncHTTPTransport: + ) -> AsyncBaseTransport: if transport is not None: return transport @@ -1303,7 +1301,7 @@ def _init_proxy_transport( http2: bool = False, limits: Limits = DEFAULT_LIMITS, trust_env: bool = True, - ) -> httpcore.AsyncHTTPTransport: + ) -> AsyncBaseTransport: return AsyncHTTPTransport( verify=verify, cert=cert, @@ -1313,7 +1311,7 @@ def _init_proxy_transport( proxy=proxy, ) - def _transport_for_url(self, url: URL) -> httpcore.AsyncHTTPTransport: + def _transport_for_url(self, url: URL) -> AsyncBaseTransport: """ Returns the transport instance that should be used for a given URL. This will either be the standard connection pool, or a proxy. @@ -1409,21 +1407,18 @@ async def send( allow_redirects=allow_redirects, history=[], ) - - if not stream: - try: + try: + if not stream: await response.aread() - finally: - await response.aclose() - try: for hook in self._event_hooks["response"]: await hook(response) - except Exception: - await response.aclose() - raise - return response + return response + + except Exception as exc: + await response.aclose() + raise exc async def _send_handling_auth( self, @@ -1447,18 +1442,20 @@ async def _send_handling_auth( history=history, ) try: - next_request = await auth_flow.asend(response) - except StopAsyncIteration: - return response - except BaseException as exc: - await response.aclose() - raise exc from None - else: + try: + next_request = await auth_flow.asend(response) + except StopAsyncIteration: + return response + response.history = list(history) await response.aread() request = next_request history.append(response) + except Exception as exc: + await response.aclose() + raise exc + async def _send_handling_redirects( self, request: Request, @@ -1473,19 +1470,24 @@ async def _send_handling_redirects( ) response = await self._send_single_request(request, timeout) - response.history = list(history) + try: + response.history = list(history) - if not response.is_redirect: - return response + if not response.is_redirect: + return response - if allow_redirects: - await response.aread() - request = self._build_redirect_request(request, response) - history = history + [response] + request = self._build_redirect_request(request, response) + history = history + [response] - if not allow_redirects: - response.next_request = request - return response + if allow_redirects: + await response.aread() + else: + response.next_request = request + return response + + except Exception as exc: + await response.aclose() + raise exc async def _send_single_request( self, request: Request, timeout: Timeout @@ -1497,26 +1499,30 @@ async def _send_single_request( timer = Timer() await timer.async_start() - with map_exceptions(HTTPCORE_EXC_MAP, request=request): - (status_code, headers, stream, ext) = await transport.arequest( + with request_context(request=request): + ( + status_code, + headers, + stream, + extensions, + ) = await transport.handle_async_request( request.method.encode(), request.url.raw, headers=request.headers.raw, stream=request.stream, # type: ignore - ext={"timeout": timeout.as_dict()}, + extensions={"timeout": timeout.as_dict()}, ) async def on_close(response: Response) -> None: response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed()) - if hasattr(stream, "aclose"): - with map_exceptions(HTTPCORE_EXC_MAP, request=request): - await stream.aclose() + if "aclose" in extensions: + await extensions["aclose"]() response = Response( status_code, headers=headers, - stream=stream, # type: ignore - ext=ext, + stream=stream, + extensions=extensions, request=request, on_close=on_close, ) diff --git a/httpx/_decoders.py b/httpx/_decoders.py index 8ef0157e6f..c0d51a4cdc 100644 --- a/httpx/_decoders.py +++ b/httpx/_decoders.py @@ -8,6 +8,8 @@ import typing import zlib +from ._exceptions import DecodingError + try: import brotli except ImportError: # pragma: nocover @@ -54,13 +56,13 @@ def decode(self, data: bytes) -> bytes: if was_first_attempt: self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS) return self.decode(data) - raise ValueError(str(exc)) + raise DecodingError(str(exc)) from exc def flush(self) -> bytes: try: return self.decompressor.flush() except zlib.error as exc: # pragma: nocover - raise ValueError(str(exc)) + raise DecodingError(str(exc)) from exc class GZipDecoder(ContentDecoder): @@ -77,13 +79,13 @@ def decode(self, data: bytes) -> bytes: try: return self.decompressor.decompress(data) except zlib.error as exc: - raise ValueError(str(exc)) + raise DecodingError(str(exc)) from exc def flush(self) -> bytes: try: return self.decompressor.flush() except zlib.error as exc: # pragma: nocover - raise ValueError(str(exc)) + raise DecodingError(str(exc)) from exc class BrotliDecoder(ContentDecoder): @@ -118,7 +120,7 @@ def decode(self, data: bytes) -> bytes: try: return self._decompress(data) except brotli.error as exc: - raise ValueError(str(exc)) + raise DecodingError(str(exc)) from exc def flush(self) -> bytes: if not self.seen_data: @@ -128,7 +130,7 @@ def flush(self) -> bytes: self.decompressor.finish() return b"" except brotli.error as exc: # pragma: nocover - raise ValueError(str(exc)) + raise DecodingError(str(exc)) from exc class MultiDecoder(ContentDecoder): diff --git a/httpx/_exceptions.py b/httpx/_exceptions.py index bade9f9b81..092dbcf04e 100644 --- a/httpx/_exceptions.py +++ b/httpx/_exceptions.py @@ -34,8 +34,6 @@ import contextlib import typing -import httpcore - if typing.TYPE_CHECKING: from ._models import Request, Response # pragma: nocover @@ -58,9 +56,8 @@ class HTTPError(Exception): ``` """ - def __init__(self, message: str, *, request: "Request") -> None: + def __init__(self, message: str) -> None: super().__init__(message) - self.request = request class RequestError(HTTPError): @@ -68,15 +65,30 @@ class RequestError(HTTPError): Base class for all exceptions that may occur when issuing a `.request()`. """ - def __init__(self, message: str, *, request: "Request") -> None: - super().__init__(message, request=request) + def __init__(self, message: str, *, request: "Request" = None) -> None: + super().__init__(message) + # At the point an exception is raised we won't typically have a request + # instance to associate it with. + # + # The 'request_context' context manager is used within the Client and + # Response methods in order to ensure that any raised exceptions + # have a `.request` property set on them. + self._request = request + + @property + def request(self) -> "Request": + if self._request is None: + raise RuntimeError("The .request property has not been set.") + return self._request + + @request.setter + def request(self, request: "Request") -> None: + self._request = request class TransportError(RequestError): """ Base class for all exceptions that occur at the level of the Transport API. - - All of these exceptions also have an equivelent mapping in `httpcore`. """ @@ -219,7 +231,8 @@ class HTTPStatusError(HTTPError): def __init__( self, message: str, *, request: "Request", response: "Response" ) -> None: - super().__init__(message, request=request) + super().__init__(message) + self.request = request self.response = response @@ -318,45 +331,14 @@ def __init__(self) -> None: @contextlib.contextmanager -def map_exceptions( - mapping: typing.Mapping[typing.Type[Exception], typing.Type[Exception]], - **kwargs: typing.Any, -) -> typing.Iterator[None]: +def request_context(request: "Request" = None) -> typing.Iterator[None]: + """ + A context manager that can be used to attach the given request context + to any `RequestError` exceptions that are raised within the block. + """ try: yield - except Exception as exc: - mapped_exc = None - - for from_exc, to_exc in mapping.items(): - if not isinstance(exc, from_exc): - continue - # We want to map to the most specific exception we can find. - # Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to - # `httpx.ReadTimeout`, not just `httpx.TimeoutException`. - if mapped_exc is None or issubclass(to_exc, mapped_exc): - mapped_exc = to_exc - - if mapped_exc is None: - raise - - message = str(exc) - raise mapped_exc(message, **kwargs) from exc # type: ignore - - -HTTPCORE_EXC_MAP = { - httpcore.TimeoutException: TimeoutException, - httpcore.ConnectTimeout: ConnectTimeout, - httpcore.ReadTimeout: ReadTimeout, - httpcore.WriteTimeout: WriteTimeout, - httpcore.PoolTimeout: PoolTimeout, - httpcore.NetworkError: NetworkError, - httpcore.ConnectError: ConnectError, - httpcore.ReadError: ReadError, - httpcore.WriteError: WriteError, - httpcore.CloseError: CloseError, - httpcore.ProxyError: ProxyError, - httpcore.UnsupportedProtocol: UnsupportedProtocol, - httpcore.ProtocolError: ProtocolError, - httpcore.LocalProtocolError: LocalProtocolError, - httpcore.RemoteProtocolError: RemoteProtocolError, -} + except RequestError as exc: + if request is not None: + exc.request = request + raise exc diff --git a/httpx/_models.py b/httpx/_models.py index 83deb9a243..34fb2d388c 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -1,5 +1,4 @@ import cgi -import contextlib import datetime import email.message import json as jsonlib @@ -24,16 +23,14 @@ TextDecoder, ) from ._exceptions import ( - HTTPCORE_EXC_MAP, CookieConflict, - DecodingError, HTTPStatusError, InvalidURL, RequestNotRead, ResponseClosed, ResponseNotRead, StreamConsumed, - map_exceptions, + request_context, ) from ._status_codes import codes from ._types import ( @@ -909,7 +906,7 @@ def __init__( json: typing.Any = None, stream: ByteStream = None, request: Request = None, - ext: dict = None, + extensions: dict = None, history: typing.List["Response"] = None, on_close: typing.Callable = None, ): @@ -924,7 +921,7 @@ def __init__( self.call_next: typing.Optional[typing.Callable] = None - self.ext = {} if ext is None else ext + self.extensions = {} if extensions is None else extensions self.history = [] if history is None else list(history) self._on_close = on_close @@ -995,11 +992,17 @@ def request(self, value: Request) -> None: @property def http_version(self) -> str: - return self.ext.get("http_version", "HTTP/1.1") + try: + return self.extensions["http_version"].decode("ascii", errors="ignore") + except KeyError: + return "HTTP/1.1" @property def reason_phrase(self) -> str: - return self.ext.get("reason", codes.get_reason_phrase(self.status_code)) + try: + return self.extensions["reason_phrase"].decode("ascii", errors="ignore") + except KeyError: + return codes.get_reason_phrase(self.status_code) @property def url(self) -> typing.Optional[URL]: @@ -1152,17 +1155,6 @@ def num_bytes_downloaded(self) -> int: def __repr__(self) -> str: return f"" - @contextlib.contextmanager - def _wrap_decoder_errors(self) -> typing.Iterator[None]: - # If the response has an associated request instance, we want decoding - # errors to be raised as proper `httpx.DecodingError` exceptions. - try: - yield - except ValueError as exc: - if self._request is None: - raise exc - raise DecodingError(message=str(exc), request=self.request) from exc - def read(self) -> bytes: """ Read and return the response content. @@ -1183,7 +1175,7 @@ def iter_bytes(self, chunk_size: int = None) -> typing.Iterator[bytes]: else: decoder = self._get_content_decoder() chunker = ByteChunker(chunk_size=chunk_size) - with self._wrap_decoder_errors(): + with request_context(request=self._request): for raw_bytes in self.iter_raw(): decoded = decoder.decode(raw_bytes) for chunk in chunker.decode(decoded): @@ -1202,7 +1194,7 @@ def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]: """ decoder = TextDecoder(encoding=self.encoding) chunker = TextChunker(chunk_size=chunk_size) - with self._wrap_decoder_errors(): + with request_context(request=self._request): for byte_content in self.iter_bytes(): text_content = decoder.decode(byte_content) for chunk in chunker.decode(text_content): @@ -1215,7 +1207,7 @@ def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]: def iter_lines(self) -> typing.Iterator[str]: decoder = LineDecoder() - with self._wrap_decoder_errors(): + with request_context(request=self._request): for text in self.iter_text(): for line in decoder.decode(text): yield line @@ -1237,7 +1229,7 @@ def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]: self._num_bytes_downloaded = 0 chunker = ByteChunker(chunk_size=chunk_size) - with map_exceptions(HTTPCORE_EXC_MAP, request=self._request): + with request_context(request=self._request): for raw_stream_bytes in self.stream: self._num_bytes_downloaded += len(raw_stream_bytes) for chunk in chunker.decode(raw_stream_bytes): @@ -1256,7 +1248,8 @@ def close(self) -> None: if not self.is_closed: self.is_closed = True if self._on_close is not None: - self._on_close(self) + with request_context(request=self._request): + self._on_close(self) async def aread(self) -> bytes: """ @@ -1278,7 +1271,7 @@ async def aiter_bytes(self, chunk_size: int = None) -> typing.AsyncIterator[byte else: decoder = self._get_content_decoder() chunker = ByteChunker(chunk_size=chunk_size) - with self._wrap_decoder_errors(): + with request_context(request=self._request): async for raw_bytes in self.aiter_raw(): decoded = decoder.decode(raw_bytes) for chunk in chunker.decode(decoded): @@ -1297,7 +1290,7 @@ async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]: """ decoder = TextDecoder(encoding=self.encoding) chunker = TextChunker(chunk_size=chunk_size) - with self._wrap_decoder_errors(): + with request_context(request=self._request): async for byte_content in self.aiter_bytes(): text_content = decoder.decode(byte_content) for chunk in chunker.decode(text_content): @@ -1310,7 +1303,7 @@ async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]: async def aiter_lines(self) -> typing.AsyncIterator[str]: decoder = LineDecoder() - with self._wrap_decoder_errors(): + with request_context(request=self._request): async for text in self.aiter_text(): for line in decoder.decode(text): yield line @@ -1332,7 +1325,7 @@ async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes] self._num_bytes_downloaded = 0 chunker = ByteChunker(chunk_size=chunk_size) - with map_exceptions(HTTPCORE_EXC_MAP, request=self._request): + with request_context(request=self._request): async for raw_stream_bytes in self.stream: self._num_bytes_downloaded += len(raw_stream_bytes) for chunk in chunker.decode(raw_stream_bytes): @@ -1351,7 +1344,8 @@ async def aclose(self) -> None: if not self.is_closed: self.is_closed = True if self._on_close is not None: - await self._on_close(self) + with request_context(request=self._request): + await self._on_close(self) class Cookies(MutableMapping): diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index 758d8375b2..ef0a3ef29a 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -1,15 +1,16 @@ -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union +import typing from urllib.parse import unquote -import httpcore import sniffio -if TYPE_CHECKING: # pragma: no cover +from .base import AsyncBaseTransport + +if typing.TYPE_CHECKING: # pragma: no cover import asyncio import trio - Event = Union[asyncio.Event, trio.Event] + Event = typing.Union[asyncio.Event, trio.Event] def create_event() -> "Event": @@ -23,7 +24,7 @@ def create_event() -> "Event": return asyncio.Event() -class ASGITransport(httpcore.AsyncHTTPTransport): +class ASGITransport(AsyncBaseTransport): """ A custom AsyncTransport that handles sending requests directly to an ASGI app. The simplest way to use this functionality is to use the `app` argument. @@ -58,27 +59,26 @@ class ASGITransport(httpcore.AsyncHTTPTransport): def __init__( self, - app: Callable, + app: typing.Callable, raise_app_exceptions: bool = True, root_path: str = "", - client: Tuple[str, int] = ("127.0.0.1", 123), + client: typing.Tuple[str, int] = ("127.0.0.1", 123), ) -> None: self.app = app self.raise_app_exceptions = raise_app_exceptions self.root_path = root_path self.client = client - async def arequest( + async def handle_async_request( self, method: bytes, - url: Tuple[bytes, bytes, Optional[int], bytes], - headers: List[Tuple[bytes, bytes]] = None, - stream: httpcore.AsyncByteStream = None, - ext: dict = None, - ) -> Tuple[int, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream, dict]: - headers = [] if headers is None else headers - stream = httpcore.PlainByteStream(content=b"") if stream is None else stream - + url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], + stream: typing.AsyncIterable[bytes], + extensions: dict, + ) -> typing.Tuple[ + int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict + ]: # ASGI scope. scheme, host, port, full_path = url path, _, query = full_path.partition(b"?") @@ -155,7 +155,9 @@ async def send(message: dict) -> None: assert status_code is not None assert response_headers is not None - stream = httpcore.PlainByteStream(content=b"".join(body_parts)) - ext = {} + async def response_stream() -> typing.AsyncIterator[bytes]: + yield b"".join(body_parts) + + extensions = {} - return (status_code, response_headers, stream, ext) + return (status_code, response_headers, response_stream(), extensions) diff --git a/httpx/_transports/base.py b/httpx/_transports/base.py new file mode 100644 index 0000000000..e26938f94b --- /dev/null +++ b/httpx/_transports/base.py @@ -0,0 +1,129 @@ +import typing +from types import TracebackType + +T = typing.TypeVar("T", bound="BaseTransport") +A = typing.TypeVar("A", bound="AsyncBaseTransport") + + +class BaseTransport: + def __enter__(self: T) -> T: + return self + + def __exit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self.close() + + def handle_request( + self, + method: bytes, + url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], + stream: typing.Iterable[bytes], + extensions: dict, + ) -> typing.Tuple[ + int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict + ]: + """ + Send a single HTTP request and return a response. + + At this layer of API we're simply using plain primitives. No `Request` or + `Response` models, no fancy `URL` or `Header` handling. This strict point + of cut-off provides a clear design seperation between the HTTPX API, + and the low-level network handling. + + Developers shouldn't typically ever need to call into this API directly, + since the Client class provides all the higher level user-facing API + niceties. + + Example usage: + + with httpx.HTTPTransport() as transport: + status_code, headers, stream, extensions = transport.handle_request( + method=b'GET', + url=(b'https', b'www.example.com', 443, b'/'), + headers=[(b'Host', b'www.example.com')], + stream=[], + extensions={} + ) + try: + body = b''.join([part for part in stream]) + finally: + if 'close' in extensions: + extensions['close']() + print(status_code, headers, body) + + Arguments: + + method: The request method as bytes. Eg. b'GET'. + url: The components of the request URL, as a tuple of `(scheme, host, port, target)`. + The target will usually be the URL path, but also allows for alternative + formulations, such as proxy requests which include the complete URL in + the target portion of the HTTP request, or for "OPTIONS *" requests, which + cannot be expressed in a URL string. + headers: The request headers as a list of byte pairs. + stream: The request body as a bytes iterator. + extensions: An open ended dictionary, including optional extensions to the + core request/response API. Keys may include: + timeout: A dictionary of str:Optional[float] timeout values. + May include values for 'connect', 'read', 'write', or 'pool'. + + Returns a tuple of: + + status_code: The response status code as an integer. Should be in the range 1xx-5xx. + headers: The response headers as a list of byte pairs. + stream: The response body as a bytes iterator. + extensions: An open ended dictionary, including optional extensions to the + core request/response API. Keys are plain strings, and may include: + reason_phrase: The reason-phrase of the HTTP response, as bytes. Eg b'OK'. + HTTP/2 onwards does not include a reason phrase on the wire. + When no key is included, a default based on the status code may + be used. An empty-string reason phrase should not be substituted + for a default, as it indicates the server left the portion blank + eg. the leading response bytes were b"HTTP/1.1 200 ". + http_version: The HTTP version, as bytes. Eg. b"HTTP/1.1". + When no http_version key is included, HTTP/1.1 may be assumed. + close: A callback which should be invoked to release any network + resources. + aclose: An async callback which should be invoked to release any + network resources. + """ + raise NotImplementedError( + "The 'handle_request' method must be implemented." + ) # pragma: nocover + + def close(self) -> None: + pass + + +class AsyncBaseTransport: + async def __aenter__(self: A) -> A: + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.aclose() + + async def handle_async_request( + self, + method: bytes, + url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], + stream: typing.AsyncIterable[bytes], + extensions: dict, + ) -> typing.Tuple[ + int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict + ]: + raise NotImplementedError( + "The 'handle_async_request' method must be implemented." + ) # pragma: nocover + + async def aclose(self) -> None: + pass diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index 84aeb26be8..67f62322af 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -24,21 +24,93 @@ transport = httpx.HTTPTransport(uds="socket.uds") client = httpx.Client(transport=transport) """ +import contextlib import typing from types import TracebackType import httpcore from .._config import DEFAULT_LIMITS, Limits, Proxy, create_ssl_context +from .._exceptions import ( + CloseError, + ConnectError, + ConnectTimeout, + LocalProtocolError, + NetworkError, + PoolTimeout, + ProtocolError, + ProxyError, + ReadError, + ReadTimeout, + RemoteProtocolError, + TimeoutException, + UnsupportedProtocol, + WriteError, + WriteTimeout, +) from .._types import CertTypes, VerifyTypes +from .base import AsyncBaseTransport, BaseTransport T = typing.TypeVar("T", bound="HTTPTransport") A = typing.TypeVar("A", bound="AsyncHTTPTransport") -Headers = typing.List[typing.Tuple[bytes, bytes]] -URL = typing.Tuple[bytes, bytes, typing.Optional[int], bytes] -class HTTPTransport(httpcore.SyncHTTPTransport): +@contextlib.contextmanager +def map_httpcore_exceptions() -> typing.Iterator[None]: + try: + yield + except Exception as exc: + mapped_exc = None + + for from_exc, to_exc in HTTPCORE_EXC_MAP.items(): + if not isinstance(exc, from_exc): + continue + # We want to map to the most specific exception we can find. + # Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to + # `httpx.ReadTimeout`, not just `httpx.TimeoutException`. + if mapped_exc is None or issubclass(to_exc, mapped_exc): + mapped_exc = to_exc + + if mapped_exc is None: # pragma: nocover + raise + + message = str(exc) + raise mapped_exc(message) from exc + + +def ensure_http_version_reason_phrase_as_bytes(extensions: dict) -> None: + # From HTTPX 0.18 onwards we're treating the "reason_phrase" and "http_version" + # extensions as bytes, in order to be more precise. Also we're using the + # "reason_phrase" key in preference to "reason", in order to match properly + # with the HTTP spec naming. + # HTTPCore 0.12 does not yet use these same conventions for the extensions, + # so we bridge between the two styles for now. + if "reason" in extensions: + extensions["reason_phrase"] = extensions.pop("reason").encode("ascii") + if "http_version" in extensions: + extensions["http_version"] = extensions["http_version"].encode("ascii") + + +HTTPCORE_EXC_MAP = { + httpcore.TimeoutException: TimeoutException, + httpcore.ConnectTimeout: ConnectTimeout, + httpcore.ReadTimeout: ReadTimeout, + httpcore.WriteTimeout: WriteTimeout, + httpcore.PoolTimeout: PoolTimeout, + httpcore.NetworkError: NetworkError, + httpcore.ConnectError: ConnectError, + httpcore.ReadError: ReadError, + httpcore.WriteError: WriteError, + httpcore.CloseError: CloseError, + httpcore.ProxyError: ProxyError, + httpcore.UnsupportedProtocol: UnsupportedProtocol, + httpcore.ProtocolError: ProtocolError, + httpcore.LocalProtocolError: LocalProtocolError, + httpcore.RemoteProtocolError: RemoteProtocolError, +} + + +class HTTPTransport(BaseTransport): def __init__( self, verify: VerifyTypes = True, @@ -91,21 +163,44 @@ def __exit__( ) -> None: self._pool.__exit__(exc_type, exc_value, traceback) - def request( + def handle_request( self, method: bytes, - url: URL, - headers: Headers = None, - stream: httpcore.SyncByteStream = None, - ext: dict = None, - ) -> typing.Tuple[int, Headers, httpcore.SyncByteStream, dict]: - return self._pool.request(method, url, headers=headers, stream=stream, ext=ext) + url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], + stream: typing.Iterable[bytes], + extensions: dict, + ) -> typing.Tuple[ + int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict + ]: + with map_httpcore_exceptions(): + status_code, headers, byte_stream, extensions = self._pool.request( + method=method, + url=url, + headers=headers, + stream=httpcore.IteratorByteStream(iter(stream)), + ext=extensions, + ) + + def response_stream() -> typing.Iterator[bytes]: + with map_httpcore_exceptions(): + for part in byte_stream: + yield part + + def close() -> None: + with map_httpcore_exceptions(): + byte_stream.close() + + ensure_http_version_reason_phrase_as_bytes(extensions) + extensions["close"] = close + + return status_code, headers, response_stream(), extensions def close(self) -> None: self._pool.close() -class AsyncHTTPTransport(httpcore.AsyncHTTPTransport): +class AsyncHTTPTransport(AsyncBaseTransport): def __init__( self, verify: VerifyTypes = True, @@ -158,17 +253,38 @@ async def __aexit__( ) -> None: await self._pool.__aexit__(exc_type, exc_value, traceback) - async def arequest( + async def handle_async_request( self, method: bytes, - url: URL, - headers: Headers = None, - stream: httpcore.AsyncByteStream = None, - ext: dict = None, - ) -> typing.Tuple[int, Headers, httpcore.AsyncByteStream, dict]: - return await self._pool.arequest( - method, url, headers=headers, stream=stream, ext=ext - ) + url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], + stream: typing.AsyncIterable[bytes], + extensions: dict, + ) -> typing.Tuple[ + int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict + ]: + with map_httpcore_exceptions(): + status_code, headers, byte_stream, extenstions = await self._pool.arequest( + method=method, + url=url, + headers=headers, + stream=httpcore.AsyncIteratorByteStream(stream.__aiter__()), + ext=extensions, + ) + + async def response_stream() -> typing.AsyncIterator[bytes]: + with map_httpcore_exceptions(): + async for part in byte_stream: + yield part + + async def aclose() -> None: + with map_httpcore_exceptions(): + await byte_stream.aclose() + + ensure_http_version_reason_phrase_as_bytes(extensions) + extensions["aclose"] = aclose + + return status_code, headers, response_stream(), extensions async def aclose(self) -> None: await self._pool.aclose() diff --git a/httpx/_transports/mock.py b/httpx/_transports/mock.py index a55a88b7a2..b6ca353a31 100644 --- a/httpx/_transports/mock.py +++ b/httpx/_transports/mock.py @@ -1,23 +1,24 @@ import asyncio -from typing import Callable, List, Optional, Tuple - -import httpcore +import typing from .._models import Request +from .base import AsyncBaseTransport, BaseTransport -class MockTransport(httpcore.SyncHTTPTransport, httpcore.AsyncHTTPTransport): - def __init__(self, handler: Callable) -> None: +class MockTransport(AsyncBaseTransport, BaseTransport): + def __init__(self, handler: typing.Callable) -> None: self.handler = handler - def request( + def handle_request( self, method: bytes, - url: Tuple[bytes, bytes, Optional[int], bytes], - headers: List[Tuple[bytes, bytes]] = None, - stream: httpcore.SyncByteStream = None, - ext: dict = None, - ) -> Tuple[int, List[Tuple[bytes, bytes]], httpcore.SyncByteStream, dict]: + url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], + stream: typing.Iterable[bytes], + extensions: dict, + ) -> typing.Tuple[ + int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict + ]: request = Request( method=method, url=url, @@ -30,17 +31,19 @@ def request( response.status_code, response.headers.raw, response.stream, - response.ext, + response.extensions, ) - async def arequest( + async def handle_async_request( self, method: bytes, - url: Tuple[bytes, bytes, Optional[int], bytes], - headers: List[Tuple[bytes, bytes]] = None, - stream: httpcore.AsyncByteStream = None, - ext: dict = None, - ) -> Tuple[int, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream, dict]: + url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], + headers: typing.List[typing.Tuple[bytes, bytes]], + stream: typing.AsyncIterable[bytes], + extensions: dict, + ) -> typing.Tuple[ + int, typing.List[typing.Tuple[bytes, bytes]], typing.AsyncIterable[bytes], dict + ]: request = Request( method=method, url=url, @@ -63,5 +66,5 @@ async def arequest( response.status_code, response.headers.raw, response.stream, - response.ext, + response.extensions, ) diff --git a/httpx/_transports/wsgi.py b/httpx/_transports/wsgi.py index 67b44bde42..3b7651fba7 100644 --- a/httpx/_transports/wsgi.py +++ b/httpx/_transports/wsgi.py @@ -3,7 +3,7 @@ import typing from urllib.parse import unquote -import httpcore +from .base import BaseTransport def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable: @@ -14,7 +14,7 @@ def _skip_leading_empty_chunks(body: typing.Iterable) -> typing.Iterable: return [] -class WSGITransport(httpcore.SyncHTTPTransport): +class WSGITransport(BaseTransport): """ A custom transport that handles sending requests directly to an WSGI app. The simplest way to use this functionality is to use the `app` argument. @@ -59,18 +59,17 @@ def __init__( self.script_name = script_name self.remote_addr = remote_addr - def request( + def handle_request( self, method: bytes, url: typing.Tuple[bytes, bytes, typing.Optional[int], bytes], - headers: typing.List[typing.Tuple[bytes, bytes]] = None, - stream: httpcore.SyncByteStream = None, - ext: dict = None, + headers: typing.List[typing.Tuple[bytes, bytes]], + stream: typing.Iterable[bytes], + extensions: dict, ) -> typing.Tuple[ - int, typing.List[typing.Tuple[bytes, bytes]], httpcore.SyncByteStream, dict + int, typing.List[typing.Tuple[bytes, bytes]], typing.Iterable[bytes], dict ]: - headers = [] if headers is None else headers - stream = httpcore.PlainByteStream(content=b"") if stream is None else stream + wsgi_input = io.BytesIO(b"".join(stream)) scheme, host, port, full_path = url path, _, query = full_path.partition(b"?") @@ -80,7 +79,7 @@ def request( environ = { "wsgi.version": (1, 0), "wsgi.url_scheme": scheme.decode("ascii"), - "wsgi.input": io.BytesIO(b"".join(stream)), + "wsgi.input": wsgi_input, "wsgi.errors": io.BytesIO(), "wsgi.multithread": True, "wsgi.multiprocess": False, @@ -126,7 +125,6 @@ def start_response( (key.encode("ascii"), value.encode("ascii")) for key, value in seen_response_headers ] - stream = httpcore.IteratorByteStream(iterator=result) - ext = {} + extensions = {} - return (status_code, headers, stream, ext) + return (status_code, headers, result, extensions) diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 1d3f4ccafa..99493c43ab 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -1,7 +1,6 @@ import typing from datetime import timedelta -import httpcore import pytest import httpx @@ -169,12 +168,12 @@ async def test_100_continue(server): @pytest.mark.usefixtures("async_environment") async def test_context_managed_transport(): - class Transport(httpcore.AsyncHTTPTransport): + class Transport(httpx.AsyncBaseTransport): def __init__(self): self.events = [] async def aclose(self): - # The base implementation of httpcore.AsyncHTTPTransport just + # The base implementation of httpx.AsyncBaseTransport just # calls into `.aclose`, so simple transport cases can just override # this method for any cleanup, where more complex cases # might want to additionally override `__aenter__`/`__aexit__`. @@ -201,13 +200,13 @@ async def __aexit__(self, *args): @pytest.mark.usefixtures("async_environment") async def test_context_managed_transport_and_mount(): - class Transport(httpcore.AsyncHTTPTransport): + class Transport(httpx.AsyncBaseTransport): def __init__(self, name: str): self.name: str = name self.events: typing.List[str] = [] async def aclose(self): - # The base implementation of httpcore.AsyncHTTPTransport just + # The base implementation of httpx.AsyncBaseTransport just # calls into `.aclose`, so simple transport cases can just override # this method for any cleanup, where more complex cases # might want to additionally override `__aenter__`/`__aexit__`. @@ -303,25 +302,6 @@ async def test_mounted_transport(): assert response.json() == {"app": "mounted"} -@pytest.mark.usefixtures("async_environment") -async def test_response_aclose_map_exceptions(): - class BrokenStream: - async def __aiter__(self): - # so we're an AsyncIterator - pass # pragma: nocover - - async def aclose(self): - raise httpcore.CloseError(OSError(104, "Connection reset by peer")) - - def handle(request: httpx.Request) -> httpx.Response: - return httpx.Response(200, stream=BrokenStream()) - - async with httpx.AsyncClient(transport=httpx.MockTransport(handle)) as client: - async with client.stream("GET", "http://example.com") as response: - with pytest.raises(httpx.CloseError): - await response.aclose() - - @pytest.mark.usefixtures("async_environment") async def test_async_mock_transport(): async def hello_world(request): diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 7e32bcf6f3..386cd7480c 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,7 +1,6 @@ import typing from datetime import timedelta -import httpcore import pytest import httpx @@ -224,12 +223,12 @@ def test_pool_limits_deprecated(): def test_context_managed_transport(): - class Transport(httpcore.SyncHTTPTransport): + class Transport(httpx.BaseTransport): def __init__(self): self.events = [] def close(self): - # The base implementation of httpcore.SyncHTTPTransport just + # The base implementation of httpx.BaseTransport just # calls into `.close`, so simple transport cases can just override # this method for any cleanup, where more complex cases # might want to additionally override `__enter__`/`__exit__`. @@ -255,13 +254,13 @@ def __exit__(self, *args): def test_context_managed_transport_and_mount(): - class Transport(httpcore.SyncHTTPTransport): + class Transport(httpx.BaseTransport): def __init__(self, name: str): self.name: str = name self.events: typing.List[str] = [] def close(self): - # The base implementation of httpcore.SyncHTTPTransport just + # The base implementation of httpx.BaseTransport just # calls into `.close`, so simple transport cases can just override # this method for any cleanup, where more complex cases # might want to additionally override `__enter__`/`__exit__`. diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 84d371e9fa..22c5aa0f1a 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -1,4 +1,3 @@ -import httpcore import pytest import httpx @@ -6,9 +5,7 @@ def redirects(request: httpx.Request) -> httpx.Response: if request.url.scheme not in ("http", "https"): - raise httpcore.UnsupportedProtocol( - f"Scheme {request.url.scheme!r} not supported." - ) + raise httpx.UnsupportedProtocol(f"Scheme {request.url.scheme!r} not supported.") if request.url.path == "/redirect_301": status_code = httpx.codes.MOVED_PERMANENTLY @@ -396,3 +393,10 @@ def test_redirect_custom_scheme(): with pytest.raises(httpx.UnsupportedProtocol) as e: client.post("https://example.org/redirect_custom_scheme") assert str(e.value) == "Scheme 'market' not supported." + + +@pytest.mark.usefixtures("async_environment") +async def test_async_invalid_redirect(): + async with httpx.AsyncClient(transport=httpx.MockTransport(redirects)) as client: + with pytest.raises(httpx.RemoteProtocolError): + await client.get("http://example.org/invalid_redirect") diff --git a/tests/conftest.py b/tests/conftest.py index 12db1b0bb2..62c10c9fb4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,8 +76,6 @@ async def app(scope, receive, send): assert scope["type"] == "http" if scope["path"].startswith("/slow_response"): await slow_response(scope, receive, send) - elif scope["path"].startswith("/slow_stream_response"): - await slow_stream_response(scope, receive, send) elif scope["path"].startswith("/status"): await status_code(scope, receive, send) elif scope["path"].startswith("/echo_body"): @@ -113,19 +111,6 @@ async def slow_response(scope, receive, send): await send({"type": "http.response.body", "body": b"Hello, world!"}) -async def slow_stream_response(scope, receive, send): - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [[b"content-type", b"text/plain"]], - } - ) - - await sleep(1) - await send({"type": "http.response.body", "body": b"", "more_body": False}) - - async def status_code(scope, receive, send): status_code = int(scope["path"].replace("/status/", "")) await send( diff --git a/tests/models/test_responses.py b/tests/models/test_responses.py index cb46719c17..793fad3b76 100644 --- a/tests/models/test_responses.py +++ b/tests/models/test_responses.py @@ -733,7 +733,7 @@ def test_json_without_specified_encoding_value_error(): # force incorrect guess from `guess_json_utf` to trigger error with mock.patch("httpx._models.guess_json_utf", return_value="utf-32"): response = httpx.Response(200, content=content, headers=headers) - with pytest.raises(ValueError): + with pytest.raises(json.decoder.JSONDecodeError): response.json() @@ -767,7 +767,7 @@ def test_decode_error_with_request(header_value): headers = [(b"Content-Encoding", header_value)] body = b"test 123" compressed_body = brotli.compress(body)[3:] - with pytest.raises(ValueError): + with pytest.raises(httpx.DecodingError): httpx.Response( 200, headers=headers, @@ -788,7 +788,7 @@ def test_value_error_without_request(header_value): headers = [(b"Content-Encoding", header_value)] body = b"test 123" compressed_body = brotli.compress(body)[3:] - with pytest.raises(ValueError): + with pytest.raises(httpx.DecodingError): httpx.Response(200, headers=headers, content=compressed_body) diff --git a/tests/test_asgi.py b/tests/test_asgi.py index b16f68246c..d7cf9412af 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -70,6 +70,42 @@ async def raise_exc_after_response(scope, receive, send): raise RuntimeError() +async def empty_stream(): + yield b"" + + +@pytest.mark.usefixtures("async_environment") +async def test_asgi_transport(): + async with httpx.ASGITransport(app=hello_world) as transport: + status_code, headers, stream, ext = await transport.handle_async_request( + method=b"GET", + url=(b"http", b"www.example.org", 80, b"/"), + headers=[(b"Host", b"www.example.org")], + stream=empty_stream(), + extensions={}, + ) + body = b"".join([part async for part in stream]) + + assert status_code == 200 + assert body == b"Hello, World!" + + +@pytest.mark.usefixtures("async_environment") +async def test_asgi_transport_no_body(): + async with httpx.ASGITransport(app=echo_body) as transport: + status_code, headers, stream, ext = await transport.handle_async_request( + method=b"GET", + url=(b"http", b"www.example.org", 80, b"/"), + headers=[(b"Host", b"www.example.org")], + stream=empty_stream(), + extensions={}, + ) + body = b"".join([part async for part in stream]) + + assert status_code == 200 + assert body == b"" + + @pytest.mark.usefixtures("async_environment") async def test_asgi(): async with httpx.AsyncClient(app=hello_world) as client: diff --git a/tests/test_decoders.py b/tests/test_decoders.py index f8c432cc89..faaf71d2fb 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -170,7 +170,7 @@ def test_decoding_errors(header_value): request = httpx.Request("GET", "https://example.org") httpx.Response(200, headers=headers, content=compressed_body, request=request) - with pytest.raises(ValueError): + with pytest.raises(httpx.DecodingError): httpx.Response(200, headers=headers, content=compressed_body) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index f1c7005bba..1bc6723a87 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,10 +1,10 @@ -from typing import Any +from unittest import mock import httpcore import pytest import httpx -from httpx._exceptions import HTTPCORE_EXC_MAP +from httpx._transports.default import HTTPCORE_EXC_MAP def test_httpcore_all_exceptions_mapped() -> None: @@ -29,25 +29,40 @@ def test_httpcore_exception_mapping(server) -> None: HTTPCore exception mapping works as expected. """ - # Make sure we don't just map to `NetworkError`. - with pytest.raises(httpx.ConnectError): - httpx.get("http://doesnotexist") + def connect_failed(*args, **kwargs): + raise httpcore.ConnectError() - # Make sure streaming methods also map exceptions. - url = server.url.copy_with(path="/slow_stream_response") - timeout = httpx.Timeout(None, read=0.1) - with httpx.stream("GET", url, timeout=timeout) as stream: - with pytest.raises(httpx.ReadTimeout): - stream.read() + class TimeoutStream: + def __iter__(self): + raise httpcore.ReadTimeout() + + def close(self): + pass + + class CloseFailedStream: + def __iter__(self): + yield b"" - # Make sure it also works with custom transports. - class MockTransport(httpcore.SyncHTTPTransport): - def request(self, *args: Any, **kwargs: Any) -> Any: - raise httpcore.ProtocolError() + def close(self): + raise httpcore.CloseError() - client = httpx.Client(transport=MockTransport()) - with pytest.raises(httpx.ProtocolError): - client.get("http://testserver") + with mock.patch("httpcore.SyncConnectionPool.request", side_effect=connect_failed): + with pytest.raises(httpx.ConnectError): + httpx.get(server.url) + + with mock.patch( + "httpcore.SyncConnectionPool.request", + return_value=(200, [], TimeoutStream(), {}), + ): + with pytest.raises(httpx.ReadTimeout): + httpx.get(server.url) + + with mock.patch( + "httpcore.SyncConnectionPool.request", + return_value=(200, [], CloseFailedStream(), {}), + ): + with pytest.raises(httpx.CloseError): + httpx.get(server.url) def test_httpx_exceptions_exposed() -> None: @@ -66,3 +81,15 @@ def test_httpx_exceptions_exposed() -> None: if not_exposed: # pragma: nocover pytest.fail(f"Unexposed HTTPX exceptions: {not_exposed}") + + +def test_request_attribute() -> None: + # Exception without request attribute + exc = httpx.ReadTimeout("Read operation timed out") + with pytest.raises(RuntimeError): + exc.request + + # Exception with request attribute + request = httpx.Request("GET", "https://www.example.com") + exc = httpx.ReadTimeout("Read operation timed out", request=request) + assert exc.request == request