diff --git a/docs/middleware.md b/docs/middleware.md index b21914291..626ba54c5 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -39,8 +39,37 @@ application would look like this: * `HTTPSRedirectMiddleware` * `ExceptionMiddleware` * Routing +* Route middleware * Endpoint +Middleware can also be added at the route level, in which case it will be executed after routing occurs: + +```python +from starlette.applications import Starlette +from starlette.middleware.trustedhost import TrustedHostMiddleware + +example_route_middleware = [ + Middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com']), +] + +routes = [ + Route( + "/example", + endpoint=..., + middleware=example_route_middleware, + ) +] + +middleware = [ + Middleware(HTTPSRedirectMiddleware) +] + +app = Starlette(routes=routes, middleware=middleware) +``` + +Note that since this is run after routing, modifying the path in the middleware will have no effect. +There is also no built-in error handling for route middleware, so your middleware will need to handle exceptions and resource cleanup itself. + The following middleware implementations are available in the Starlette package: ## CORSMiddleware diff --git a/starlette/routing.py b/starlette/routing.py index 0388304c9..9948203da 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -14,6 +14,7 @@ from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath from starlette.exceptions import HTTPException +from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse from starlette.types import ASGIApp, Receive, Scope, Send @@ -195,6 +196,7 @@ def __init__( methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -214,6 +216,10 @@ def __init__( # Endpoint is a class. Treat it as ASGI. self.app = endpoint + if middleware is not None: + for cls, options in reversed(middleware): + self.app = cls(app=self.app, **options) + if methods is None: self.methods = None else: @@ -333,12 +339,16 @@ def __eq__(self, other: typing.Any) -> bool: class Mount(BaseRoute): + _routes: typing.List[BaseRoute] + def __init__( self, path: str, app: ASGIApp = None, routes: typing.Sequence[BaseRoute] = None, name: str = None, + *, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert ( @@ -349,14 +359,19 @@ def __init__( self.app: ASGIApp = app else: self.app = Router(routes=routes) + self._routes = getattr(self.app, "routes", []) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path( self.path + "/{path:path}" ) + if middleware is not None: + for cls, options in reversed(middleware): + self.app = cls(app=self.app, **options) + @property def routes(self) -> typing.List[BaseRoute]: - return getattr(self.app, "routes", []) + return self._routes def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): @@ -404,7 +419,7 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: ) if path_kwarg is not None: remaining_params["path"] = path_kwarg - for route in self.routes or []: + for route in self.routes: try: url = route.url_path_for(remaining_name, **remaining_params) return URLPath( diff --git a/tests/test_routing.py b/tests/test_routing.py index 7077c5616..8848e8979 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,11 +1,24 @@ import functools +import typing import uuid import pytest from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response -from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute +from starlette.routing import ( + BaseRoute, + Host, + Mount, + NoMatchFound, + Route, + Router, + WebSocketRoute, +) +from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -710,3 +723,132 @@ def test_duplicated_param_names(): match="Duplicated param names id, name at path /{id}/{name}/{id}/{name}", ): Route("/{id}/{name}/{id}/{name}", user) + + +def assert_middleware_header_route(request: Request): + assert getattr(request.state, "middleware_touched") == "Set by middleware" + return Response() + + +class AddHeadersMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + setattr(request.state, "middleware_touched", "Set by middleware") + response: Response = await call_next(request) + response.headers["X-Test"] = "Set by middleware" + return response + + +route_with_middleware = ( + Route( + "/http", + endpoint=assert_middleware_header_route, + methods=["GET"], + middleware=[Middleware(AddHeadersMiddleware)], + ), +) + + +mounted_routes_with_middleware = Mount( + "/http", + routes=[ + Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", + ), + ], + middleware=[Middleware(AddHeadersMiddleware)], +) + + +mounted_app_with_middleware = Mount( + "/http", + app=Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", + ), + middleware=[Middleware(AddHeadersMiddleware)], +) + + +mounted_routes_with_route_middleware = Mount( + "/http", + routes=[ + Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", + middleware=[Middleware(AddHeadersMiddleware)], + ), + ], +) + + +mounted_app_with_route_middleware = Mount( + "/http", + app=Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", + middleware=[Middleware(AddHeadersMiddleware)], + ), +) + + +@pytest.mark.parametrize( + "route", + [ + mounted_routes_with_middleware, + mounted_routes_with_middleware, + mounted_app_with_middleware, + mounted_routes_with_route_middleware, + mounted_app_with_route_middleware, + ], +) +def test_route_level_middleware( + test_client_factory: typing.Callable[..., TestClient], + route: BaseRoute, +) -> None: + test_client = test_client_factory(Router([route])) + response = test_client.get("/http") + assert response.status_code == 200 + assert response.headers["X-Test"] == "Set by middleware" + + +@pytest.mark.parametrize( + "route", + [ + mounted_routes_with_middleware, + mounted_routes_with_middleware, + mounted_routes_with_route_middleware, + ], +) +def test_mount_middleware_url_path_for_(route: BaseRoute) -> None: + """Checks that url_path_for still works with middelware on Mounts""" + router = Router([route]) + assert router.url_path_for("route") == "/http/" + + +def test_add_route_to_app_after_mount( + test_client_factory: typing.Callable[..., TestClient], +) -> None: + """Checks that mounds will pick up routes + added to the underlaying app after it is mounted + """ + inner_app = Router() + app = Mount("/http", app=inner_app) + inner_app.add_route( + "/inner", + endpoint=lambda request: Response(), + methods=["GET"], + ) + client = test_client_factory(app) + response = client.get("/http/inner") + assert response.status_code == 200