diff --git a/tests/test_responses.py b/tests/test_responses.py index 291c46e6d..57a594901 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -3,6 +3,7 @@ import time from http.cookies import SimpleCookie from pathlib import Path +from typing import AsyncIterator, Callable, Iterator, Union import anyio import pytest @@ -19,11 +20,13 @@ StreamingResponse, ) from starlette.testclient import TestClient -from starlette.types import Message +from starlette.types import Message, Receive, Scope, Send +TestClientFactory = Callable[..., TestClient] -def test_text_response(test_client_factory): - async def app(scope, receive, send): + +def test_text_response(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("hello, world", media_type="text/plain") await response(scope, receive, send) @@ -32,8 +35,8 @@ async def app(scope, receive, send): assert response.text == "hello, world" -def test_bytes_response(test_client_factory): - async def app(scope, receive, send): +def test_bytes_response(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(b"xxxxx", media_type="image/png") await response(scope, receive, send) @@ -42,8 +45,8 @@ async def app(scope, receive, send): assert response.content == b"xxxxx" -def test_json_none_response(test_client_factory): - async def app(scope, receive, send): +def test_json_none_response(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = JSONResponse(None) await response(scope, receive, send) @@ -53,8 +56,8 @@ async def app(scope, receive, send): assert response.content == b"null" -def test_redirect_response(test_client_factory): - async def app(scope, receive, send): +def test_redirect_response(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = Response("hello, world", media_type="text/plain") else: @@ -67,8 +70,8 @@ async def app(scope, receive, send): assert response.url == "http://testserver/" -def test_quoting_redirect_response(test_client_factory): - async def app(scope, receive, send): +def test_quoting_redirect_response(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/I ♥ Starlette/": response = Response("hello, world", media_type="text/plain") else: @@ -81,8 +84,10 @@ async def app(scope, receive, send): assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/" -def test_redirect_response_content_length_header(test_client_factory): - async def app(scope, receive, send): +def test_redirect_response_content_length_header( + test_client_factory: TestClientFactory, +) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["path"] == "/": response = Response("hello", media_type="text/plain") # pragma: nocover else: @@ -95,18 +100,18 @@ async def app(scope, receive, send): assert response.headers["content-length"] == "0" -def test_streaming_response(test_client_factory): +def test_streaming_response(test_client_factory: TestClientFactory) -> None: filled_by_bg_task = "" - async def app(scope, receive, send): - async def numbers(minimum, maximum): + async def app(scope: Scope, receive: Receive, send: Send) -> None: + async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: yield ", " await anyio.sleep(0) - async def numbers_for_cleanup(start=1, stop=5): + async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: nonlocal filled_by_bg_task async for thing in numbers(start, stop): filled_by_bg_task = filled_by_bg_task + thing @@ -125,16 +130,18 @@ async def numbers_for_cleanup(start=1, stop=5): assert filled_by_bg_task == "6, 7, 8, 9" -def test_streaming_response_custom_iterator(test_client_factory): - async def app(scope, receive, send): +def test_streaming_response_custom_iterator( + test_client_factory: TestClientFactory, +) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterator: - def __init__(self): + def __init__(self) -> None: self._called = 0 - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[str]: return self - async def __anext__(self): + async def __anext__(self) -> str: if self._called == 5: raise StopAsyncIteration() self._called += 1 @@ -148,10 +155,12 @@ async def __anext__(self): assert response.text == "12345" -def test_streaming_response_custom_iterable(test_client_factory): - async def app(scope, receive, send): +def test_streaming_response_custom_iterable( + test_client_factory: TestClientFactory, +) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterable: - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator[Union[str, bytes]]: for i in range(5): yield str(i + 1) @@ -163,9 +172,9 @@ async def __aiter__(self): assert response.text == "12345" -def test_sync_streaming_response(test_client_factory): - async def app(scope, receive, send): - def numbers(minimum, maximum): +def test_sync_streaming_response(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + def numbers(minimum: int, maximum: int) -> Iterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: @@ -180,8 +189,8 @@ def numbers(minimum, maximum): assert response.text == "1, 2, 3, 4, 5" -def test_response_headers(test_client_factory): - async def app(scope, receive, send): +def test_response_headers(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: headers = {"x-header-1": "123", "x-header-2": "456"} response = Response("hello, world", media_type="text/plain", headers=headers) response.headers["x-header-2"] = "789" @@ -193,7 +202,7 @@ async def app(scope, receive, send): assert response.headers["x-header-2"] == "789" -def test_response_phrase(test_client_factory): +def test_response_phrase(test_client_factory: TestClientFactory) -> None: app = Response(status_code=204) client = test_client_factory(app) response = client.get("/") @@ -205,7 +214,7 @@ def test_response_phrase(test_client_factory): assert response.reason_phrase == "" -def test_file_response(tmpdir, test_client_factory): +def test_file_response(tmpdir: Path, test_client_factory: TestClientFactory) -> None: path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -213,21 +222,21 @@ def test_file_response(tmpdir, test_client_factory): filled_by_bg_task = "" - async def numbers(minimum, maximum): + async def numbers(minimum: int, maximum: int) -> AsyncIterator[str]: for i in range(minimum, maximum + 1): yield str(i) if i != maximum: yield ", " await anyio.sleep(0) - async def numbers_for_cleanup(start=1, stop=5): + async def numbers_for_cleanup(start: int = 1, stop: int = 5) -> None: nonlocal filled_by_bg_task async for thing in numbers(start, stop): filled_by_bg_task = filled_by_bg_task + thing cleanup_task = BackgroundTask(numbers_for_cleanup, start=6, stop=9) - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = FileResponse( path=path, filename="example.png", background=cleanup_task ) @@ -248,7 +257,7 @@ async def app(scope, receive, send): @pytest.mark.anyio -async def test_file_response_on_head_method(tmpdir: Path): +async def test_file_response_on_head_method(tmpdir: Path) -> None: path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -277,7 +286,9 @@ async def send(message: Message) -> None: await app({"type": "http", "method": "head"}, receive, send) -def test_file_response_with_directory_raises_error(tmpdir, test_client_factory): +def test_file_response_with_directory_raises_error( + tmpdir: Path, test_client_factory: TestClientFactory +) -> None: app = FileResponse(path=tmpdir, filename="example.png") client = test_client_factory(app) with pytest.raises(RuntimeError) as exc_info: @@ -285,7 +296,9 @@ def test_file_response_with_directory_raises_error(tmpdir, test_client_factory): assert "is not a file" in str(exc_info.value) -def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factory): +def test_file_response_with_missing_file_raises_error( + tmpdir: Path, test_client_factory: TestClientFactory +) -> None: path = os.path.join(tmpdir, "404.txt") app = FileResponse(path=path, filename="404.txt") client = test_client_factory(app) @@ -294,7 +307,9 @@ def test_file_response_with_missing_file_raises_error(tmpdir, test_client_factor assert "does not exist" in str(exc_info.value) -def test_file_response_with_chinese_filename(tmpdir, test_client_factory): +def test_file_response_with_chinese_filename( + tmpdir: Path, test_client_factory: TestClientFactory +) -> None: content = b"file content" filename = "你好.txt" # probably "Hello.txt" in Chinese path = os.path.join(tmpdir, filename) @@ -309,7 +324,9 @@ def test_file_response_with_chinese_filename(tmpdir, test_client_factory): assert response.headers["content-disposition"] == expected_disposition -def test_file_response_with_inline_disposition(tmpdir, test_client_factory): +def test_file_response_with_inline_disposition( + tmpdir: Path, test_client_factory: TestClientFactory +) -> None: content = b"file content" filename = "hello.txt" path = os.path.join(tmpdir, filename) @@ -324,13 +341,15 @@ def test_file_response_with_inline_disposition(tmpdir, test_client_factory): assert response.headers["content-disposition"] == expected_disposition -def test_file_response_with_method_warns(tmpdir, test_client_factory): +def test_file_response_with_method_warns( + tmpdir: Path, test_client_factory: TestClientFactory +) -> None: with pytest.warns(DeprecationWarning): FileResponse(path=tmpdir, filename="example.png", method="GET") @pytest.mark.anyio -async def test_file_response_with_pathsend(tmpdir: Path): +async def test_file_response_with_pathsend(tmpdir: Path) -> None: path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -361,12 +380,14 @@ async def send(message: Message) -> None: ) -def test_set_cookie(test_client_factory, monkeypatch): +def test_set_cookie( + test_client_factory: TestClientFactory, monkeypatch: pytest.MonkeyPatch +) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie( "mycookie", @@ -401,12 +422,16 @@ async def app(scope, receive, send): pytest.param(10, id="int"), ], ) -def test_expires_on_set_cookie(test_client_factory, monkeypatch, expires): +def test_expires_on_set_cookie( + test_client_factory: TestClientFactory, + monkeypatch: pytest.MonkeyPatch, + expires: str, +) -> None: # Mock time used as a reference for `Expires` by stdlib `SimpleCookie`. mocked_now = dt.datetime(2037, 1, 22, 12, 0, 0, tzinfo=dt.timezone.utc) monkeypatch.setattr(time, "time", lambda: mocked_now.timestamp()) - async def app(scope, receive, send): + async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response("Hello, world!", media_type="text/plain") response.set_cookie("mycookie", "myvalue", expires=expires) await response(scope, receive, send) @@ -417,8 +442,8 @@ async def app(scope, receive, send): assert cookie["mycookie"]["expires"] == "Thu, 22 Jan 2037 12:00:10 GMT" -def test_delete_cookie(test_client_factory): - async def app(scope, receive, send): +def test_delete_cookie(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) response = Response("Hello, world!", media_type="text/plain") if request.cookies.get("mycookie"): @@ -434,7 +459,7 @@ async def app(scope, receive, send): assert not response.cookies.get("mycookie") -def test_populate_headers(test_client_factory): +def test_populate_headers(test_client_factory: TestClientFactory) -> None: app = Response(content="hi", headers={}, media_type="text/html") client = test_client_factory(app) response = client.get("/") @@ -443,14 +468,14 @@ def test_populate_headers(test_client_factory): assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_head_method(test_client_factory): +def test_head_method(test_client_factory: TestClientFactory) -> None: app = Response("hello, world", media_type="text/plain") client = test_client_factory(app) response = client.head("/") assert response.text == "" -def test_empty_response(test_client_factory): +def test_empty_response(test_client_factory: TestClientFactory) -> None: app = Response() client: TestClient = test_client_factory(app) response = client.get("/") @@ -459,28 +484,32 @@ def test_empty_response(test_client_factory): assert "content-type" not in response.headers -def test_empty_204_response(test_client_factory): +def test_empty_204_response(test_client_factory: TestClientFactory) -> None: app = Response(status_code=204) client: TestClient = test_client_factory(app) response = client.get("/") assert "content-length" not in response.headers -def test_non_empty_response(test_client_factory): +def test_non_empty_response(test_client_factory: TestClientFactory) -> None: app = Response(content="hi") client: TestClient = test_client_factory(app) response = client.get("/") assert response.headers["content-length"] == "2" -def test_response_do_not_add_redundant_charset(test_client_factory): +def test_response_do_not_add_redundant_charset( + test_client_factory: TestClientFactory, +) -> None: app = Response(media_type="text/plain; charset=utf-8") client = test_client_factory(app) response = client.get("/") assert response.headers["content-type"] == "text/plain; charset=utf-8" -def test_file_response_known_size(tmpdir, test_client_factory): +def test_file_response_known_size( + tmpdir: Path, test_client_factory: TestClientFactory +) -> None: path = os.path.join(tmpdir, "xyz") content = b"" * 1000 with open(path, "wb") as file: @@ -492,14 +521,16 @@ def test_file_response_known_size(tmpdir, test_client_factory): assert response.headers["content-length"] == str(len(content)) -def test_streaming_response_unknown_size(test_client_factory): +def test_streaming_response_unknown_size( + test_client_factory: TestClientFactory, +) -> None: app = StreamingResponse(content=iter(["hello", "world"])) client: TestClient = test_client_factory(app) response = client.get("/") assert "content-length" not in response.headers -def test_streaming_response_known_size(test_client_factory): +def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None: app = StreamingResponse( content=iter(["hello", "world"]), headers={"content-length": "10"} ) @@ -509,16 +540,16 @@ def test_streaming_response_known_size(test_client_factory): @pytest.mark.anyio -async def test_streaming_response_stops_if_receiving_http_disconnect(): +async def test_streaming_response_stops_if_receiving_http_disconnect() -> None: streamed = 0 disconnected = anyio.Event() - async def receive_disconnect(): + async def receive_disconnect() -> Message: await disconnected.wait() return {"type": "http.disconnect"} - async def send(message): + async def send(message: Message) -> None: nonlocal streamed if message["type"] == "http.response.body": streamed += len(message.get("body", b"")) @@ -526,7 +557,7 @@ async def send(message): if streamed >= 16: disconnected.set() - async def stream_indefinitely(): + async def stream_indefinitely() -> AsyncIterator[bytes]: while True: # Need a sleep for the event loop to switch to another task await anyio.sleep(0)