Skip to content

Commit

Permalink
Upgraded to AnyIO 4.0 (#2211)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
agronholm and Kludex committed Jul 23, 2023
1 parent e160a17 commit 1a71441
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 12 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ types-contextvars==2.4.7.2
types-PyYAML==6.0.12.10
types-dataclasses==0.6.6
pytest==7.4.0
trio==0.21.0
trio==0.22.1
anyio@git+https://github.com/agronholm/anyio.git

# Documentation
mkdocs==1.4.3
Expand Down
28 changes: 24 additions & 4 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
import sys
import typing
from contextlib import contextmanager

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import BaseExceptionGroup

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


@contextmanager
def _convert_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]

raise exc


class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
Expand Down Expand Up @@ -107,6 +124,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
Expand Down Expand Up @@ -182,10 +201,11 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response.raw_headers = message["headers"]
return response

async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
with _convert_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
4 changes: 4 additions & 0 deletions starlette/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -72,6 +73,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:


class WSGIResponder:
stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]

def __init__(self, app: typing.Callable, scope: Scope) -> None:
self.app = app
self.scope = scope
Expand Down
19 changes: 13 additions & 6 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import anyio
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import is_async_callable
Expand Down Expand Up @@ -737,12 +738,18 @@ def __enter__(self) -> "TestClient":
def reset_portal() -> None:
self.portal = None

self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
send1: ObjectSendStream[
typing.Optional[typing.MutableMapping[str, typing.Any]]
]
receive1: ObjectReceiveStream[
typing.Optional[typing.MutableMapping[str, typing.Any]]
]
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send1, receive1 = anyio.create_memory_object_stream(math.inf)
send2, receive2 = anyio.create_memory_object_stream(math.inf)
self.stream_send = StapledObjectStream(send1, receive1)
self.stream_receive = StapledObjectStream(send2, receive2)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)

Expand Down
8 changes: 7 additions & 1 deletion tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from starlette.middleware.wsgi import WSGIMiddleware, build_environ

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import ExceptionGroup


def hello_world(environ, start_response):
status = "200 OK"
Expand Down Expand Up @@ -66,9 +69,12 @@ def test_wsgi_exception(test_client_factory):
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
with pytest.raises(RuntimeError):
with pytest.raises(ExceptionGroup) as exc:
client.get("/")

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], RuntimeError)


def test_wsgi_exc_info(test_client_factory):
# Note that we're testing the WSGI app directly here.
Expand Down
4 changes: 4 additions & 0 deletions tests/test_websockets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
from typing import Any, MutableMapping

import anyio
import pytest
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette import status
from starlette.types import Receive, Scope, Send
Expand Down Expand Up @@ -178,6 +180,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:


def test_websocket_concurrency_pattern(test_client_factory):
stream_send: ObjectSendStream[MutableMapping[str, Any]]
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
stream_send, stream_receive = anyio.create_memory_object_stream()

async def reader(websocket):
Expand Down

0 comments on commit 1a71441

Please sign in to comment.