diff --git a/httpx/client.py b/httpx/client.py index e61761e6e4..2d13269db4 100644 --- a/httpx/client.py +++ b/httpx/client.py @@ -1,3 +1,4 @@ +import functools import inspect import typing from types import TracebackType @@ -18,13 +19,16 @@ ) from .dispatch.asgi import ASGIDispatch from .dispatch.base import AsyncDispatcher, Dispatcher -from .dispatch.basic_auth import BasicAuthDispatcher from .dispatch.connection_pool import ConnectionPool -from .dispatch.custom_auth import CustomAuthDispatcher -from .dispatch.redirect import RedirectDispatcher from .dispatch.threaded import ThreadedDispatcher from .dispatch.wsgi import WSGIDispatch from .exceptions import HTTPError, InvalidURL +from .middleware import ( + BaseMiddleware, + BasicAuthMiddleware, + CustomAuthMiddleware, + RedirectMiddleware, +) from .models import ( URL, AsyncRequest, @@ -157,73 +161,79 @@ async def send( verify: VerifyTypes = None, cert: CertTypes = None, timeout: TimeoutTypes = None, - trust_env: typing.Optional[bool] = None, + trust_env: bool = None, ) -> AsyncResponse: - url = request.url - if url.scheme not in ("http", "https"): + if request.url.scheme not in ("http", "https"): raise InvalidURL('URL scheme must be "http" or "https".') - dispatcher: AsyncDispatcher = self._resolve_dispatcher( - request, - auth or self.auth, - self.trust_env if trust_env is None else trust_env, - allow_redirects, + async def get_response(request: AsyncRequest) -> AsyncResponse: + try: + response = await self.dispatch.send( + request, verify=verify, cert=cert, timeout=timeout + ) + except HTTPError as exc: + # Add the original request to any HTTPError + exc.request = request + raise + + self.cookies.extract_cookies(response) + if not stream: + try: + await response.read() + finally: + await response.close() + + return response + + def wrap( + get_response: typing.Callable, middleware: BaseMiddleware + ) -> typing.Callable: + return functools.partial(middleware, get_response=get_response) + + get_response = wrap( + get_response, + RedirectMiddleware(allow_redirects=allow_redirects, cookies=self.cookies), ) - try: - response = await dispatcher.send( - request, verify=verify, cert=cert, timeout=timeout - ) - except HTTPError as exc: - # Add the original request to any HTTPError - exc.request = request - raise + auth_middleware = self._get_auth_middleware( + request=request, + trust_env=self.trust_env if trust_env is None else trust_env, + auth=self.auth if auth is None else auth, + ) - self.cookies.extract_cookies(response) - if not stream: - try: - await response.read() - finally: - await response.close() + if auth_middleware is not None: + get_response = wrap(get_response, auth_middleware) - return response + return await get_response(request) - def _resolve_dispatcher( - self, - request: AsyncRequest, - auth: AuthTypes = None, - trust_env: bool = False, - allow_redirects: bool = True, - ) -> AsyncDispatcher: - dispatcher: AsyncDispatcher = RedirectDispatcher( - next_dispatcher=self.dispatch, - base_cookies=self.cookies, - allow_redirects=allow_redirects, - ) + def _get_auth_middleware( + self, request: AsyncRequest, trust_env: bool, auth: AuthTypes = None + ) -> typing.Optional[BaseMiddleware]: + if isinstance(auth, tuple): + return BasicAuthMiddleware(username=auth[0], password=auth[1]) - username: typing.Optional[typing.Union[str, bytes]] = None - password: typing.Optional[typing.Union[str, bytes]] = None - if auth is None: - if request.url.username or request.url.password: - username, password = request.url.username, request.url.password - elif trust_env: - netrc_login = get_netrc_login(request.url.authority) - if netrc_login: - username, _, password = netrc_login - else: - if isinstance(auth, tuple): - username, password = auth[0], auth[1] - elif callable(auth): - dispatcher = CustomAuthDispatcher( - next_dispatcher=dispatcher, auth_callable=auth - ) + if callable(auth): + return CustomAuthMiddleware(auth=auth) - if username is not None and password is not None: - dispatcher = BasicAuthDispatcher( - next_dispatcher=dispatcher, username=username, password=password + if auth is not None: + raise TypeError( + 'When specified, "auth" must be a (username, password) tuple or ' + "a callable with signature (AsyncRequest) -> AsyncRequest " + f"(got {auth!r})" ) - return dispatcher + if request.url.username or request.url.password: + return BasicAuthMiddleware( + username=request.url.username, password=request.url.password + ) + + if trust_env: + netrc_login = get_netrc_login(request.url.authority) + if netrc_login: + username, _, password = netrc_login + return BasicAuthMiddleware(username=username, password=password) + + return None class AsyncClient(BaseClient): diff --git a/httpx/dispatch/basic_auth.py b/httpx/dispatch/basic_auth.py deleted file mode 100644 index d806988e4a..0000000000 --- a/httpx/dispatch/basic_auth.py +++ /dev/null @@ -1,43 +0,0 @@ -import typing -from base64 import b64encode - -from ..config import CertTypes, TimeoutTypes, VerifyTypes -from ..models import AsyncRequest, AsyncResponse -from .base import AsyncDispatcher - - -class BasicAuthDispatcher(AsyncDispatcher): - def __init__( - self, - next_dispatcher: AsyncDispatcher, - username: typing.Union[str, bytes], - password: typing.Union[str, bytes], - ): - self.next_dispatcher = next_dispatcher - self.username = username - self.password = password - - async def send( - self, - request: AsyncRequest, - verify: VerifyTypes = None, - cert: CertTypes = None, - timeout: TimeoutTypes = None, - ) -> AsyncResponse: - request.headers["Authorization"] = self.build_auth_header() - return await self.next_dispatcher.send( - request, verify=verify, cert=cert, timeout=timeout - ) - - def build_auth_header(self) -> str: - username, password = self.username, self.password - - if isinstance(username, str): - username = username.encode("latin1") - - if isinstance(password, str): - password = password.encode("latin1") - - userpass = b":".join((username, password)) - token = b64encode(userpass).decode().strip() - return f"Basic {token}" diff --git a/httpx/dispatch/custom_auth.py b/httpx/dispatch/custom_auth.py deleted file mode 100644 index 9f42524465..0000000000 --- a/httpx/dispatch/custom_auth.py +++ /dev/null @@ -1,27 +0,0 @@ -import typing - -from ..config import CertTypes, TimeoutTypes, VerifyTypes -from ..models import AsyncRequest, AsyncResponse -from .base import AsyncDispatcher - - -class CustomAuthDispatcher(AsyncDispatcher): - def __init__( - self, - next_dispatcher: AsyncDispatcher, - auth_callable: typing.Callable[[AsyncRequest], AsyncRequest], - ): - self.next_dispatcher = next_dispatcher - self.auth_callable = auth_callable - - async def send( - self, - request: AsyncRequest, - verify: VerifyTypes = None, - cert: CertTypes = None, - timeout: TimeoutTypes = None, - ) -> AsyncResponse: - request = self.auth_callable(request) - return await self.next_dispatcher.send( - request, verify=verify, cert=cert, timeout=timeout - ) diff --git a/httpx/dispatch/redirect.py b/httpx/middleware.py similarity index 63% rename from httpx/dispatch/redirect.py rename to httpx/middleware.py index 1683c24052..4ed750e6bb 100644 --- a/httpx/dispatch/redirect.py +++ b/httpx/middleware.py @@ -1,41 +1,74 @@ +import functools import typing +from base64 import b64encode -from ..config import DEFAULT_MAX_REDIRECTS, CertTypes, TimeoutTypes, VerifyTypes -from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects -from ..models import URL, AsyncRequest, AsyncResponse, Cookies, Headers -from ..status_codes import codes -from .base import AsyncDispatcher +from .config import DEFAULT_MAX_REDIRECTS +from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects +from .models import URL, AsyncRequest, AsyncResponse, Cookies, Headers +from .status_codes import codes -class RedirectDispatcher(AsyncDispatcher): +class BaseMiddleware: + async def __call__( + self, request: AsyncRequest, get_response: typing.Callable + ) -> AsyncResponse: + raise NotImplementedError # pragma: no cover + + +class BasicAuthMiddleware(BaseMiddleware): + def __init__( + self, username: typing.Union[str, bytes], password: typing.Union[str, bytes] + ): + if isinstance(username, str): + username = username.encode("latin1") + + if isinstance(password, str): + password = password.encode("latin1") + + userpass = b":".join((username, password)) + token = b64encode(userpass).decode().strip() + + self.authorization_header = f"Basic {token}" + + async def __call__( + self, request: AsyncRequest, get_response: typing.Callable + ) -> AsyncResponse: + request.headers["Authorization"] = self.authorization_header + return await get_response(request) + + +class CustomAuthMiddleware(BaseMiddleware): + def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]): + self.auth = auth + + async def __call__( + self, request: AsyncRequest, get_response: typing.Callable + ) -> AsyncResponse: + request = self.auth(request) + return await get_response(request) + + +class RedirectMiddleware(BaseMiddleware): def __init__( self, - next_dispatcher: AsyncDispatcher, allow_redirects: bool = True, max_redirects: int = DEFAULT_MAX_REDIRECTS, - base_cookies: typing.Optional[Cookies] = None, + cookies: typing.Optional[Cookies] = None, ): - self.next_dispatcher = next_dispatcher self.allow_redirects = allow_redirects self.max_redirects = max_redirects - self.base_cookies = base_cookies - self.history = [] # type: list + self.cookies = cookies + self.history: typing.List[AsyncResponse] = [] - async def send( - self, - request: AsyncRequest, - verify: VerifyTypes = None, - cert: CertTypes = None, - timeout: TimeoutTypes = None, + async def __call__( + self, request: AsyncRequest, get_response: typing.Callable ) -> AsyncResponse: if len(self.history) > self.max_redirects: raise TooManyRedirects() if request.url in (response.url for response in self.history): raise RedirectLoop() - response = await self.next_dispatcher.send( - request, verify=verify, cert=cert, timeout=timeout - ) + response = await get_response(request) response.history = list(self.history) if not response.is_redirect: @@ -43,19 +76,12 @@ async def send( self.history.append(response) next_request = self.build_redirect_request(request, response) - if self.allow_redirects: - return await self.send( - next_request, verify=verify, cert=cert, timeout=timeout - ) - else: - async def call_next() -> AsyncResponse: - return await self.send( - next_request, verify=verify, cert=cert, timeout=timeout - ) + if self.allow_redirects: + return await self(next_request, get_response) - response.call_next = call_next # type: ignore - return response + response.call_next = functools.partial(self, next_request, get_response) + return response def build_redirect_request( self, request: AsyncRequest, response: AsyncResponse @@ -64,7 +90,7 @@ def build_redirect_request( url = self.redirect_url(request, response) headers = self.redirect_headers(request, url) # TODO: merge headers? content = self.redirect_content(request, method) - cookies = Cookies(self.base_cookies) + cookies = Cookies(self.cookies) cookies.update(request.cookies) return AsyncRequest( method=method, url=url, headers=headers, data=content, cookies=cookies diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 725ea56c00..fc3b192f49 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,6 +1,8 @@ import json import os +import pytest + from httpx import ( URL, AsyncDispatcher, @@ -118,3 +120,10 @@ def test_auth_hidden_header(): response = client.get(url, auth=auth) assert "'authorization': '[secure]'" in str(response.request.headers) + + +def test_auth_invalid_type(): + url = "https://example.org/" + with Client(dispatch=MockDispatch(), auth="not a tuple, not a callable") as client: + with pytest.raises(TypeError): + client.get(url)