Skip to content

Commit

Permalink
Cancel WebSocketTestSession on close (#2427)
Browse files Browse the repository at this point in the history
* Cancel `WebSocketTestSession` on close

* Undo some noise

* Fix test

* Undo pyproject

* Undo anyio bump

* Undo changes on test_authentication

* Always call cancel scope
  • Loading branch information
Kludex committed Jan 20, 2024
1 parent 13c66c9 commit 3ae161e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 69 deletions.
95 changes: 59 additions & 36 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

import contextlib
import inspect
import io
import json
import math
import queue
import sys
import typing
import warnings
from concurrent.futures import Future
from types import GeneratorType
from urllib.parse import unquote, urljoin

import anyio
import anyio.abc
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream
Expand All @@ -19,6 +23,11 @@
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

if sys.version_info >= (3, 10): # pragma: no cover
from typing import TypeGuard
else: # pragma: no cover
from typing_extensions import TypeGuard

try:
import httpx
except ModuleNotFoundError: # pragma: no cover
Expand All @@ -39,7 +48,7 @@
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]


def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> TypeGuard[ASGI3App]:
if inspect.isclass(app):
return hasattr(app, "__await__")
return is_async_callable(app)
Expand All @@ -64,7 +73,7 @@ class _AsyncBackend(typing.TypedDict):


class _Upgrade(Exception):
def __init__(self, session: "WebSocketTestSession") -> None:
def __init__(self, session: WebSocketTestSession) -> None:
self.session = session


Expand All @@ -79,16 +88,17 @@ def __init__(
self.scope = scope
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[Message]" = queue.Queue()
self._send_queue: "queue.Queue[Message | BaseException]" = queue.Queue()
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
self.extra_headers = None

def __enter__(self) -> "WebSocketTestSession":
def __enter__(self) -> WebSocketTestSession:
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())
self.should_close = anyio.Event()

try:
_: "Future[None]" = self.portal.start_task_soon(self._run)
_: Future[None] = self.portal.start_task_soon(self._run)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
Expand All @@ -99,10 +109,14 @@ def __enter__(self) -> "WebSocketTestSession":
self.extra_headers = message.get("headers", None)
return self

async def _notify_close(self) -> None:
self.should_close.set()

def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
self.portal.start_task_soon(self._notify_close)
self.exit_stack.close()
while not self._send_queue.empty():
message = self._send_queue.get()
Expand All @@ -113,14 +127,22 @@ async def _run(self) -> None:
"""
The sub-thread in which the websocket session runs.
"""
scope = self.scope
receive = self._asgi_receive
send = self._asgi_send
try:
await self.app(scope, receive, send)
except BaseException as exc:
self._send_queue.put(exc)
raise

async def run_app(tg: anyio.abc.TaskGroup) -> None:
try:
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except anyio.get_cancelled_exc_class():
...
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
tg.cancel_scope.cancel()

async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()

async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
Expand Down Expand Up @@ -153,7 +175,7 @@ def send_json(self, data: typing.Any, mode: str = "text") -> None:
else:
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})

def close(self, code: int = 1000, reason: typing.Union[str, None] = None) -> None:
def close(self, code: int = 1000, reason: str | None = None) -> None:
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})

def receive(self) -> Message:
Expand All @@ -172,8 +194,9 @@ def receive_bytes(self) -> bytes:
self._raise_on_close(message)
return typing.cast(bytes, message["bytes"])

def receive_json(self, mode: str = "text") -> typing.Any:
assert mode in ["text", "binary"]
def receive_json(
self, mode: typing.Literal["text", "binary"] = "text"
) -> typing.Any:
message = self.receive()
self._raise_on_close(message)
if mode == "text":
Expand All @@ -191,7 +214,7 @@ def __init__(
raise_server_exceptions: bool = True,
root_path: str = "",
*,
app_state: typing.Dict[str, typing.Any],
app_state: dict[str, typing.Any],
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
Expand All @@ -217,7 +240,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:

# Include the 'host' header.
if "host" in request.headers:
headers: typing.List[typing.Tuple[bytes, bytes]] = []
headers: list[tuple[bytes, bytes]] = []
elif port == default_port: # pragma: no cover
headers = [(b"host", host.encode())]
else: # pragma: no cover
Expand All @@ -229,7 +252,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
for key, value in request.headers.multi_items()
]

scope: typing.Dict[str, typing.Any]
scope: dict[str, typing.Any]

if scheme in {"ws", "wss"}:
subprotocol = request.headers.get("sec-websocket-protocol", None)
Expand Down Expand Up @@ -272,7 +295,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
request_complete = False
response_started = False
response_complete: anyio.Event
raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
template = None
context = None

Expand Down Expand Up @@ -363,26 +386,25 @@ async def send(message: Message) -> None:

class TestClient(httpx.Client):
__test__ = False
task: "Future[None]"
portal: typing.Optional[anyio.abc.BlockingPortal] = None
task: Future[None]
portal: anyio.abc.BlockingPortal | None = None

def __init__(
self,
app: ASGIApp,
base_url: str = "http://testserver",
raise_server_exceptions: bool = True,
root_path: str = "",
backend: str = "asyncio",
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
cookies: httpx._types.CookieTypes = None,
headers: typing.Dict[str, str] = None,
backend: typing.Literal["asyncio", "trio"] = "asyncio",
backend_options: typing.Dict[str, typing.Any] | None = None,
cookies: httpx._types.CookieTypes | None = None,
headers: typing.Dict[str, str] | None = None,
follow_redirects: bool = True,
) -> None:
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
)
if _is_asgi3(app):
app = typing.cast(ASGI3App, app)
asgi_app = app
else:
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
Expand Down Expand Up @@ -419,13 +441,11 @@ def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, No
yield portal

def _choose_redirect_arg(
self,
follow_redirects: typing.Optional[bool],
allow_redirects: typing.Optional[bool],
) -> typing.Union[bool, httpx._client.UseClientDefault]:
redirect: typing.Union[
bool, httpx._client.UseClientDefault
] = httpx._client.USE_CLIENT_DEFAULT
self, follow_redirects: bool | None, allow_redirects: bool | None
) -> bool | httpx._client.UseClientDefault:
redirect: bool | httpx._client.UseClientDefault = (
httpx._client.USE_CLIENT_DEFAULT
)
if allow_redirects is not None:
message = (
"The `allow_redirects` argument is deprecated. "
Expand Down Expand Up @@ -709,7 +729,10 @@ def delete( # type: ignore[override]
)

def websocket_connect(
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
self,
url: str,
subprotocols: typing.Sequence[str] | None = None,
**kwargs: typing.Any,
) -> "WebSocketTestSession":
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
Expand Down

0 comments on commit 3ae161e

Please sign in to comment.