Skip to content

Commit

Permalink
Added type annotations to test_error.py (#2462)
Browse files Browse the repository at this point in the history
* added type annotations to test_error.py

* Apply suggestions from code review

---------

Co-authored-by: Scirlat Danut <scirlatdanut@scirlats-mini.lan>
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
3 people committed Feb 4, 2024
1 parent 93e74a4 commit c158ef4
Showing 1 changed file with 24 additions and 15 deletions.
39 changes: 24 additions & 15 deletions tests/middleware/test_errors.py
@@ -1,17 +1,26 @@
from typing import Any, Callable

import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
from starlette.testclient import TestClient
from starlette.types import Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]


def test_handler(test_client_factory):
async def app(scope, receive, send):
def test_handler(
test_client_factory: TestClientFactory,
) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")

def error_500(request, exc):
def error_500(request: Request, exc: Exception) -> JSONResponse:
return JSONResponse({"detail": "Server Error"}, status_code=500)

app = ServerErrorMiddleware(app, handler=error_500)
Expand All @@ -21,8 +30,8 @@ def error_500(request, exc):
assert response.json() == {"detail": "Server Error"}


def test_debug_text(test_client_factory):
async def app(scope, receive, send):
def test_debug_text(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")

app = ServerErrorMiddleware(app, debug=True)
Expand All @@ -33,8 +42,8 @@ async def app(scope, receive, send):
assert "RuntimeError: Something went wrong" in response.text


def test_debug_html(test_client_factory):
async def app(scope, receive, send):
def test_debug_html(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")

app = ServerErrorMiddleware(app, debug=True)
Expand All @@ -45,8 +54,8 @@ async def app(scope, receive, send):
assert "RuntimeError" in response.text


def test_debug_after_response_sent(test_client_factory):
async def app(scope, receive, send):
def test_debug_after_response_sent(test_client_factory: TestClientFactory) -> None:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
response = Response(b"", status_code=204)
await response(scope, receive, send)
raise RuntimeError("Something went wrong")
Expand All @@ -57,12 +66,12 @@ async def app(scope, receive, send):
client.get("/")


def test_debug_not_http(test_client_factory):
def test_debug_not_http(test_client_factory: TestClientFactory) -> None:
"""
DebugMiddleware should just pass through any non-http messages as-is.
"""

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
raise RuntimeError("Something went wrong")

app = ServerErrorMiddleware(app)
Expand All @@ -73,17 +82,17 @@ async def app(scope, receive, send):
pass # pragma: nocover


def test_background_task(test_client_factory):
def test_background_task(test_client_factory: TestClientFactory) -> None:
accessed_error_handler = False

def error_handler(request, exc):
def error_handler(request: Request, exc: Exception) -> Any:
nonlocal accessed_error_handler
accessed_error_handler = True

def raise_exception():
def raise_exception() -> None:
raise Exception("Something went wrong")

async def endpoint(request):
async def endpoint(request: Request) -> Response:
task = BackgroundTask(raise_exception)
return Response(status_code=204, background=task)

Expand Down

0 comments on commit c158ef4

Please sign in to comment.