Skip to content

Commit

Permalink
Refactor middleware interface
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Aug 22, 2019
1 parent 86dc0a4 commit 5ae8b13
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 147 deletions.
100 changes: 52 additions & 48 deletions httpx/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import inspect
import typing
from types import TracebackType
Expand All @@ -18,10 +19,8 @@
)
from .dispatch.asgi import ASGIDispatch
from .dispatch.base import AsyncDispatcher, Dispatcher
from .dispatch.basic_auth import BasicAuthDispatcher
from .dispatch.connection_pool import ConnectionPool
from .dispatch.custom_auth import CustomAuthDispatcher
from .dispatch.redirect import RedirectDispatcher
from .dispatch.middleware import Middleware, basic_auth, custom_auth, redirect
from .dispatch.threaded import ThreadedDispatcher
from .dispatch.wsgi import WSGIDispatch
from .exceptions import HTTPError, InvalidURL
Expand Down Expand Up @@ -131,31 +130,16 @@ def merge_headers(
return merged_headers
return headers

async def send(
async def get_response(
self,
request: AsyncRequest,
*,
stream: bool = False,
auth: AuthTypes = None,
allow_redirects: bool = True,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
trust_env: typing.Optional[bool] = None,
) -> AsyncResponse:
url = request.url
if url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')

dispatcher: AsyncDispatcher = self._resolve_dispatcher(
request,
auth or self.auth,
self.trust_env if trust_env is None else trust_env,
allow_redirects,
)

try:
response = await dispatcher.send(
response = await self.dispatch.send(
request, verify=verify, cert=cert, timeout=timeout
)
except HTTPError as exc:
Expand All @@ -172,42 +156,62 @@ async def send(

return response

def _resolve_dispatcher(
async def send(
self,
request: AsyncRequest,
*,
stream: bool = False,
auth: AuthTypes = None,
trust_env: bool = False,
allow_redirects: bool = True,
) -> AsyncDispatcher:
dispatcher: AsyncDispatcher = RedirectDispatcher(
next_dispatcher=self.dispatch,
base_cookies=self.cookies,
allow_redirects=allow_redirects,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
trust_env: bool = None,
) -> AsyncResponse:
if request.url.scheme not in ("http", "https"):
raise InvalidURL('URL scheme must be "http" or "https".')

get_response = functools.partial(
self.get_response, stream=stream, verify=verify, cert=cert, timeout=timeout
)

username: typing.Optional[typing.Union[str, bytes]] = None
password: typing.Optional[typing.Union[str, bytes]] = None
if auth is None:
if request.url.username or request.url.password:
username, password = request.url.username, request.url.password
elif trust_env:
netrc_login = get_netrc_login(request.url.authority)
if netrc_login:
username, _, password = netrc_login
else:
if isinstance(auth, tuple):
username, password = auth[0], auth[1]
elif callable(auth):
dispatcher = CustomAuthDispatcher(
next_dispatcher=dispatcher, auth_callable=auth
)

if username is not None and password is not None:
dispatcher = BasicAuthDispatcher(
next_dispatcher=dispatcher, username=username, password=password
get_response = functools.partial(
redirect(allow_redirects, cookies=self.cookies), get_response=get_response
)

auth_middleware = self._get_auth_middleware(
request=request,
trust_env=self.trust_env if trust_env is None else trust_env,
auth=self.auth if auth is None else auth,
)
if auth_middleware is not None:
get_response = functools.partial(auth_middleware, get_response=get_response)

return await get_response(request)

def _get_auth_middleware(
self, request: AsyncRequest, trust_env: bool, auth: AuthTypes = None
) -> typing.Optional[Middleware]:
if isinstance(auth, tuple):
return basic_auth(username=auth[0], password=auth[1])

if callable(auth):
return custom_auth(auth=auth)

assert auth is None

if request.url.username or request.url.password:
return basic_auth(
username=request.url.username, password=request.url.password
)

return dispatcher
if trust_env:
netrc_login = get_netrc_login(request.url.authority)
if netrc_login:
username, _, password = netrc_login
return basic_auth(username=username, password=password)

return None


class AsyncClient(BaseClient):
Expand Down
43 changes: 0 additions & 43 deletions httpx/dispatch/basic_auth.py

This file was deleted.

27 changes: 0 additions & 27 deletions httpx/dispatch/custom_auth.py

This file was deleted.

88 changes: 59 additions & 29 deletions httpx/dispatch/redirect.py → httpx/dispatch/middleware.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,91 @@
import functools
import typing
from base64 import b64encode

from ..config import DEFAULT_MAX_REDIRECTS, CertTypes, TimeoutTypes, VerifyTypes
from ..config import DEFAULT_MAX_REDIRECTS
from ..exceptions import RedirectBodyUnavailable, RedirectLoop, TooManyRedirects
from ..models import URL, AsyncRequest, AsyncResponse, Cookies, Headers
from ..status_codes import codes
from .base import AsyncDispatcher

Responder = typing.Callable[[AsyncRequest], typing.Coroutine[None, None, AsyncResponse]]
Middleware = typing.Callable[
[AsyncRequest, typing.Callable], typing.Coroutine[None, None, AsyncResponse]
]

class RedirectDispatcher(AsyncDispatcher):

def basic_auth(
username: typing.Union[str, bytes], password: typing.Union[str, bytes]
) -> Middleware:
if isinstance(username, str):
username = username.encode("latin1")

if isinstance(password, str):
password = password.encode("latin1")

userpass = b":".join((username, password))
token = b64encode(userpass).decode().strip()
print(username, password, userpass, token)
authorization_header = f"Basic {token}"

async def dispatch(request: AsyncRequest, get_response: Responder) -> AsyncResponse:
request.headers["Authorization"] = authorization_header
return await get_response(request)

return dispatch


def custom_auth(auth: typing.Callable[[AsyncRequest], AsyncRequest]) -> Middleware:
async def dispatch(request: AsyncRequest, get_response: Responder) -> AsyncResponse:
request = auth(request)
return await get_response(request)

return dispatch


def redirect(
allow_redirects: bool = True,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
cookies: typing.Optional[Cookies] = None,
) -> Middleware:
return Redirect(
allow_redirects=allow_redirects, max_redirects=max_redirects, cookies=cookies
).dispatch


class Redirect:
def __init__(
self,
next_dispatcher: AsyncDispatcher,
allow_redirects: bool = True,
max_redirects: int = DEFAULT_MAX_REDIRECTS,
base_cookies: typing.Optional[Cookies] = None,
cookies: typing.Optional[Cookies] = None,
):
self.next_dispatcher = next_dispatcher
self.allow_redirects = allow_redirects
self.max_redirects = max_redirects
self.base_cookies = base_cookies
self.history = [] # type: list
self.cookies = cookies
self.history: typing.List[AsyncResponse] = []

async def send(
self,
request: AsyncRequest,
verify: VerifyTypes = None,
cert: CertTypes = None,
timeout: TimeoutTypes = None,
async def dispatch(
self, request: AsyncRequest, get_response: Responder
) -> AsyncResponse:
if len(self.history) > self.max_redirects:
raise TooManyRedirects()
if request.url in (response.url for response in self.history):
raise RedirectLoop()

response = await self.next_dispatcher.send(
request, verify=verify, cert=cert, timeout=timeout
)
response = await get_response(request)
response.history = list(self.history)

if not response.is_redirect:
return response

self.history.append(response)
next_request = self.build_redirect_request(request, response)
if self.allow_redirects:
return await self.send(
next_request, verify=verify, cert=cert, timeout=timeout
)
else:

async def send_next() -> AsyncResponse:
return await self.send(
next_request, verify=verify, cert=cert, timeout=timeout
)
if self.allow_redirects:
return await self.dispatch(next_request, get_response)

response.next = send_next # type: ignore
return response
response.next = functools.partial(self.dispatch, next_request, get_response)
return response

def build_redirect_request(
self, request: AsyncRequest, response: AsyncResponse
Expand All @@ -64,7 +94,7 @@ def build_redirect_request(
url = self.redirect_url(request, response)
headers = self.redirect_headers(request, url) # TODO: merge headers?
content = self.redirect_content(request, method)
cookies = Cookies(self.base_cookies)
cookies = Cookies(self.cookies)
cookies.update(request.cookies)
return AsyncRequest(
method=method, url=url, headers=headers, data=content, cookies=cookies
Expand Down

0 comments on commit 5ae8b13

Please sign in to comment.