diff --git a/starlette/routing.py b/starlette/routing.py index 38bb3517e..cb1b1d919 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -217,6 +217,7 @@ def __init__( methods: typing.Optional[typing.List[str]] = None, name: typing.Optional[str] = None, include_in_schema: bool = True, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -236,6 +237,10 @@ def __init__( # Endpoint is a class. Treat it as ASGI. self.app = endpoint + if middleware is not None: + for cls, options in reversed(middleware): + self.app = cls(app=self.app, **options) + if methods is None: self.methods = None else: @@ -309,6 +314,7 @@ def __init__( endpoint: typing.Callable[..., typing.Any], *, name: typing.Optional[str] = None, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -325,6 +331,10 @@ def __init__( # Endpoint is a class. Treat it as ASGI. self.app = endpoint + if middleware is not None: + for cls, options in reversed(middleware): + self.app = cls(app=self.app, **options) + self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: diff --git a/starlette/testclient.py b/starlette/testclient.py index a0046f2ff..4a750cd8c 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -710,7 +710,7 @@ def delete( # type: ignore[override] def websocket_connect( self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any - ) -> typing.Any: + ) -> "WebSocketTestSession": url = urljoin("ws://testserver", url) headers = kwargs.get("headers", {}) headers.setdefault("connection", "upgrade") diff --git a/tests/test_routing.py b/tests/test_routing.py index 7159a4bfc..7644f5a24 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -919,6 +919,18 @@ def assert_middleware_header_route(request: Request) -> Response: return Response() +route_with_middleware = Starlette( + routes=[ + Route( + "/http", + endpoint=assert_middleware_header_route, + methods=["GET"], + middleware=[Middleware(AddHeadersMiddleware)], + ), + Route("/home", homepage), + ] +) + mounted_routes_with_middleware = Starlette( routes=[ Mount( @@ -960,9 +972,10 @@ def assert_middleware_header_route(request: Request) -> Response: [ mounted_routes_with_middleware, mounted_app_with_middleware, + route_with_middleware, ], ) -def test_mount_middleware( +def test_base_route_middleware( test_client_factory: typing.Callable[..., TestClient], app: Starlette, ) -> None: @@ -1076,6 +1089,44 @@ async def modified_send(msg: Message) -> None: assert "X-Mounted" in resp.headers +def test_websocket_route_middleware( + test_client_factory: typing.Callable[..., TestClient] +): + async def websocket_endpoint(session: WebSocket): + await session.accept() + await session.send_text("Hello, world!") + await session.close() + + class WebsocketMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + async def modified_send(msg: Message) -> None: + if msg["type"] == "websocket.accept": + msg["headers"].append((b"X-Test", b"Set by middleware")) + await send(msg) + + await self.app(scope, receive, modified_send) + + app = Starlette( + routes=[ + WebSocketRoute( + "/ws", + endpoint=websocket_endpoint, + middleware=[Middleware(WebsocketMiddleware)], + ) + ] + ) + + client = test_client_factory(app) + + with client.websocket_connect("/ws") as websocket: + text = websocket.receive_text() + assert text == "Hello, world!" + assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")] + + def test_route_repr() -> None: route = Route("/welcome", endpoint=homepage) assert (