Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trio concurrency backend #276

Merged
merged 10 commits into from
Sep 21, 2019
2 changes: 1 addition & 1 deletion httpx/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
if param_count == 2:
dispatch = WSGIDispatch(app=app)
else:
dispatch = ASGIDispatch(app=app)
dispatch = ASGIDispatch(app=app, backend=backend)

self.trust_env = True if trust_env is None else trust_env

Expand Down
14 changes: 14 additions & 0 deletions httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ async def write(
raise WriteTimeout() from None

def is_connection_dropped(self) -> bool:
# Counter-intuitively, what we really want to know here is whether the socket is
# *readable*, i.e. whether it would return immediately with empty bytes if we
# called `.recv()` on it, indicating that the other end has closed the socket.
# See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
#
# As it turns out, asyncio checks for readability in the background
# (see: https://github.com/encode/httpx/pull/276#discussion_r322000402),
# so checking for EOF or readability here would yield the same result.
#
# At the cost of rigour, we check for EOF instead of readability because asyncio
# does not expose any public API to check for readability.
# (For a solution that uses private asyncio APIs, see:
# https://github.com/encode/httpx/pull/143#issuecomment-515202982)

return self.stream_reader.at_eof()

async def close(self) -> None:
Expand Down
7 changes: 7 additions & 0 deletions httpx/concurrency/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,10 @@ async def __aexit__(
traceback: TracebackType = None,
) -> None:
raise NotImplementedError() # pragma: no cover

async def close(self, exception: BaseException = None) -> None:
if exception is None:
await self.__aexit__(None, None, None)
else:
traceback = exception.__traceback__ # type: ignore
await self.__aexit__(type(exception), exception, traceback)
255 changes: 255 additions & 0 deletions httpx/concurrency/trio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import functools
import math
import ssl
import typing
from types import TracebackType

import trio

from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
BaseQueue,
BaseTCPStream,
ConcurrencyBackend,
TimeoutFlag,
)


def _or_inf(value: typing.Optional[float]) -> float:
return value if value is not None else float("inf")


class TCPStream(BaseTCPStream):
def __init__(
self,
stream: typing.Union[trio.SocketStream, trio.SSLStream],
timeout: TimeoutConfig,
) -> None:
self.stream = stream
self.timeout = timeout
self.write_buffer = b""
self.write_lock = trio.Lock()

def get_http_version(self) -> str:
if not isinstance(self.stream, trio.SSLStream):
return "HTTP/1.1"

ident = self.stream.selected_alpn_protocol()
if ident is None:
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
return "HTTP/1.1"

return "HTTP/2" if ident == "h2" else "HTTP/1.1"

async def read(
self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
) -> bytes:
if timeout is None:
timeout = self.timeout

while True:
# Check our flag at the first possible moment, and use a fine
# grained retry loop if we're not yet in read-timeout mode.
should_raise = flag is None or flag.raise_on_read_timeout
read_timeout = _or_inf(timeout.read_timeout if should_raise else 0.01)

with trio.move_on_after(read_timeout):
return await self.stream.receive_some(max_bytes=n)

if should_raise:
raise ReadTimeout() from None

def is_connection_dropped(self) -> bool:
# Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982
stream = self.stream

# Peek through any SSLStream wrappers to get the underlying SocketStream.
while hasattr(stream, "transport_stream"):
stream = stream.transport_stream
assert isinstance(stream, trio.SocketStream)

# Counter-intuitively, what we really want to know here is whether the socket is
# *readable*, i.e. whether it would return immediately with empty bytes if we
# called `.recv()` on it, indicating that the other end has closed the socket.
# See: https://github.com/encode/httpx/pull/143#issuecomment-515181778
return stream.socket.is_readable()

def write_no_block(self, data: bytes) -> None:
self.write_buffer += data

async def write(
self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
) -> None:
if self.write_buffer:
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
previous_data = self.write_buffer
# Reset before recursive call, otherwise we'll go through
# this branch indefinitely.
self.write_buffer = b""
try:
await self.write(previous_data, timeout=timeout, flag=flag)
except WriteTimeout:
self.writer_buffer = previous_data
raise

if not data:
return

if timeout is None:
timeout = self.timeout

write_timeout = _or_inf(timeout.write_timeout)

while True:
with trio.move_on_after(write_timeout):
async with self.write_lock:
await self.stream.send_all(data)
break
# We check our flag at the first possible moment, in order to
# allow us to suppress write timeouts, if we've since
# switched over to read-timeout mode.
should_raise = flag is None or flag.raise_on_write_timeout
if should_raise:
raise WriteTimeout() from None

async def close(self) -> None:
await self.stream.aclose()


class PoolSemaphore(BasePoolSemaphore):
def __init__(self, pool_limits: PoolLimits):
self.pool_limits = pool_limits

@property
def semaphore(self) -> typing.Optional[trio.Semaphore]:
if not hasattr(self, "_semaphore"):
max_connections = self.pool_limits.hard_limit
if max_connections is None:
self._semaphore = None
else:
self._semaphore = trio.Semaphore(
max_connections, max_value=max_connections
)
return self._semaphore

async def acquire(self) -> None:
if self.semaphore is None:
return

timeout = _or_inf(self.pool_limits.pool_timeout)

with trio.move_on_after(timeout):
await self.semaphore.acquire()
return

raise PoolTimeout()

def release(self) -> None:
if self.semaphore is None:
return

self.semaphore.release()


class TrioBackend(ConcurrencyBackend):
async def open_tcp_stream(
self,
hostname: str,
port: int,
ssl_context: typing.Optional[ssl.SSLContext],
timeout: TimeoutConfig,
) -> TCPStream:
connect_timeout = _or_inf(timeout.connect_timeout)

with trio.move_on_after(connect_timeout) as cancel_scope:
stream: trio.SocketStream = await trio.open_tcp_stream(hostname, port)
if ssl_context is not None:
stream = trio.SSLStream(stream, ssl_context, server_hostname=hostname)
await stream.do_handshake()

if cancel_scope.cancelled_caught:
raise ConnectTimeout()

return TCPStream(stream=stream, timeout=timeout)

async def run_in_threadpool(
self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return await trio.to_thread.run_sync(
functools.partial(func, **kwargs) if kwargs else func, *args
)

def run(
self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any
) -> typing.Any:
return trio.run(
functools.partial(coroutine, **kwargs) if kwargs else coroutine, *args
)

def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)

def create_queue(self, max_size: int) -> BaseQueue:
return Queue(max_size=max_size)

def create_event(self) -> BaseEvent:
return Event()

def background_manager(
self, coroutine: typing.Callable, *args: typing.Any
) -> "BackgroundManager":
return BackgroundManager(coroutine, *args)


class Queue(BaseQueue):
def __init__(self, max_size: int) -> None:
self.send_channel, self.receive_channel = trio.open_memory_channel(math.inf)

async def get(self) -> typing.Any:
return await self.receive_channel.receive()

async def put(self, value: typing.Any) -> None:
await self.send_channel.send(value)


class Event(BaseEvent):
def __init__(self) -> None:
self._event = trio.Event()

def set(self) -> None:
self._event.set()

def is_set(self) -> bool:
return self._event.is_set()

async def wait(self) -> None:
await self._event.wait()

def clear(self) -> None:
# trio.Event.clear() was deprecated in Trio 0.12.
# https://github.com/python-trio/trio/issues/637
self._event = trio.Event()


class BackgroundManager(BaseBackgroundManager):
def __init__(self, coroutine: typing.Callable, *args: typing.Any) -> None:
self.coroutine = coroutine
self.args = args
self.nursery_manager = trio.open_nursery()
self.nursery: typing.Optional[trio.Nursery] = None

async def __aenter__(self) -> "BackgroundManager":
self.nursery = await self.nursery_manager.__aenter__()
self.nursery.start_soon(self.coroutine, *self.args)
return self

async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
assert self.nursery is not None
florimondmanca marked this conversation as resolved.
Show resolved Hide resolved
await self.nursery_manager.__aexit__(exc_type, exc_value, traceback)
3 changes: 2 additions & 1 deletion httpx/dispatch/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ async def run_app() -> None:
await response_started_or_failed.wait()

if app_exc is not None and self.raise_app_exceptions:
await background.close(app_exc)
raise app_exc

assert status_code is not None, "application did not return a response."
Expand All @@ -138,7 +139,7 @@ async def run_app() -> None:
async def on_close() -> None:
nonlocal response_body
await response_body.drain()
await background.__aexit__(None, None, None)
await background.close(app_exc)
if app_exc is not None and self.raise_app_exceptions:
raise app_exc

Expand Down
5 changes: 5 additions & 0 deletions httpx/dispatch/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ..concurrency.base import BaseEvent, BaseTCPStream, ConcurrencyBackend, TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..exceptions import ProtocolError
from ..models import AsyncRequest, AsyncResponse
from ..utils import get_logger

Expand Down Expand Up @@ -187,6 +188,10 @@ async def receive_event(
logger.debug(
f"receive_event stream_id={event_stream_id} event={event!r}"
)

if hasattr(event, "error_code"):
raise ProtocolError(event)

if isinstance(event, h2.events.WindowUpdated):
if event_stream_id == 0:
for window_update_event in self.window_update_received.values():
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ combine_as_imports = True
force_grid_wrap = 0
include_trailing_comma = True
known_first_party = httpx,tests
known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,nox,pytest,requests,rfc3986,setuptools,trustme,uvicorn
known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,nox,pytest,requests,rfc3986,setuptools,trio,trustme,uvicorn
line_length = 88
multi_line_output = 3

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def get_packages(package):
"idna==2.*",
"rfc3986==1.*",
],
extras_require={"trio": ["trio"]},
classifiers=[
"Development Status :: 3 - Alpha",
"Environment :: Web Environment",
Expand Down
3 changes: 2 additions & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-e .
-e .[trio]

# Optional
brotlipy==0.7.*
Expand All @@ -11,6 +11,7 @@ isort
mypy
pytest
pytest-asyncio
pytest-trio
pytest-cov
trustme
uvicorn
Expand Down
12 changes: 12 additions & 0 deletions tests/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,15 @@ async def sleep(backend, seconds: int):
@sleep.register(AsyncioBackend)
async def _sleep_asyncio(backend, seconds: int):
await asyncio.sleep(seconds)


try:
import trio
from httpx.concurrency.trio import TrioBackend
except ImportError: # pragma: no cover
pass
else:

@sleep.register(TrioBackend)
async def _sleep_trio(backend, seconds: int):
await trio.sleep(seconds)
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,17 @@ def clean_environ() -> typing.Dict[str, typing.Any]:
os.environ.update(original_environ)


@pytest.fixture(params=[pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)])
backend_params = [pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)]

try:
from httpx.concurrency.trio import TrioBackend
except ImportError: # pragma: no cover
pass
else:
backend_params.append(pytest.param(TrioBackend, marks=pytest.mark.trio))


@pytest.fixture(params=backend_params)
def backend(request):
backend_cls = request.param
return backend_cls()
Expand Down
Loading