Skip to content

Commit

Permalink
Use AnyIO by default on asyncio (#345)
Browse files Browse the repository at this point in the history
* Use AnyIO by default on asyncio

* Test on AnyIO 3.1.0

* Changed the anyio version specifier

* Removed AnyIO 2.x compatibility code
  • Loading branch information
agronholm committed May 25, 2021
1 parent 958d992 commit 911d61c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 35 deletions.
44 changes: 13 additions & 31 deletions httpcore/_backends/anyio.py
Expand Up @@ -19,30 +19,12 @@
from .._utils import is_socket_readable
from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream

# For compatibility with both AnyIO 2.x and 3.x
# (some functions and context managers were converted from async to sync in 3.0)
try:
from anyio import maybe_async, maybe_async_cm
except ImportError:

def maybe_async(x): # type: ignore
return x

def maybe_async_cm(x): # type: ignore
return x


try:
from anyio import Lock as create_lock, Semaphore as create_semaphore
except ImportError:
from anyio import create_lock, create_semaphore # type: ignore


class SocketStream(AsyncSocketStream):
def __init__(self, stream: ByteStream) -> None:
self.stream = stream
self.read_lock = create_lock()
self.write_lock = create_lock()
self.read_lock = anyio.Lock()
self.write_lock = anyio.Lock()

def get_http_version(self) -> str:
alpn_protocol = self.stream.extra(TLSAttribute.alpn_protocol, None)
Expand All @@ -56,7 +38,7 @@ async def start_tls(
) -> "SocketStream":
connect_timeout = timeout.get("connect")
try:
async with maybe_async_cm(anyio.fail_after(connect_timeout)):
with anyio.fail_after(connect_timeout):
ssl_stream = await TLSStream.wrap(
self.stream,
ssl_context=ssl_context,
Expand All @@ -73,7 +55,7 @@ async def read(self, n: int, timeout: TimeoutDict) -> bytes:
read_timeout = timeout.get("read")
async with self.read_lock:
try:
async with maybe_async_cm(anyio.fail_after(read_timeout)):
with anyio.fail_after(read_timeout):
return await self.stream.receive(n)
except TimeoutError:
raise ReadTimeout from None
Expand All @@ -89,7 +71,7 @@ async def write(self, data: bytes, timeout: TimeoutDict) -> None:
write_timeout = timeout.get("write")
async with self.write_lock:
try:
async with maybe_async_cm(anyio.fail_after(write_timeout)):
with anyio.fail_after(write_timeout):
return await self.stream.send(data)
except TimeoutError:
raise WriteTimeout from None
Expand All @@ -110,10 +92,10 @@ def is_readable(self) -> bool:

class Lock(AsyncLock):
def __init__(self) -> None:
self._lock = create_lock()
self._lock = anyio.Lock()

async def release(self) -> None:
await maybe_async(self._lock.release())
self._lock.release()

async def acquire(self) -> None:
await self._lock.acquire()
Expand All @@ -127,18 +109,18 @@ def __init__(self, max_value: int, exc_class: type):
@property
def semaphore(self) -> anyio.abc.Semaphore:
if not hasattr(self, "_semaphore"):
self._semaphore = create_semaphore(self.max_value)
self._semaphore = anyio.Semaphore(self.max_value)
return self._semaphore

async def acquire(self, timeout: float = None) -> None:
async with maybe_async_cm(anyio.move_on_after(timeout)):
with anyio.move_on_after(timeout):
await self.semaphore.acquire()
return

raise self.exc_class()

async def release(self) -> None:
await maybe_async(self.semaphore.release())
self.semaphore.release()


class AnyIOBackend(AsyncBackend):
Expand All @@ -160,7 +142,7 @@ async def open_tcp_stream(
}

with map_exceptions(exc_map):
async with maybe_async_cm(anyio.fail_after(connect_timeout)):
with anyio.fail_after(connect_timeout):
stream: anyio.abc.ByteStream
stream = await anyio.connect_tcp(
unicode_host, port, local_host=local_address
Expand Down Expand Up @@ -191,7 +173,7 @@ async def open_uds_stream(
}

with map_exceptions(exc_map):
async with maybe_async_cm(anyio.fail_after(connect_timeout)):
with anyio.fail_after(connect_timeout):
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
if ssl_context:
stream = await TLSStream.wrap(
Expand All @@ -210,7 +192,7 @@ def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore:
return Semaphore(max_value, exc_class=exc_class)

async def time(self) -> float:
return await maybe_async(anyio.current_time())
return float(anyio.current_time())

async def sleep(self, seconds: float) -> None:
await anyio.sleep(seconds)
4 changes: 2 additions & 2 deletions httpcore/_backends/auto.py
Expand Up @@ -17,9 +17,9 @@ def backend(self) -> AsyncBackend:
backend = sniffio.current_async_library()

if backend == "asyncio":
from .asyncio import AsyncioBackend
from .anyio import AnyIOBackend

self._backend_implementation: AsyncBackend = AsyncioBackend()
self._backend_implementation: AsyncBackend = AnyIOBackend()
elif backend == "trio":
from .trio import TrioBackend

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -20,7 +20,7 @@ twine==3.4.1
wheel==0.36.2

# Tests & Linting
anyio==3.0.1
anyio==3.1.0
autoflake==1.4
black==21.4b2
coverage==5.5
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -53,7 +53,7 @@ def get_packages(package):
packages=get_packages("httpcore"),
include_package_data=True,
zip_safe=False,
install_requires=["h11>=0.11,<0.13", "sniffio==1.*"],
install_requires=["h11>=0.11,<0.13", "sniffio==1.*", "anyio==3.*"],
extras_require={
"http2": ["h2>=3,<5"],
},
Expand Down

0 comments on commit 911d61c

Please sign in to comment.