Skip to content

Commit

Permalink
Add type hints to test_testclient.py (#2493)
Browse files Browse the repository at this point in the history
* Add type hints to test_testclient.py

* Fix check errors

* Apply suggestions from code review

* Use ASGIInstance instead

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
3 people committed Feb 9, 2024
1 parent eaee85b commit 3f38038
Showing 1 changed file with 45 additions and 41 deletions.
86 changes: 45 additions & 41 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import itertools
import sys
from asyncio import current_task as asyncio_current_task
from asyncio import Task, current_task as asyncio_current_task
from contextlib import asynccontextmanager
from typing import Callable
from typing import Any, AsyncGenerator, Callable

import anyio
import anyio.lowlevel
Expand All @@ -15,19 +17,21 @@
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse, Response
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.testclient import ASGIInstance, TestClient
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect

TestClientFactory = Callable[..., TestClient]


def mock_service_endpoint(request: Request):
def mock_service_endpoint(request: Request) -> JSONResponse:
return JSONResponse({"mock": "example"})


mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)])


def current_task():
def current_task() -> Task[Any] | trio.lowlevel.Task:
# anyio's TaskInfo comparisons are invalid after their associated native
# task object is GC'd https://github.com/agronholm/anyio/issues/324
asynclib_name = sniffio.current_async_library()
Expand All @@ -42,19 +46,19 @@ def current_task():
raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover


def startup():
def startup() -> None:
raise RuntimeError()


def test_use_testclient_in_endpoint(test_client_factory: Callable[..., TestClient]):
def test_use_testclient_in_endpoint(test_client_factory: TestClientFactory) -> None:
"""
We should be able to use the test client within applications.
This is useful if we need to mock out other services,
during tests or in development.
"""

def homepage(request: Request):
def homepage(request: Request) -> JSONResponse:
client = test_client_factory(mock_service)
response = client.get("/")
return JSONResponse(response.json())
Expand All @@ -66,7 +70,7 @@ def homepage(request: Request):
assert response.json() == {"mock": "example"}


def test_testclient_headers_behavior():
def test_testclient_headers_behavior() -> None:
"""
We should be able to use the test client with user defined headers.
Expand All @@ -86,16 +90,16 @@ def test_testclient_headers_behavior():


def test_use_testclient_as_contextmanager(
test_client_factory: Callable[..., TestClient], anyio_backend_name: str
):
test_client_factory: TestClientFactory, anyio_backend_name: str
) -> None:
"""
This test asserts a number of properties that are important for an
app level task_group
"""
counter = itertools.count()
identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar")

def get_identity():
def get_identity() -> int:
try:
return identity_runvar.get()
except LookupError:
Expand All @@ -109,7 +113,7 @@ def get_identity():
shutdown_loop = None

@asynccontextmanager
async def lifespan_context(app: Starlette):
async def lifespan_context(app: Starlette) -> AsyncGenerator[None, None]:
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop

startup_task = current_task()
Expand All @@ -119,7 +123,7 @@ async def lifespan_context(app: Starlette):
shutdown_task = current_task()
shutdown_loop = get_identity()

async def loop_id(request: Request):
async def loop_id(request: Request) -> JSONResponse:
return JSONResponse(get_identity())

app = Starlette(
Expand All @@ -143,7 +147,7 @@ async def loop_id(request: Request):
assert startup_task is shutdown_task

# outside the TestClient context, new requests continue to spawn in new
# eventloops in new threads
# event loops in new threads
assert client.get("/loop_id").json() == 1
assert client.get("/loop_id").json() == 2

Expand All @@ -165,7 +169,7 @@ async def loop_id(request: Request):
assert first_task is not startup_task


def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
def test_error_on_startup(test_client_factory: TestClientFactory) -> None:
with pytest.deprecated_call(
match="The on_startup and on_shutdown parameters are deprecated"
):
Expand All @@ -176,15 +180,15 @@ def test_error_on_startup(test_client_factory: Callable[..., TestClient]):
pass # pragma: no cover


def test_exception_in_middleware(test_client_factory: Callable[..., TestClient]):
def test_exception_in_middleware(test_client_factory: TestClientFactory) -> None:
class MiddlewareException(Exception):
pass

class BrokenMiddleware:
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
raise MiddlewareException()

broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)])
Expand All @@ -194,9 +198,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send):
pass # pragma: no cover


def test_testclient_asgi2(test_client_factory: Callable[..., TestClient]):
def app(scope: Scope):
async def inner(receive: Receive, send: Send):
def test_testclient_asgi2(test_client_factory: TestClientFactory) -> None:
def app(scope: Scope) -> ASGIInstance:
async def inner(receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
Expand All @@ -213,8 +217,8 @@ async def inner(receive: Receive, send: Send):
assert response.text == "Hello, world!"


def test_testclient_asgi3(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send):
def test_testclient_asgi3(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
await send(
{
"type": "http.response.start",
Expand All @@ -229,12 +233,12 @@ async def app(scope: Scope, receive: Receive, send: Send):
assert response.text == "Hello, world!"


def test_websocket_blocking_receive(test_client_factory: Callable[..., TestClient]):
def app(scope: Scope):
async def respond(websocket: WebSocket):
def test_websocket_blocking_receive(test_client_factory: TestClientFactory) -> None:
def app(scope: Scope) -> ASGIInstance:
async def respond(websocket: WebSocket) -> None:
await websocket.send_json({"message": "test"})

async def asgi(receive: Receive, send: Send):
async def asgi(receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
async with anyio.create_task_group() as task_group:
Expand All @@ -254,9 +258,9 @@ async def asgi(receive: Receive, send: Send):
assert data == {"message": "test"}


def test_websocket_not_block_on_close(test_client_factory: Callable[..., TestClient]):
def app(scope: Scope):
async def asgi(receive: Receive, send: Send):
def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None:
def app(scope: Scope) -> ASGIInstance:
async def asgi(receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
while True:
Expand All @@ -271,8 +275,8 @@ async def asgi(receive: Receive, send: Send):


@pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà"))
def test_query_params(test_client_factory: Callable[..., TestClient], param: str):
def homepage(request: Request):
def test_query_params(test_client_factory: TestClientFactory, param: str) -> None:
def homepage(request: Request) -> Response:
return Response(request.query_params["param"])

app = Starlette(routes=[Route("/", endpoint=homepage)])
Expand Down Expand Up @@ -301,8 +305,8 @@ def homepage(request: Request):
],
)
def test_domain_restricted_cookies(
test_client_factory: Callable[..., TestClient], domain: str, ok: bool
):
test_client_factory: TestClientFactory, domain: str, ok: bool
) -> None:
"""
Test that test client discards domain restricted cookies which do not match the
base_url of the testclient (`http://testserver` by default).
Expand All @@ -312,7 +316,7 @@ def test_domain_restricted_cookies(
in accordance with RFC 2965.
"""

async def app(scope: Scope, receive: Receive, send: Send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response("Hello, world!", media_type="text/plain")
response.set_cookie(
"mycookie",
Expand All @@ -328,8 +332,8 @@ async def app(scope: Scope, receive: Receive, send: Send):
assert cookie_set == ok


def test_forward_follow_redirects(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send):
def test_forward_follow_redirects(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if "/ok" in scope["path"]:
response = Response("ok")
else:
Expand All @@ -341,8 +345,8 @@ async def app(scope: Scope, receive: Receive, send: Send):
assert response.status_code == 200


def test_forward_nofollow_redirects(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send):
def test_forward_nofollow_redirects(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = RedirectResponse("/ok")
await response(scope, receive, send)

Expand All @@ -351,7 +355,7 @@ async def app(scope: Scope, receive: Receive, send: Send):
assert response.status_code == 307


def test_with_duplicate_headers(test_client_factory: Callable[[Starlette], TestClient]):
def test_with_duplicate_headers(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> JSONResponse:
return JSONResponse({"x-token": request.headers.getlist("x-token")})

Expand All @@ -361,7 +365,7 @@ def homepage(request: Request) -> JSONResponse:
assert response.json() == {"x-token": ["foo", "bar"]}


def test_merge_url(test_client_factory: Callable[..., TestClient]):
def test_merge_url(test_client_factory: TestClientFactory) -> None:
def homepage(request: Request) -> Response:
return Response(request.url.path)

Expand Down

0 comments on commit 3f38038

Please sign in to comment.