diff --git a/docs/advanced.md b/docs/advanced.md index 809d84383c..c407ef0fff 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -471,3 +471,160 @@ If you do need to make HTTPS connections to a local server, for example to test >>> r Response <200 OK> ``` + +## Middleware + +Middleware is a general-purpose mechanism for extending the built-in functionality of a `Client`. + +### Using middleware + +Middleware generally comes in the form of classes. Middleware classes and their configuration parameters are meant to be passed as a list of `httpx.Middleware` instances to a client: + +```python +from example.middleware import ExampleMiddleware +import httpx + +middleware = [ + httpx.Middleware(ExampleMiddleware, client_param="value", ...), +] + +with httpx.Client(middleware=middleware) as client: + # This request will pass through ExampleMiddleware before + # reaching the core processing layers of the Client. + r = client.get("https://example.org") +``` + +### Writing middleware + +Middleware classes should accept a `get_response` parameter, that represents the inner middleware, as well as any keyword arguments for middleware-specific configuration options. They should implement the `.__call__()` method, that accepts a `request`, `timeout` configuration, and any per-request keyword arguments coming directly from keyword arguments passed to `client.get()`, `client.post()`, etc. The `__call__()` method should return a generator that returns a `Response` instance. To get a response from the inner middleware, use `response = yield from get_response(...)`. + +#### Basic example + +Here is a "do-nothing" middleware that sends the request unmodified, and returns the response unmodified: + +```python +import httpx + +class PassThroughMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request, timeout, **kwargs): + return (yield from self.get_response(request, timeout, **kwargs)) +``` + +#### Inspecting requests and responses + +Here is a middleware that prints out information about the request and the response, as well as any client or request options: + +```python +import httpx + +class ExampleMiddleware: + def __init__(self, get_response, **options): + self.get_response = get_response + self.options = options + print(f"Client options: {self.options}") + + def __call__(self, request, timeout, **kwargs): + print(f"Response options: {kwargs}") + print(f"Request: {request}") + response = yield from self.get_response(request, timeout, **kwargs) + print(f"Response: {response}") + return response + +middleware = [httpx.Middleware(ExampleMiddleware, client_option="example")] + +with httpx.Client(middleware=middleware) as client: + print("Sending request...") + r = client.get("https://example.org", request_option="example") + print("Got response") +``` + +Output: + +```console +Client options: {'client_option': 'example'} +Sending request... +Request: +Request options: {'request_option': 'example'} +Options: {'param': 'value'} +Response: +Got response +``` + +#### Sending multiple requests + +Middleware can use the `yield from get_response()` construct multiple times to send multiple requests. + +This can be useful to implement behaviors such as retries, mechanisms involving web hooks, and other advanced features. + +#### Example: retry middleware + +The example below shows how to implement a general-purpose retry middleware based on the excellent [Tenacity](https://github.com/jd/tenacity) library, including type annotations usage. + +```python +# retries.py +from typing import Any, Callable, Generator + +import httpx +import tenacity + + +class RetryingMiddleware: + def __init__(self, get_response: Callable, *, retrying: tenacity.Retrying) -> None: + self.get_response = get_response + self.retrying = retrying + + def __call__( + self, request: httpx.Request, timeout: httpx.Timeout, **kwargs: Any + ) -> Generator[Any, Any, httpx.Response]: + # Allow overriding the retries algorithm per-request. + retrying = self.retrying if "retrying" not in kwargs else kwargs["retrying"] + + try: + for attempt in retrying: + with attempt: + response = yield from self.get_response(request, timeout, **kwargs) + break + print("Failed!") + except tenacity.RetryError as exc: + # Wrap as an HTTPX-specific exception. + raise httpx.HTTPError(exc, request=request) + else: + return response +``` + +Usage: + +```python +import httpx +import tenacity + +from .retries import RetryingMiddleware + +middleware = [ + httpx.Middleware( + RetryingMiddleware, + retrying=tenacity.Retrying( + retry=tenacity.retry_if_exception_type(httpx._exceptions.NetworkError), + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(multiplier=1, min=0.5, max=10), + ), + ) +] + +with httpx.Client(middleware=middleware) as client: + # Network failures on this request (such as failures to establish + # a connection, or failures to keep the connection open) will be + # retried on at most 3 times. + r = client.get("https://doesnotexist.org") + + # Read timeouts on this request will be retried on at most 5 times, + # with a constant 200ms delay between retries. + retry_on_read_timeouts = tenacity.Retrying( + retry=tenacity.retry_if_exception_type(httpx.ReadTimeout), + wait=tenacity.wait_fixed(0.2), + ) + r = client.get("https://flakyserver.io", retrying=retry_on_read_timeouts) +``` diff --git a/httpx/__init__.py b/httpx/__init__.py index 55f8526c59..b8364e1df6 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -27,6 +27,7 @@ TooManyRedirects, WriteTimeout, ) +from ._middleware import Middleware from ._models import URL, Cookies, Headers, QueryParams, Request, Response from ._status_codes import StatusCode, codes @@ -83,4 +84,5 @@ "Response", "DigestAuth", "WSGIDispatch", + "Middleware", ] diff --git a/httpx/_client.py b/httpx/_client.py index 109c8351f8..d7d6a12f8b 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -34,6 +34,7 @@ RequestBodyUnavailable, TooManyRedirects, ) +from ._middleware import Middleware, MiddlewareStack from ._models import ( URL, Cookies, @@ -50,7 +51,7 @@ URLTypes, ) from ._status_codes import codes -from ._utils import NetRCInfo, get_environment_proxies, get_logger +from ._utils import NetRCInfo, consume_generator, get_environment_proxies, get_logger logger = get_logger(__name__) @@ -451,6 +452,7 @@ def __init__( dispatch: SyncDispatcher = None, app: typing.Callable = None, trust_env: bool = True, + middleware: typing.Sequence[Middleware] = None, ): super().__init__( auth=auth, @@ -484,6 +486,13 @@ def __init__( for key, proxy in proxy_map.items() } + def get_response( + request: Request, timeout: Timeout, **kwargs: typing.Any + ) -> typing.Generator: + yield self.send_handling_redirects(request, timeout=timeout, **kwargs) + + self.middleware_stack = MiddlewareStack(get_response, middleware) + def init_dispatch( self, verify: VerifyTypes = True, @@ -558,6 +567,7 @@ def request( auth: AuthTypes = None, allow_redirects: bool = True, timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, + **kwargs: typing.Any, ) -> Response: request = self.build_request( method=method, @@ -570,7 +580,11 @@ def request( cookies=cookies, ) return self.send( - request, auth=auth, allow_redirects=allow_redirects, timeout=timeout, + request, + auth=auth, + allow_redirects=allow_redirects, + timeout=timeout, + **kwargs, ) def send( @@ -581,6 +595,7 @@ def send( auth: AuthTypes = None, allow_redirects: bool = True, timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, + **kwargs: typing.Any, ) -> Response: if request.url.scheme not in ("http", "https"): raise InvalidURL('URL scheme must be "http" or "https".') @@ -589,8 +604,10 @@ def send( auth = self.build_auth(request, auth) - response = self.send_handling_redirects( - request, auth=auth, timeout=timeout, allow_redirects=allow_redirects, + response = consume_generator( + self.middleware_stack( + request, timeout, auth=auth, allow_redirects=allow_redirects, **kwargs + ) ) if not stream: @@ -705,6 +722,7 @@ def get( auth: AuthTypes = None, allow_redirects: bool = True, timeout: typing.Union[TimeoutTypes, UnsetType] = UNSET, + **kwargs: typing.Any, ) -> Response: return self.request( "GET", @@ -715,6 +733,7 @@ def get( auth=auth, allow_redirects=allow_redirects, timeout=timeout, + **kwargs, ) def options( diff --git a/httpx/_middleware.py b/httpx/_middleware.py new file mode 100644 index 0000000000..0a48b44d69 --- /dev/null +++ b/httpx/_middleware.py @@ -0,0 +1,51 @@ +import typing + +from ._config import Timeout +from ._models import Request, Response +from ._utils import get_logger + +logger = get_logger(__name__) + + +class MiddlewareInstance(typing.Protocol): + def __call__( + self, request: Request, timeout: Timeout, **kwargs: typing.Any + ) -> typing.Generator[typing.Any, typing.Any, Response]: + ... + + +MiddlewareType = typing.Callable[[MiddlewareInstance], MiddlewareInstance] + + +class Middleware: + def __init__(self, middleware: typing.Callable, **kwargs: typing.Any) -> None: + self.middleware = middleware + self.kwargs = kwargs + + def __call__(self, get_response: MiddlewareInstance) -> MiddlewareInstance: + return self.middleware(get_response, **self.kwargs) + + +class MiddlewareStack: + """ + Container for representing a stack of middleware classes. + """ + + def __init__( + self, + get_response: MiddlewareInstance, + middleware: typing.Sequence[Middleware] = None, + ) -> None: + self.get_response = get_response + self.middleware = list(middleware) if middleware is not None else [] + + def __call__( + self, request: Request, timeout: Timeout, **kwargs: typing.Any + ) -> typing.Generator[typing.Any, typing.Any, Response]: + if not hasattr(self, "_stack"): + get_response = self.get_response + for middleware in self.middleware: + get_response = middleware(get_response) + self._stack = get_response + + return self._stack(request, timeout, **kwargs) diff --git a/httpx/_utils.py b/httpx/_utils.py index b884f8193e..6c7fcd0c91 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -20,6 +20,8 @@ from ._models import URL +T = typing.TypeVar("T") + _HTML5_FORM_ENCODING_REPLACEMENTS = {'"': "%22", "\\": "\\\\"} _HTML5_FORM_ENCODING_REPLACEMENTS.update( {chr(c): "%{:02X}".format(c) for c in range(0x00, 0x1F + 1) if c != 0x1B} @@ -367,3 +369,19 @@ def as_network_error(*exception_classes: type) -> typing.Iterator[None]: if isinstance(exc, cls): raise NetworkError(exc) from exc raise + + +def consume_generator(gen: typing.Generator[typing.Any, typing.Any, T]) -> T: + """ + Run a generator to completion and return the result, assuming that yielded + values are synchronous (i.e. they're not coroutines). + """ + value: typing.Any = None + + while True: + try: + value = gen.send(value) + except StopIteration as exc: + return exc.value + except BaseException as exc: + value = gen.throw(type(exc), exc, exc.__traceback__)