Skip to content

Commit

Permalink
Push httpcore exception wrapping out of client into transport (#1524)
Browse files Browse the repository at this point in the history
* Push httpcore exception wrapping out of client into transport

* Include close/aclose extensions in docstring

* Comment about the request property on RequestError exceptions
  • Loading branch information
tomchristie committed Mar 23, 2021
1 parent e94beae commit ee2a612
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 158 deletions.
20 changes: 9 additions & 11 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
)
from ._decoders import SUPPORTED_DECODERS
from ._exceptions import (
HTTPCORE_EXC_MAP,
InvalidURL,
RemoteProtocolError,
TooManyRedirects,
map_exceptions,
request_context,
)
from ._models import URL, Cookies, Headers, QueryParams, Request, Response
from ._status_codes import codes
Expand Down Expand Up @@ -849,7 +848,7 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response:
timer = Timer()
timer.sync_start()

with map_exceptions(HTTPCORE_EXC_MAP, request=request):
with request_context(request=request):
(status_code, headers, stream, extensions) = transport.handle_request(
request.method.encode(),
request.url.raw,
Expand All @@ -860,13 +859,13 @@ def _send_single_request(self, request: Request, timeout: Timeout) -> Response:

def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=timer.sync_elapsed())
if hasattr(stream, "close"):
stream.close() # type: ignore
if "close" in extensions:
extensions["close"]()

response = Response(
status_code,
headers=headers,
stream=stream, # type: ignore
stream=stream,
extensions=extensions,
request=request,
on_close=on_close,
Expand Down Expand Up @@ -1483,7 +1482,7 @@ async def _send_single_request(
timer = Timer()
await timer.async_start()

with map_exceptions(HTTPCORE_EXC_MAP, request=request):
with request_context(request=request):
(
status_code,
headers,
Expand All @@ -1499,14 +1498,13 @@ async def _send_single_request(

async def on_close(response: Response) -> None:
response.elapsed = datetime.timedelta(seconds=await timer.async_elapsed())
if hasattr(stream, "aclose"):
with map_exceptions(HTTPCORE_EXC_MAP, request=request):
await stream.aclose() # type: ignore
if "aclose" in extensions:
await extensions["aclose"]()

response = Response(
status_code,
headers=headers,
stream=stream, # type: ignore
stream=stream,
extensions=extensions,
request=request,
on_close=on_close,
Expand Down
14 changes: 8 additions & 6 deletions httpx/_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import typing
import zlib

from ._exceptions import DecodingError

try:
import brotli
except ImportError: # pragma: nocover
Expand Down Expand Up @@ -54,13 +56,13 @@ def decode(self, data: bytes) -> bytes:
if was_first_attempt:
self.decompressor = zlib.decompressobj(-zlib.MAX_WBITS)
return self.decode(data)
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc

def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: nocover
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc


class GZipDecoder(ContentDecoder):
Expand All @@ -77,13 +79,13 @@ def decode(self, data: bytes) -> bytes:
try:
return self.decompressor.decompress(data)
except zlib.error as exc:
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc

def flush(self) -> bytes:
try:
return self.decompressor.flush()
except zlib.error as exc: # pragma: nocover
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc


class BrotliDecoder(ContentDecoder):
Expand Down Expand Up @@ -118,7 +120,7 @@ def decode(self, data: bytes) -> bytes:
try:
return self._decompress(data)
except brotli.error as exc:
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc

def flush(self) -> bytes:
if not self.seen_data:
Expand All @@ -128,7 +130,7 @@ def flush(self) -> bytes:
self.decompressor.finish()
return b""
except brotli.error as exc: # pragma: nocover
raise ValueError(str(exc))
raise DecodingError(str(exc)) from exc


class MultiDecoder(ContentDecoder):
Expand Down
80 changes: 31 additions & 49 deletions httpx/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@
import contextlib
import typing

import httpcore

if typing.TYPE_CHECKING:
from ._models import Request, Response # pragma: nocover

Expand All @@ -58,25 +56,39 @@ class HTTPError(Exception):
```
"""

def __init__(self, message: str, *, request: "Request") -> None:
def __init__(self, message: str) -> None:
super().__init__(message)
self.request = request


class RequestError(HTTPError):
"""
Base class for all exceptions that may occur when issuing a `.request()`.
"""

def __init__(self, message: str, *, request: "Request") -> None:
super().__init__(message, request=request)
def __init__(self, message: str, *, request: "Request" = None) -> None:
super().__init__(message)
# At the point an exception is raised we won't typically have a request
# instance to associate it with.
#
# The 'request_context' context manager is used within the Client and
# Response methods in order to ensure that any raised exceptions
# have a `.request` property set on them.
self._request = request

@property
def request(self) -> "Request":
if self._request is None:
raise RuntimeError("The .request property has not been set.")
return self._request

@request.setter
def request(self, request: "Request") -> None:
self._request = request


class TransportError(RequestError):
"""
Base class for all exceptions that occur at the level of the Transport API.
All of these exceptions also have an equivelent mapping in `httpcore`.
"""


Expand Down Expand Up @@ -219,7 +231,8 @@ class HTTPStatusError(HTTPError):
def __init__(
self, message: str, *, request: "Request", response: "Response"
) -> None:
super().__init__(message, request=request)
super().__init__(message)
self.request = request
self.response = response


Expand Down Expand Up @@ -318,45 +331,14 @@ def __init__(self) -> None:


@contextlib.contextmanager
def map_exceptions(
mapping: typing.Mapping[typing.Type[Exception], typing.Type[Exception]],
**kwargs: typing.Any,
) -> typing.Iterator[None]:
def request_context(request: "Request" = None) -> typing.Iterator[None]:
"""
A context manager that can be used to attach the given request context
to any `RequestError` exceptions that are raised within the block.
"""
try:
yield
except Exception as exc:
mapped_exc = None

for from_exc, to_exc in mapping.items():
if not isinstance(exc, from_exc):
continue
# We want to map to the most specific exception we can find.
# Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to
# `httpx.ReadTimeout`, not just `httpx.TimeoutException`.
if mapped_exc is None or issubclass(to_exc, mapped_exc):
mapped_exc = to_exc

if mapped_exc is None:
raise

message = str(exc)
raise mapped_exc(message, **kwargs) from exc # type: ignore


HTTPCORE_EXC_MAP = {
httpcore.TimeoutException: TimeoutException,
httpcore.ConnectTimeout: ConnectTimeout,
httpcore.ReadTimeout: ReadTimeout,
httpcore.WriteTimeout: WriteTimeout,
httpcore.PoolTimeout: PoolTimeout,
httpcore.NetworkError: NetworkError,
httpcore.ConnectError: ConnectError,
httpcore.ReadError: ReadError,
httpcore.WriteError: WriteError,
httpcore.CloseError: CloseError,
httpcore.ProxyError: ProxyError,
httpcore.UnsupportedProtocol: UnsupportedProtocol,
httpcore.ProtocolError: ProtocolError,
httpcore.LocalProtocolError: LocalProtocolError,
httpcore.RemoteProtocolError: RemoteProtocolError,
}
except RequestError as exc:
if request is not None:
exc.request = request
raise exc
38 changes: 13 additions & 25 deletions httpx/_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cgi
import contextlib
import datetime
import email.message
import json as jsonlib
Expand All @@ -24,16 +23,14 @@
TextDecoder,
)
from ._exceptions import (
HTTPCORE_EXC_MAP,
CookieConflict,
DecodingError,
HTTPStatusError,
InvalidURL,
RequestNotRead,
ResponseClosed,
ResponseNotRead,
StreamConsumed,
map_exceptions,
request_context,
)
from ._status_codes import codes
from ._types import (
Expand Down Expand Up @@ -1145,17 +1142,6 @@ def num_bytes_downloaded(self) -> int:
def __repr__(self) -> str:
return f"<Response [{self.status_code} {self.reason_phrase}]>"

@contextlib.contextmanager
def _wrap_decoder_errors(self) -> typing.Iterator[None]:
# If the response has an associated request instance, we want decoding
# errors to be raised as proper `httpx.DecodingError` exceptions.
try:
yield
except ValueError as exc:
if self._request is None:
raise exc
raise DecodingError(message=str(exc), request=self.request) from exc

def read(self) -> bytes:
"""
Read and return the response content.
Expand All @@ -1176,7 +1162,7 @@ def iter_bytes(self, chunk_size: int = None) -> typing.Iterator[bytes]:
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
with request_context(request=self._request):
for raw_bytes in self.iter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
Expand All @@ -1195,7 +1181,7 @@ def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]:
"""
decoder = TextDecoder(encoding=self.encoding)
chunker = TextChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
with request_context(request=self._request):
for byte_content in self.iter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
Expand All @@ -1208,7 +1194,7 @@ def iter_text(self, chunk_size: int = None) -> typing.Iterator[str]:

def iter_lines(self) -> typing.Iterator[str]:
decoder = LineDecoder()
with self._wrap_decoder_errors():
with request_context(request=self._request):
for text in self.iter_text():
for line in decoder.decode(text):
yield line
Expand All @@ -1230,7 +1216,7 @@ def iter_raw(self, chunk_size: int = None) -> typing.Iterator[bytes]:
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)

with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
with request_context(request=self._request):
for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
Expand All @@ -1249,7 +1235,8 @@ def close(self) -> None:
if not self.is_closed:
self.is_closed = True
if self._on_close is not None:
self._on_close(self)
with request_context(request=self._request):
self._on_close(self)

async def aread(self) -> bytes:
"""
Expand All @@ -1271,7 +1258,7 @@ async def aiter_bytes(self, chunk_size: int = None) -> typing.AsyncIterator[byte
else:
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
with request_context(request=self._request):
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
Expand All @@ -1290,7 +1277,7 @@ async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]:
"""
decoder = TextDecoder(encoding=self.encoding)
chunker = TextChunker(chunk_size=chunk_size)
with self._wrap_decoder_errors():
with request_context(request=self._request):
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
Expand All @@ -1303,7 +1290,7 @@ async def aiter_text(self, chunk_size: int = None) -> typing.AsyncIterator[str]:

async def aiter_lines(self) -> typing.AsyncIterator[str]:
decoder = LineDecoder()
with self._wrap_decoder_errors():
with request_context(request=self._request):
async for text in self.aiter_text():
for line in decoder.decode(text):
yield line
Expand All @@ -1325,7 +1312,7 @@ async def aiter_raw(self, chunk_size: int = None) -> typing.AsyncIterator[bytes]
self._num_bytes_downloaded = 0
chunker = ByteChunker(chunk_size=chunk_size)

with map_exceptions(HTTPCORE_EXC_MAP, request=self._request):
with request_context(request=self._request):
async for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
Expand All @@ -1344,7 +1331,8 @@ async def aclose(self) -> None:
if not self.is_closed:
self.is_closed = True
if self._on_close is not None:
await self._on_close(self)
with request_context(request=self._request):
await self._on_close(self)


class Cookies(MutableMapping):
Expand Down
Loading

0 comments on commit ee2a612

Please sign in to comment.