Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added AnyIO support #169

Merged
merged 7 commits into from Sep 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
194 changes: 194 additions & 0 deletions httpcore/_backends/anyio.py
@@ -0,0 +1,194 @@
import select
from ssl import SSLContext
from typing import Optional

import anyio.abc
from anyio import BrokenResourceError, EndOfStream
from anyio.abc import ByteStream, SocketAttribute
from anyio.streams.tls import TLSAttribute, TLSStream

from .._exceptions import (
CloseError,
ConnectError,
ConnectTimeout,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
)
from .._types import TimeoutDict
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream


class SocketStream(AsyncSocketStream):
def __init__(self, stream: ByteStream) -> None:
self.stream = stream
self.read_lock = anyio.create_lock()
self.write_lock = anyio.create_lock()

def get_http_version(self) -> str:
alpn_protocol = self.stream.extra(TLSAttribute.alpn_protocol, None)
return "HTTP/2" if alpn_protocol == "h2" else "HTTP/1.1"

async def start_tls(
self,
hostname: bytes,
ssl_context: SSLContext,
timeout: TimeoutDict,
) -> "SocketStream":
connect_timeout = timeout.get("connect")
try:
async with anyio.fail_after(connect_timeout):
ssl_stream = await TLSStream.wrap(
self.stream,
ssl_context=ssl_context,
hostname=hostname.decode("ascii"),
)
except TimeoutError:
raise ConnectTimeout from None
except BrokenResourceError as exc:
raise ConnectError from exc

return SocketStream(ssl_stream)

async def read(self, n: int, timeout: TimeoutDict) -> bytes:
read_timeout = timeout.get("read")
async with self.read_lock:
try:
async with anyio.fail_after(read_timeout):
return await self.stream.receive(n)
except TimeoutError:
raise ReadTimeout from None
except BrokenResourceError as exc:
raise ReadError from exc
except EndOfStream:
raise ReadError("Server disconnected while attempting read") from None

async def write(self, data: bytes, timeout: TimeoutDict) -> None:
if not data:
return

write_timeout = timeout.get("write")
async with self.write_lock:
try:
async with anyio.fail_after(write_timeout):
return await self.stream.send(data)
except TimeoutError:
raise WriteTimeout from None
except BrokenResourceError as exc:
raise WriteError from exc

async def aclose(self) -> None:
async with self.write_lock:
try:
await self.stream.aclose()
except BrokenResourceError as exc:
raise CloseError from exc

def is_connection_dropped(self) -> bool:
raw_socket = self.stream.extra(SocketAttribute.raw_socket)
rready, _wready, _xready = select.select([raw_socket], [], [], 0)
return bool(rready)
Comment on lines +88 to +91
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to look into #182 some more before deciding that this trick is worth committing to. 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense for us to approach it in this order?...

  1. Merge this PR.
  2. Update Tweak dropped connection detection #185 to include curio and anyio.
  3. Get Tweak dropped connection detection #185 merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure yup, #185 isn't blocking us from merging this, just wanted to clarify that this select() approach needed a bit of refining. :)



class Lock(AsyncLock):
def __init__(self) -> None:
self._lock = anyio.create_lock()

async def release(self) -> None:
await self._lock.release()

async def acquire(self) -> None:
await self._lock.acquire()


class Semaphore(AsyncSemaphore):
def __init__(self, max_value: int, exc_class: type):
self.max_value = max_value
self.exc_class = exc_class

@property
def semaphore(self) -> anyio.abc.Semaphore:
if not hasattr(self, "_semaphore"):
self._semaphore = anyio.create_semaphore(self.max_value)
return self._semaphore

async def acquire(self, timeout: float = None) -> None:
async with anyio.move_on_after(timeout):
await self.semaphore.acquire()
return

raise self.exc_class()

async def release(self) -> None:
await self.semaphore.release()


class AnyIOBackend(AsyncBackend):
async def open_tcp_stream(
self,
hostname: bytes,
port: int,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
*,
local_address: Optional[str],
) -> AsyncSocketStream:
connect_timeout = timeout.get("connect")
unicode_host = hostname.decode("utf-8")

try:
async with anyio.fail_after(connect_timeout):
stream: anyio.abc.ByteStream
stream = await anyio.connect_tcp(
unicode_host, port, local_host=local_address
)
if ssl_context:
stream = await TLSStream.wrap(
stream,
hostname=unicode_host,
ssl_context=ssl_context,
standard_compatible=False,
)
except TimeoutError:
raise ConnectTimeout from None
except BrokenResourceError as exc:
raise ConnectError from exc

return SocketStream(stream=stream)

async def open_uds_stream(
self,
path: str,
hostname: bytes,
ssl_context: Optional[SSLContext],
timeout: TimeoutDict,
) -> AsyncSocketStream:
connect_timeout = timeout.get("connect")
unicode_host = hostname.decode("utf-8")

try:
async with anyio.fail_after(connect_timeout):
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
if ssl_context:
stream = await TLSStream.wrap(
stream,
hostname=unicode_host,
ssl_context=ssl_context,
standard_compatible=False,
)
except TimeoutError:
raise ConnectTimeout from None
except BrokenResourceError as exc:
raise ConnectError from exc

return SocketStream(stream=stream)

def create_lock(self) -> AsyncLock:
return Lock()

def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
return Semaphore(max_value, exc_class=exc_class)

async def time(self) -> float:
return await anyio.current_time()
4 changes: 4 additions & 0 deletions httpcore/_backends/base.py
Expand Up @@ -25,6 +25,10 @@ def lookup_async_backend(name: str) -> "AsyncBackend":
from .curio import CurioBackend

return CurioBackend()
elif name == "anyio":
from .anyio import AnyIOBackend

return AnyIOBackend()

raise ValueError("Invalid backend name {name!r}")

Expand Down
69 changes: 41 additions & 28 deletions tests/async_tests/test_interfaces.py
Expand Up @@ -5,10 +5,15 @@

import httpcore
from httpcore._types import URL
from tests.conftest import Server, detect_backend
from tests.conftest import Server
from tests.utils import lookup_async_backend


@pytest.fixture(params=["auto", "anyio"])
def backend(request):
return request.param


async def read_body(stream: httpcore.AsyncByteStream) -> bytes:
try:
body = []
Expand All @@ -20,8 +25,8 @@ async def read_body(stream: httpcore.AsyncByteStream) -> bytes:


@pytest.mark.anyio
async def test_http_request() -> None:
async with httpcore.AsyncConnectionPool() as http:
async def test_http_request(backend: str) -> None:
async with httpcore.AsyncConnectionPool(backend=backend) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
Expand All @@ -37,8 +42,8 @@ async def test_http_request() -> None:


@pytest.mark.anyio
async def test_https_request() -> None:
async with httpcore.AsyncConnectionPool() as http:
async def test_https_request(backend: str) -> None:
async with httpcore.AsyncConnectionPool(backend=backend) as http:
method = b"GET"
url = (b"https", b"example.org", 443, b"/")
headers = [(b"host", b"example.org")]
Expand All @@ -54,8 +59,8 @@ async def test_https_request() -> None:


@pytest.mark.anyio
async def test_request_unsupported_protocol() -> None:
async with httpcore.AsyncConnectionPool() as http:
async def test_request_unsupported_protocol(backend: str) -> None:
async with httpcore.AsyncConnectionPool(backend=backend) as http:
method = b"GET"
url = (b"ftp", b"example.org", 443, b"/")
headers = [(b"host", b"example.org")]
Expand All @@ -64,8 +69,8 @@ async def test_request_unsupported_protocol() -> None:


@pytest.mark.anyio
async def test_http2_request() -> None:
async with httpcore.AsyncConnectionPool(http2=True) as http:
async def test_http2_request(backend: str) -> None:
async with httpcore.AsyncConnectionPool(backend=backend, http2=True) as http:
method = b"GET"
url = (b"https", b"example.org", 443, b"/")
headers = [(b"host", b"example.org")]
Expand All @@ -81,8 +86,8 @@ async def test_http2_request() -> None:


@pytest.mark.anyio
async def test_closing_http_request() -> None:
async with httpcore.AsyncConnectionPool() as http:
async def test_closing_http_request(backend: str) -> None:
async with httpcore.AsyncConnectionPool(backend=backend) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org"), (b"connection", b"close")]
Expand All @@ -98,8 +103,8 @@ async def test_closing_http_request() -> None:


@pytest.mark.anyio
async def test_http_request_reuse_connection() -> None:
async with httpcore.AsyncConnectionPool() as http:
async def test_http_request_reuse_connection(backend: str) -> None:
async with httpcore.AsyncConnectionPool(backend=backend) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
Expand Down Expand Up @@ -128,8 +133,8 @@ async def test_http_request_reuse_connection() -> None:


@pytest.mark.anyio
async def test_https_request_reuse_connection() -> None:
async with httpcore.AsyncConnectionPool() as http:
async def test_https_request_reuse_connection(backend: str) -> None:
async with httpcore.AsyncConnectionPool(backend=backend) as http:
method = b"GET"
url = (b"https", b"example.org", 443, b"/")
headers = [(b"host", b"example.org")]
Expand Down Expand Up @@ -158,8 +163,8 @@ async def test_https_request_reuse_connection() -> None:


@pytest.mark.anyio
async def test_http_request_cannot_reuse_dropped_connection() -> None:
async with httpcore.AsyncConnectionPool() as http:
async def test_http_request_cannot_reuse_dropped_connection(backend: str) -> None:
async with httpcore.AsyncConnectionPool(backend=backend) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
Expand Down Expand Up @@ -193,13 +198,16 @@ async def test_http_request_cannot_reuse_dropped_connection() -> None:

@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"])
@pytest.mark.anyio
async def test_http_proxy(proxy_server: URL, proxy_mode: str) -> None:
async def test_http_proxy(proxy_server: URL, proxy_mode: str, backend: str) -> None:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
max_connections = 1
async with httpcore.AsyncHTTPProxy(
proxy_server, proxy_mode=proxy_mode, max_connections=max_connections
proxy_server,
proxy_mode=proxy_mode,
max_connections=max_connections,
backend=backend,
) as http:
http_version, status_code, reason, headers, stream = await http.request(
method, url, headers
Expand All @@ -212,11 +220,13 @@ async def test_http_proxy(proxy_server: URL, proxy_mode: str) -> None:


@pytest.mark.anyio
async def test_http_request_local_address() -> None:
if lookup_async_backend() == "trio":
async def test_http_request_local_address(backend: str) -> None:
if backend == "auto" and lookup_async_backend() == "trio":
pytest.skip("The trio backend does not support local_address")

async with httpcore.AsyncConnectionPool(local_address="0.0.0.0") as http:
async with httpcore.AsyncConnectionPool(
backend=backend, local_address="0.0.0.0"
) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
Expand Down Expand Up @@ -294,9 +304,10 @@ async def test_connection_pool_get_connection_info(
keepalive_expiry: float,
expected_during_active: dict,
expected_during_idle: dict,
backend: str,
) -> None:
async with httpcore.AsyncConnectionPool(
http2=http2, keepalive_expiry=keepalive_expiry
http2=http2, keepalive_expiry=keepalive_expiry, backend=backend
) as http:
method = b"GET"
url = (b"https", b"example.org", 443, b"/")
Expand Down Expand Up @@ -324,10 +335,12 @@ async def test_connection_pool_get_connection_info(
reason="Unix Domain Sockets only exist on Unix",
)
@pytest.mark.anyio
async def test_http_request_unix_domain_socket(uds_server: Server) -> None:
async def test_http_request_unix_domain_socket(
uds_server: Server, backend: str
) -> None:
uds = uds_server.config.uds
assert uds is not None
async with httpcore.AsyncConnectionPool(uds=uds) as http:
async with httpcore.AsyncConnectionPool(uds=uds, backend=backend) as http:
method = b"GET"
url = (b"http", b"localhost", None, b"/")
headers = [(b"host", b"localhost")]
Expand All @@ -345,10 +358,10 @@ async def test_http_request_unix_domain_socket(uds_server: Server) -> None:
@pytest.mark.parametrize("connections_number", [4])
@pytest.mark.anyio
async def test_max_keepalive_connections_handled_correctly(
max_keepalive: int, connections_number: int
max_keepalive: int, connections_number: int, backend: str
) -> None:
async with httpcore.AsyncConnectionPool(
max_keepalive_connections=max_keepalive, keepalive_expiry=60
max_keepalive_connections=max_keepalive, keepalive_expiry=60, backend=backend
) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
Expand All @@ -371,7 +384,7 @@ async def test_max_keepalive_connections_handled_correctly(

@pytest.mark.anyio
async def test_explicit_backend_name() -> None:
async with httpcore.AsyncConnectionPool(backend=detect_backend()) as http:
async with httpcore.AsyncConnectionPool(backend=lookup_async_backend()) as http:
method = b"GET"
url = (b"http", b"example.org", 80, b"/")
headers = [(b"host", b"example.org")]
Expand Down