diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index 392c2ba16..a2dbabd8a 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -1,17 +1,26 @@ +from typing import Any, Callable + import pytest from starlette.applications import Starlette from starlette.background import BackgroundTask from starlette.middleware.errors import ServerErrorMiddleware +from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route +from starlette.testclient import TestClient +from starlette.types import Receive, Scope, Send + +TestClientFactory = Callable[..., TestClient] -def test_handler(test_client_factory): - async def app(scope, receive, send): +def test_handler( + test_client_factory: TestClientFactory, +) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") - def error_500(request, exc): + def error_500(request: Request, exc: Exception) -> JSONResponse: return JSONResponse({"detail": "Server Error"}, status_code=500) app = ServerErrorMiddleware(app, handler=error_500) @@ -21,8 +30,8 @@ def error_500(request, exc): assert response.json() == {"detail": "Server Error"} -def test_debug_text(test_client_factory): - async def app(scope, receive, send): +def test_debug_text(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) @@ -33,8 +42,8 @@ async def app(scope, receive, send): assert "RuntimeError: Something went wrong" in response.text -def test_debug_html(test_client_factory): - async def app(scope, receive, send): +def test_debug_html(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app, debug=True) @@ -45,8 +54,8 @@ async def app(scope, receive, send): assert "RuntimeError" in response.text -def test_debug_after_response_sent(test_client_factory): - async def app(scope, receive, send): +def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(b"", status_code=204) await response(scope, receive, send) raise RuntimeError("Something went wrong") @@ -57,12 +66,12 @@ async def app(scope, receive, send): client.get("/") -def test_debug_not_http(test_client_factory): +def test_debug_not_http(test_client_factory: TestClientFactory) -> None: """ DebugMiddleware should just pass through any non-http messages as-is. """ - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: raise RuntimeError("Something went wrong") app = ServerErrorMiddleware(app) @@ -73,17 +82,17 @@ async def app(scope, receive, send): pass # pragma: nocover -def test_background_task(test_client_factory): +def test_background_task(test_client_factory: TestClientFactory) -> None: accessed_error_handler = False - def error_handler(request, exc): + def error_handler(request: Request, exc: Exception) -> Any: nonlocal accessed_error_handler accessed_error_handler = True - def raise_exception(): + def raise_exception() -> None: raise Exception("Something went wrong") - async def endpoint(request): + async def endpoint(request: Request) -> Response: task = BackgroundTask(raise_exception) return Response(status_code=204, background=task)