diff --git a/tests/conftest.py b/tests/conftest.py index 724ca65d3..21f56a054 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,12 @@ from __future__ import annotations import functools -from typing import Any, Callable, Literal +from typing import Any, Literal import pytest from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory @pytest.fixture @@ -21,4 +20,4 @@ def test_client_factory( TestClient, backend=anyio_backend_name, backend_options=anyio_backend_options, - ) + ) # type: ignore diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 2176404d8..3ad1751a2 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -5,7 +5,6 @@ from typing import ( Any, AsyncGenerator, - Callable, Generator, ) @@ -23,8 +22,7 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket - -TestClientFactory = Callable[[ASGIApp], TestClient] +from tests.types import TestClientFactory class CustomMiddleware(BaseHTTPMiddleware): diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 09ec9513f..630361243 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -1,15 +1,10 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from starlette.testclient import TestClient -from starlette.types import ASGIApp - -TestClientFactory = Callable[[ASGIApp], TestClient] +from tests.types import TestClientFactory def test_cors_allow_all( diff --git a/tests/middleware/test_errors.py b/tests/middleware/test_errors.py index a2dbabd8a..e32f406ae 100644 --- a/tests/middleware/test_errors.py +++ b/tests/middleware/test_errors.py @@ -1,4 +1,4 @@ -from typing import Any, Callable +from typing import Any import pytest @@ -8,10 +8,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route -from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_handler( diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index 5bfecadb7..ea9489452 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,15 +1,10 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse from starlette.routing import Route -from starlette.testclient import TestClient -from starlette.types import ASGIApp - -TestClientFactory = Callable[[ASGIApp], TestClient] +from tests.types import TestClientFactory def test_gzip_responses(test_client_factory: TestClientFactory) -> None: @@ -29,7 +24,9 @@ def homepage(request: Request) -> PlainTextResponse: assert int(response.headers["Content-Length"]) < 4000 -def test_gzip_not_in_accept_encoding(test_client_factory: TestClientFactory) -> None: +def test_gzip_not_in_accept_encoding( + test_client_factory: TestClientFactory, +) -> None: def homepage(request: Request) -> PlainTextResponse: return PlainTextResponse("x" * 4000, status_code=200) diff --git a/tests/middleware/test_https_redirect.py b/tests/middleware/test_https_redirect.py index 9195694a3..22dfc14b6 100644 --- a/tests/middleware/test_https_redirect.py +++ b/tests/middleware/test_https_redirect.py @@ -1,14 +1,10 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_https_redirect_middleware(test_client_factory: TestClientFactory) -> None: diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 4fbeec88c..9a0d70a0d 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -1,5 +1,4 @@ import re -from typing import Callable from starlette.applications import Starlette from starlette.middleware import Middleware @@ -8,8 +7,7 @@ from starlette.responses import JSONResponse from starlette.routing import Mount, Route from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def view_session(request: Request) -> JSONResponse: diff --git a/tests/middleware/test_trusted_host.py b/tests/middleware/test_trusted_host.py index 466302210..ddff46c48 100644 --- a/tests/middleware/test_trusted_host.py +++ b/tests/middleware/test_trusted_host.py @@ -1,14 +1,10 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.requests import Request from starlette.responses import PlainTextResponse from starlette.routing import Route -from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_trusted_host_middleware(test_client_factory: TestClientFactory) -> None: diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 69842d3ad..58696bb65 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -5,10 +5,9 @@ from starlette._utils import collapse_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ -from starlette.testclient import TestClient +from tests.types import TestClientFactory WSGIResponse = Iterable[bytes] -TestClientFactory = Callable[..., TestClient] StartResponse = Callable[..., Any] Environment = Dict[str, Any] diff --git a/tests/test_applications.py b/tests/test_applications.py index 5b6c9d545..1a7228a33 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,7 +1,7 @@ import os from contextlib import asynccontextmanager from pathlib import Path -from typing import AsyncGenerator, AsyncIterator, Callable, Generator +from typing import AsyncGenerator, AsyncIterator, Generator import anyio import pytest @@ -20,8 +20,7 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory async def error_500(request: Request, exc: HTTPException) -> JSONResponse: @@ -132,7 +131,9 @@ def custom_ws_exception_handler(websocket: WebSocket, exc: CustomWSException) -> @pytest.fixture -def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]: +def client( + test_client_factory: TestClientFactory, +) -> Generator[TestClient, None, None]: with test_client_factory(app) as client: yield client diff --git a/tests/test_authentication.py b/tests/test_authentication.py index ecddda75e..35c1110d1 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -21,10 +21,9 @@ from starlette.requests import HTTPConnection, Request from starlette.responses import JSONResponse, Response from starlette.routing import Route, WebSocketRoute -from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect +from tests.types import TestClientFactory -TestClientFactory = Callable[..., TestClient] AsyncEndpoint = Callable[..., Awaitable[Response]] SyncEndpoint = Callable[..., Response] diff --git a/tests/test_background.py b/tests/test_background.py index 846deecfd..cbffcc06a 100644 --- a/tests/test_background.py +++ b/tests/test_background.py @@ -1,13 +1,9 @@ -from typing import Callable - import pytest from starlette.background import BackgroundTask, BackgroundTasks from starlette.responses import Response -from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_async_task(test_client_factory: TestClientFactory) -> None: diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index aba3ceb1a..bac6814e4 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,5 +1,5 @@ from contextvars import ContextVar -from typing import Callable, Iterator +from typing import Iterator import anyio import pytest @@ -9,9 +9,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route -from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory @pytest.mark.anyio diff --git a/tests/test_convertors.py b/tests/test_convertors.py index 72ee17a82..520c98767 100644 --- a/tests/test_convertors.py +++ b/tests/test_convertors.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Callable, Iterator +from typing import Iterator import pytest @@ -8,9 +8,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Route, Router -from starlette.testclient import TestClient - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory @pytest.fixture(scope="module", autouse=True) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index eeb0f2322..96ad3c4ea 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterator +from typing import Iterator import pytest @@ -8,8 +8,7 @@ from starlette.routing import Route, Router from starlette.testclient import TestClient from starlette.websockets import WebSocket - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory class Homepage(HTTPEndpoint): @@ -50,7 +49,9 @@ def test_http_endpoint_route_method(client: TestClient) -> None: assert response.headers["allow"] == "GET" -def test_websocket_endpoint_on_connect(test_client_factory: TestClientFactory) -> None: +def test_websocket_endpoint_on_connect( + test_client_factory: TestClientFactory, +) -> None: class WebSocketApp(WebSocketEndpoint): async def on_connect(self, websocket: WebSocket) -> None: assert websocket["subprotocols"] == ["soap", "wamp"] @@ -137,7 +138,9 @@ async def on_receive(self, websocket: WebSocket, data: str) -> None: websocket.send_bytes(b"Hello world") -def test_websocket_endpoint_on_default(test_client_factory: TestClientFactory) -> None: +def test_websocket_endpoint_on_default( + test_client_factory: TestClientFactory, +) -> None: class WebSocketApp(WebSocketEndpoint): encoding = None diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 401ad8212..ca4e1d72b 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Generator +from typing import Generator import pytest @@ -10,8 +10,7 @@ from starlette.routing import Route, Router, WebSocketRoute from starlette.testclient import TestClient from starlette.types import Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def raise_runtime_error(request: Request) -> None: @@ -82,7 +81,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @pytest.fixture -def client(test_client_factory: TestClientFactory) -> Generator[TestClient, None, None]: +def client( + test_client_factory: TestClientFactory, +) -> Generator[TestClient, None, None]: with test_client_factory(app) as client: yield client diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index ed2226878..8d97a0ba7 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -13,10 +13,8 @@ from starlette.requests import Request from starlette.responses import JSONResponse from starlette.routing import Mount -from starlette.testclient import TestClient from starlette.types import ASGIApp, Receive, Scope, Send - -TestClientFactory = typing.Callable[..., TestClient] +from tests.types import TestClientFactory class ForceMultipartDict(typing.Dict[typing.Any, typing.Any]): diff --git a/tests/test_requests.py b/tests/test_requests.py index d8e2e9477..be31074c0 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,7 +1,7 @@ from __future__ import annotations import sys -from typing import Any, Callable, Iterator +from typing import Any, Iterator import anyio import pytest @@ -9,10 +9,8 @@ from starlette.datastructures import Address, State from starlette.requests import ClientDisconnect, Request from starlette.responses import JSONResponse, PlainTextResponse, Response -from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_request_url(test_client_factory: TestClientFactory) -> None: @@ -133,7 +131,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.json() == {"form": {"abc": "123 @"}} -def test_request_form_context_manager(test_client_factory: TestClientFactory) -> None: +def test_request_form_context_manager( + test_client_factory: TestClientFactory, +) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) async with request.form() as form: diff --git a/tests/test_responses.py b/tests/test_responses.py index fa3c1009f..dfd56cf43 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -5,7 +5,7 @@ import time from http.cookies import SimpleCookie from pathlib import Path -from typing import AsyncIterator, Callable, Iterator +from typing import AsyncIterator, Iterator import anyio import pytest @@ -23,8 +23,7 @@ ) from starlette.testclient import TestClient from starlette.types import Message, Receive, Scope, Send - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_text_response(test_client_factory: TestClientFactory) -> None: @@ -532,7 +531,9 @@ def test_streaming_response_unknown_size( assert "content-length" not in response.headers -def test_streaming_response_known_size(test_client_factory: TestClientFactory) -> None: +def test_streaming_response_known_size( + test_client_factory: TestClientFactory, +) -> None: app = StreamingResponse( content=iter(["hello", "world"]), headers={"content-length": "10"} ) diff --git a/tests/test_routing.py b/tests/test_routing.py index b75fc47f0..e26295a1f 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -17,8 +17,7 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect - -TestClientFactory = typing.Callable[..., TestClient] +from tests.types import TestClientFactory def homepage(request: Request) -> Response: diff --git a/tests/test_schemas.py b/tests/test_schemas.py index e00b2b8de..f4a5b4ad9 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -1,20 +1,16 @@ -from typing import Callable - from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.schemas import SchemaGenerator -from starlette.testclient import TestClient from starlette.websockets import WebSocket +from tests.types import TestClientFactory schemas = SchemaGenerator( {"openapi": "3.0.0", "info": {"title": "Example API", "version": "1.0"}} ) -TestClientFactory = Callable[..., TestClient] - def ws(session: WebSocket) -> None: """ws""" diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index d20bb7ef7..aa2dd0aa4 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -16,9 +16,7 @@ from starlette.responses import Response from starlette.routing import Mount from starlette.staticfiles import StaticFiles -from starlette.testclient import TestClient - -TestClientFactory = typing.Callable[..., TestClient] +from tests.types import TestClientFactory def test_staticfiles(tmpdir: Path, test_client_factory: TestClientFactory) -> None: diff --git a/tests/test_templates.py b/tests/test_templates.py index 10a1366bc..8e344f331 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import typing from pathlib import Path from unittest import mock @@ -16,9 +15,7 @@ from starlette.responses import Response from starlette.routing import Route from starlette.templating import Jinja2Templates -from starlette.testclient import TestClient - -TestClientFactory = typing.Callable[..., TestClient] +from tests.types import TestClientFactory def test_templates(tmpdir: Path, test_client_factory: TestClientFactory) -> None: diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 4ed1ced9a..d8ad4d783 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -4,7 +4,7 @@ import sys from asyncio import Task, current_task as asyncio_current_task from contextlib import asynccontextmanager -from typing import Any, AsyncGenerator, Callable +from typing import Any, AsyncGenerator import anyio import anyio.lowlevel @@ -20,8 +20,7 @@ from starlette.testclient import ASGIInstance, TestClient from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def mock_service_endpoint(request: Request) -> JSONResponse: @@ -212,7 +211,7 @@ async def inner(receive: Receive, send: Send) -> None: return inner - client = test_client_factory(app) + client = test_client_factory(app) # type: ignore response = client.get("/") assert response.text == "Hello, world!" @@ -252,13 +251,15 @@ async def asgi(receive: Receive, send: Send) -> None: return asgi - client = test_client_factory(app) + client = test_client_factory(app) # type: ignore with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} -def test_websocket_not_block_on_close(test_client_factory: TestClientFactory) -> None: +def test_websocket_not_block_on_close( + test_client_factory: TestClientFactory, +) -> None: def app(scope: Scope) -> ASGIInstance: async def asgi(receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) @@ -268,7 +269,7 @@ async def asgi(receive: Receive, send: Send) -> None: return asgi - client = test_client_factory(app) + client = test_client_factory(app) # type: ignore with client.websocket_connect("/") as websocket: ... assert websocket.should_close.is_set() diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 854c26914..d7488c479 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,5 +1,5 @@ import sys -from typing import Any, Callable, MutableMapping +from typing import Any, MutableMapping import anyio import pytest @@ -7,11 +7,10 @@ from starlette import status from starlette.responses import Response -from starlette.testclient import TestClient, WebSocketDenialResponse +from starlette.testclient import WebSocketDenialResponse from starlette.types import Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState - -TestClientFactory = Callable[..., TestClient] +from tests.types import TestClientFactory def test_websocket_url(test_client_factory: TestClientFactory) -> None: @@ -207,7 +206,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert data == {"message": {"hello": "world"}} -def test_websocket_concurrency_pattern(test_client_factory: TestClientFactory) -> None: +def test_websocket_concurrency_pattern( + test_client_factory: TestClientFactory, +) -> None: stream_send: ObjectSendStream[MutableMapping[str, Any]] stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] stream_send, stream_receive = anyio.create_memory_object_stream() @@ -379,7 +380,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert exc.value.code == status.WS_1000_NORMAL_CLOSURE -def test_send_response_duplicate_start(test_client_factory: TestClientFactory) -> None: +def test_send_response_duplicate_start( + test_client_factory: TestClientFactory, +) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) msg = await websocket.receive() @@ -564,7 +567,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: pass # pragma: nocover -def test_receive_bytes_before_accept(test_client_factory: TestClientFactory) -> None: +def test_receive_bytes_before_accept( + test_client_factory: TestClientFactory, +) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) await websocket.receive_bytes() diff --git a/tests/types.py b/tests/types.py new file mode 100644 index 000000000..4ab704159 --- /dev/null +++ b/tests/types.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import Protocol + +import httpx + +from starlette.testclient import TestClient +from starlette.types import ASGIApp + + +class TestClientFactory(Protocol): # pragma: no cover + __test__ = False # type: ignore + + def __call__( + self, + app: ASGIApp, + base_url: str = "http://testserver", + raise_server_exceptions: bool = True, + root_path: str = "", + cookies: httpx._types.CookieTypes | None = None, + headers: dict[str, str] | None = None, + follow_redirects: bool = True, + ) -> TestClient: + ...