Skip to content

Commit

Permalink
Refactor middleware interface
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Sep 1, 2019
1 parent c829891 commit 71ed2b6
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 160 deletions.
126 changes: 68 additions & 58 deletions httpx/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import inspect
import typing
from types import TracebackType
Expand All @@ -18,13 +19,16 @@
)
from .dispatch.asgi import ASGIDispatch
from .dispatch.base import AsyncDispatcher, Dispatcher
from .dispatch.basic_auth import BasicAuthDispatcher
from .dispatch.connection_pool import ConnectionPool
from .dispatch.custom_auth import CustomAuthDispatcher
from .dispatch.redirect import RedirectDispatcher
from .dispatch.threaded import ThreadedDispatcher
from .dispatch.wsgi import WSGIDispatch
from .exceptions import HTTPError, InvalidURL
from .middleware import (
BaseMiddleware,
BasicAuthMiddleware,
CustomAuthMiddleware,
RedirectMiddleware,
)
from .models import (
URL,
AsyncRequest,
Expand Down Expand Up @@ -157,73 +161,79 @@ async def send(
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
trust_env: typing.Optional[bool] = None,
trust_env: bool = None,
) -> AsyncResponse:
url = request.url
if url.scheme not in ("http", "https"):
if request.url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')

dispatcher: AsyncDispatcher = self._resolve_dispatcher(
request,
auth or self.auth,
self.trust_env if trust_env is None else trust_env,
allow_redirects,
async def get_response(request: AsyncRequest) -> AsyncResponse:
try:
response = await self.dispatch.send(
request, verify=verify, cert=cert, timeout=timeout
)
except HTTPError as exc:
# Add the original request to any HTTPError
exc.request = request
raise

self.cookies.extract_cookies(response)
if not stream:
try:
await response.read()
finally:
await response.close()

return response

def wrap(
get_response: typing.Callable, middleware: BaseMiddleware
) -> typing.Callable:
return functools.partial(middleware, get_response=get_response)

get_response = wrap(
get_response,
RedirectMiddleware(allow_redirects=allow_redirects, cookies=self.cookies),
)

try:
response = await dispatcher.send(
request, verify=verify, cert=cert, timeout=timeout
)
except HTTPError as exc:
# Add the original request to any HTTPError
exc.request = request
raise
auth_middleware = self._get_auth_middleware(
request=request,
trust_env=self.trust_env if trust_env is None else trust_env,
auth=self.auth if auth is None else auth,
)

self.cookies.extract_cookies(response)
if not stream:
try:
await response.read()
finally:
await response.close()
if auth_middleware is not None:
get_response = wrap(get_response, auth_middleware)

return response
return await get_response(request)

def _resolve_dispatcher(
self,
request: AsyncRequest,
auth: AuthTypes = None,
trust_env: bool = False,
allow_redirects: bool = True,
) -> AsyncDispatcher:
dispatcher: AsyncDispatcher = RedirectDispatcher(
next_dispatcher=self.dispatch,
base_cookies=self.cookies,
allow_redirects=allow_redirects,
)
def _get_auth_middleware(
self, request: AsyncRequest, trust_env: bool, auth: AuthTypes = None
) -> typing.Optional[BaseMiddleware]:
if isinstance(auth, tuple):
return BasicAuthMiddleware(username=auth[0], password=auth[1])

username: typing.Optional[typing.Union[str, bytes]] = None
password: typing.Optional[typing.Union[str, bytes]] = None
if auth is None:
if request.url.username or request.url.password:
username, password = request.url.username, request.url.password
elif trust_env:
netrc_login = get_netrc_login(request.url.authority)
if netrc_login:
username, _, password = netrc_login
else:
if isinstance(auth, tuple):
username, password = auth[0], auth[1]
elif callable(auth):
dispatcher = CustomAuthDispatcher(
next_dispatcher=dispatcher, auth_callable=auth
)
if callable(auth):
return CustomAuthMiddleware(auth=auth)

if username is not None and password is not None:
dispatcher = BasicAuthDispatcher(
next_dispatcher=dispatcher, username=username, password=password
if auth is not None:
raise TypeError(
'When specified, "auth" must be a (username, password) tuple or '
"a callable with signature (AsyncRequest) -> AsyncRequest "
f"(got {auth!r})"
)

return dispatcher
if request.url.username or request.url.password:
return BasicAuthMiddleware(
username=request.url.username, password=request.url.password
)

if trust_env:
netrc_login = get_netrc_login(request.url.authority)
if netrc_login:
username, _, password = netrc_login
return BasicAuthMiddleware(username=username, password=password)

return None


class AsyncClient(BaseClient):
Expand Down
43 changes: 0 additions & 43 deletions httpx/dispatch/basic_auth.py

This file was deleted.

27 changes: 0 additions & 27 deletions httpx/dispatch/custom_auth.py

This file was deleted.

90 changes: 58 additions & 32 deletions httpx/dispatch/redirect.py → httpx/middleware.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,87 @@
import functools
import typing
from base64 import b64encode

from ..config import DEFAULT_MAX_REDIRECTS, CertTypes, TimeoutTypes, VerifyTypes
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from ..models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
from ..status_codes import codes
from .base import AsyncDispatcher
from .config import DEFAULT_MAX_REDIRECTS
from .exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from .models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
from .status_codes import codes


class RedirectDispatcher(AsyncDispatcher):
class BaseMiddleware:
async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
raise NotImplementedError # pragma: no cover


class BasicAuthMiddleware(BaseMiddleware):
def __init__(
self, username: typing.Union[str, bytes], password: typing.Union[str, bytes]
):
if isinstance(username, str):
username = username.encode("latin1")

if isinstance(password, str):
password = password.encode("latin1")

userpass = b":".join((username, password))
token = b64encode(userpass).decode().strip()

self.authorization_header = f"Basic {token}"

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
request.headers["Authorization"] = self.authorization_header
return await get_response(request)


class CustomAuthMiddleware(BaseMiddleware):
def __init__(self, auth: typing.Callable[[AsyncRequest], AsyncRequest]):
self.auth = auth

async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
request = self.auth(request)
return await get_response(request)


class RedirectMiddleware(BaseMiddleware):
def __init__(
self,
next_dispatcher: AsyncDispatcher,
allow_redirects: bool = True,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
base_cookies: typing.Optional[Cookies] = None,
cookies: typing.Optional[Cookies] = None,
):
self.next_dispatcher = next_dispatcher
self.allow_redirects = allow_redirects
self.max_redirects = max_redirects
self.base_cookies = base_cookies
self.history = [] # type: list
self.cookies = cookies
self.history: typing.List[AsyncResponse] = []

async def send(
self,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
async def __call__(
self, request: AsyncRequest, get_response: typing.Callable
) -> AsyncResponse:
if len(self.history) > self.max_redirects:
raise TooManyRedirects()
if request.url in (response.url for response in self.history):
raise RedirectLoop()

response = await self.next_dispatcher.send(
request, verify=verify, cert=cert, timeout=timeout
)
response = await get_response(request)
response.history = list(self.history)

if not response.is_redirect:
return response

self.history.append(response)
next_request = self.build_redirect_request(request, response)
if self.allow_redirects:
return await self.send(
next_request, verify=verify, cert=cert, timeout=timeout
)
else:

async def call_next() -> AsyncResponse:
return await self.send(
next_request, verify=verify, cert=cert, timeout=timeout
)
if self.allow_redirects:
return await self(next_request, get_response)

response.call_next = call_next # type: ignore
return response
response.call_next = functools.partial(self, next_request, get_response)
return response

def build_redirect_request(
self, request: AsyncRequest, response: AsyncResponse
Expand All @@ -64,7 +90,7 @@ def build_redirect_request(
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url) # TODO: merge headers?
content = self.redirect_content(request, method)
cookies = Cookies(self.base_cookies)
cookies = Cookies(self.cookies)
cookies.update(request.cookies)
return AsyncRequest(
method=method, url=url, headers=headers, data=content, cookies=cookies
Expand Down
9 changes: 9 additions & 0 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os

import pytest

from httpx import (
URL,
AsyncDispatcher,
Expand Down Expand Up @@ -118,3 +120,10 @@ def test_auth_hidden_header():
response = client.get(url, auth=auth)

assert "'authorization': '[secure]'" in str(response.request.headers)


def test_auth_invalid_type():
url = "https://example.org/"
with Client(dispatch=MockDispatch(), auth="not a tuple, not a callable") as client:
with pytest.raises(TypeError):
client.get(url)

0 comments on commit 71ed2b6

Please sign in to comment.