diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 9895a4559..eeb0f2322 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,12 +1,19 @@ +from typing import Callable, Iterator + import pytest from starlette.endpoints import HTTPEndpoint, WebSocketEndpoint +from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route, Router +from starlette.testclient import TestClient +from starlette.websockets import WebSocket + +TestClientFactory = Callable[..., TestClient] class Homepage(HTTPEndpoint): - async def get(self, request): + async def get(self, request: Request) -> PlainTextResponse: username = request.path_params.get("username") if username is None: return PlainTextResponse("Hello, world!") @@ -19,33 +26,33 @@ async def get(self, request): @pytest.fixture -def client(test_client_factory): +def client(test_client_factory: TestClientFactory) -> Iterator[TestClient]: with test_client_factory(app) as client: yield client -def test_http_endpoint_route(client): +def test_http_endpoint_route(client: TestClient) -> None: response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world!" -def test_http_endpoint_route_path_params(client): +def test_http_endpoint_route_path_params(client: TestClient) -> None: response = client.get("/tomchristie") assert response.status_code == 200 assert response.text == "Hello, tomchristie!" -def test_http_endpoint_route_method(client): +def test_http_endpoint_route_method(client: TestClient) -> None: response = client.post("/") assert response.status_code == 405 assert response.text == "Method Not Allowed" assert response.headers["allow"] == "GET" -def test_websocket_endpoint_on_connect(test_client_factory): +def test_websocket_endpoint_on_connect(test_client_factory: TestClientFactory) -> None: class WebSocketApp(WebSocketEndpoint): - async def on_connect(self, websocket): + async def on_connect(self, websocket: WebSocket) -> None: assert websocket["subprotocols"] == ["soap", "wamp"] await websocket.accept(subprotocol="wamp") @@ -54,11 +61,13 @@ async def on_connect(self, websocket): assert websocket.accepted_subprotocol == "wamp" -def test_websocket_endpoint_on_receive_bytes(test_client_factory): +def test_websocket_endpoint_on_receive_bytes( + test_client_factory: TestClientFactory, +) -> None: class WebSocketApp(WebSocketEndpoint): encoding = "bytes" - async def on_receive(self, websocket, data): + async def on_receive(self, websocket: WebSocket, data: bytes) -> None: await websocket.send_bytes(b"Message bytes was: " + data) client = test_client_factory(WebSocketApp) @@ -72,11 +81,13 @@ async def on_receive(self, websocket, data): websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json(test_client_factory): +def test_websocket_endpoint_on_receive_json( + test_client_factory: TestClientFactory, +) -> None: class WebSocketApp(WebSocketEndpoint): encoding = "json" - async def on_receive(self, websocket, data): + async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_json({"message": data}) client = test_client_factory(WebSocketApp) @@ -90,11 +101,13 @@ async def on_receive(self, websocket, data): websocket.send_text("Hello world") -def test_websocket_endpoint_on_receive_json_binary(test_client_factory): +def test_websocket_endpoint_on_receive_json_binary( + test_client_factory: TestClientFactory, +) -> None: class WebSocketApp(WebSocketEndpoint): encoding = "json" - async def on_receive(self, websocket, data): + async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_json({"message": data}, mode="binary") client = test_client_factory(WebSocketApp) @@ -104,11 +117,13 @@ async def on_receive(self, websocket, data): assert data == {"message": {"hello": "world"}} -def test_websocket_endpoint_on_receive_text(test_client_factory): +def test_websocket_endpoint_on_receive_text( + test_client_factory: TestClientFactory, +) -> None: class WebSocketApp(WebSocketEndpoint): encoding = "text" - async def on_receive(self, websocket, data): + async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_text(f"Message text was: {data}") client = test_client_factory(WebSocketApp) @@ -122,11 +137,11 @@ async def on_receive(self, websocket, data): websocket.send_bytes(b"Hello world") -def test_websocket_endpoint_on_default(test_client_factory): +def test_websocket_endpoint_on_default(test_client_factory: TestClientFactory) -> None: class WebSocketApp(WebSocketEndpoint): encoding = None - async def on_receive(self, websocket, data): + async def on_receive(self, websocket: WebSocket, data: str) -> None: await websocket.send_text(f"Message text was: {data}") client = test_client_factory(WebSocketApp) @@ -136,9 +151,11 @@ async def on_receive(self, websocket, data): assert _text == "Message text was: Hello, world!" -def test_websocket_endpoint_on_disconnect(test_client_factory): +def test_websocket_endpoint_on_disconnect( + test_client_factory: TestClientFactory, +) -> None: class WebSocketApp(WebSocketEndpoint): - async def on_disconnect(self, websocket, close_code): + async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: assert close_code == 1001 await websocket.close(code=close_code)