Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to test_exceptions.py #2479

Merged
merged 2 commits into from
Feb 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 29 additions & 21 deletions tests/test_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import Callable, Generator

import pytest

Expand All @@ -7,33 +8,37 @@
from starlette.requests import Request
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Route, Router, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import Receive, Scope, Send

TestClientFactory = Callable[..., TestClient]

def raise_runtime_error(request):

def raise_runtime_error(request: Request) -> None:
raise RuntimeError("Yikes")


def not_acceptable(request):
def not_acceptable(request: Request) -> None:
raise HTTPException(status_code=406)


def no_content(request):
def no_content(request: Request) -> None:
raise HTTPException(status_code=204)


def not_modified(request):
def not_modified(request: Request) -> None:
raise HTTPException(status_code=304)


def with_headers(request):
def with_headers(request: Request) -> None:
raise HTTPException(status_code=200, headers={"x-potato": "always"})


class BadBodyException(HTTPException):
pass


async def read_body_and_raise_exc(request: Request):
async def read_body_and_raise_exc(request: Request) -> None:
await request.body()
raise BadBodyException(422)

Expand All @@ -46,7 +51,7 @@ async def handler_that_reads_body(


class HandledExcAfterResponse:
async def __call__(self, scope, receive, send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
response = PlainTextResponse("OK", status_code=200)
await response(scope, receive, send)
raise HTTPException(status_code=406)
Expand Down Expand Up @@ -77,42 +82,45 @@ async def __call__(self, scope, receive, send):


@pytest.fixture
def client(test_client_factory):
def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]:
with test_client_factory(app) as client:
yield client


def test_not_acceptable(client):
def test_not_acceptable(client: TestClient) -> None:
response = client.get("/not_acceptable")
assert response.status_code == 406
assert response.text == "Not Acceptable"


def test_no_content(client):
def test_no_content(client: TestClient) -> None:
response = client.get("/no_content")
assert response.status_code == 204
assert "content-length" not in response.headers


def test_not_modified(client):
def test_not_modified(client: TestClient) -> None:
response = client.get("/not_modified")
assert response.status_code == 304
assert response.text == ""


def test_with_headers(client):
def test_with_headers(client: TestClient) -> None:
response = client.get("/with_headers")
assert response.status_code == 200
assert response.headers["x-potato"] == "always"


def test_websockets_should_raise(client):
def test_websockets_should_raise(client: TestClient) -> None:
with pytest.raises(RuntimeError):
with client.websocket_connect("/runtime_error"):
pass # pragma: nocover


def test_handled_exc_after_response(test_client_factory, client):
def test_handled_exc_after_response(
test_client_factory: TestClientFactory,
client: TestClient,
) -> None:
# A 406 HttpException is raised *after* the response has already been sent.
# The exception middleware should raise a RuntimeError.
with pytest.raises(RuntimeError):
Expand All @@ -126,13 +134,13 @@ def test_handled_exc_after_response(test_client_factory, client):
assert response.text == "OK"


def test_force_500_response(test_client_factory):
def test_force_500_response(test_client_factory: TestClientFactory) -> None:
# use a sentinal variable to make sure we actually
# make it into the endpoint and don't get a 500
# from an incorrect ASGI app signature or something
called = False

async def app(scope, receive, send):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal called
called = True
raise RuntimeError()
Expand All @@ -144,13 +152,13 @@ async def app(scope, receive, send):
assert response.text == ""


def test_http_str():
def test_http_str() -> None:
assert str(HTTPException(status_code=404)) == "404: Not Found"
assert str(HTTPException(404, "Not Found: foo")) == "404: Not Found: foo"
assert str(HTTPException(404, headers={"key": "value"})) == "404: Not Found"


def test_http_repr():
def test_http_repr() -> None:
assert repr(HTTPException(404)) == (
"HTTPException(status_code=404, detail='Not Found')"
)
Expand All @@ -166,12 +174,12 @@ class CustomHTTPException(HTTPException):
)


def test_websocket_str():
def test_websocket_str() -> None:
assert str(WebSocketException(1008)) == "1008: "
assert str(WebSocketException(1008, "Policy Violation")) == "1008: Policy Violation"


def test_websocket_repr():
def test_websocket_repr() -> None:
assert repr(WebSocketException(1008, reason="Policy Violation")) == (
"WebSocketException(code=1008, reason='Policy Violation')"
)
Expand All @@ -198,7 +206,7 @@ def test_exception_middleware_deprecation() -> None:
starlette.exceptions.ExceptionMiddleware


def test_request_in_app_and_handler_is_the_same_object(client) -> None:
def test_request_in_app_and_handler_is_the_same_object(client: TestClient) -> None:
response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!")
assert response.status_code == 422
assert response.json() == {"body": "Hello!"}