From c2809bd50c4d98b2090bec134e519ab99861f729 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 12 Sep 2021 23:45:44 -0500 Subject: [PATCH 01/21] Add route-level middleware --- starlette/routing.py | 17 ++++++++++++++++- tests/test_routing.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/starlette/routing.py b/starlette/routing.py index 9a1a5e12d..69251b398 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -14,6 +14,7 @@ from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath from starlette.exceptions import HTTPException +from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse from starlette.types import ASGIApp, Receive, Scope, Send @@ -191,6 +192,7 @@ def __init__( methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -209,6 +211,10 @@ def __init__( else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint + + if middleware is not None: + for cls, options in reversed(middleware): + self.app = cls(app=self.app, **options) if methods is None: self.methods = None @@ -269,7 +275,12 @@ def __eq__(self, other: typing.Any) -> bool: class WebSocketRoute(BaseRoute): def __init__( - self, path: str, endpoint: typing.Callable, *, name: str = None + self, + path: str, + endpoint: typing.Callable, + *, + name: str = None, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -283,6 +294,10 @@ def __init__( # Endpoint is a class. Treat it as ASGI. self.app = endpoint + if middleware is not None: + for cls, options in reversed(middleware): + self.app = cls(app=self.app, **options) + self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: diff --git a/tests/test_routing.py b/tests/test_routing.py index 9e734b9cc..8a1e34665 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,10 +1,15 @@ import functools +from starlette.testclient import TestClient +import typing import uuid import pytest from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.responses import JSONResponse, PlainTextResponse, Response +from starlette.requests import Request from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute from starlette.websockets import WebSocket, WebSocketDisconnect @@ -648,3 +653,37 @@ def test_duplicated_param_names(): match="Duplicated param names id, name at path /{id}/{name}/{id}/{name}", ): Route("/{id}/{name}/{id}/{name}", user) + + + +def assert_middleware_header_route(request: Request): + assert getattr(request.state, "middleware_touched") == "Set by middleware" + return Response() + + +class AddHeadersMiddleware(BaseHTTPMiddleware): + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + setattr(request.state, "middleware_touched", "Set by middleware") + response: Response = await call_next(request) + response.headers["X-Test"] = "Set by middleware" + return response + + +middleware_router = Router( + [ + Route( + "/http", + endpoint=assert_middleware_header_route, + methods=["GET"], + middleware=[Middleware(AddHeadersMiddleware)] + ), + ] +) + + +def test_router_middleware_http(test_client_factory: typing.Callable[..., TestClient]) -> None: + test_client = test_client_factory(middleware_router) + response = test_client.get("/http") + assert response.status_code == 200 + assert response.headers["X-Test"] == "Set by middleware" From 2fb13a6e577591a5e1421bb1cf40fd8d763e68bf Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 12 Sep 2021 23:48:19 -0500 Subject: [PATCH 02/21] Remove websocker stuff --- starlette/routing.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 69251b398..49fe2f79b 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -280,7 +280,6 @@ def __init__( endpoint: typing.Callable, *, name: str = None, - middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path @@ -293,11 +292,7 @@ def __init__( else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint - - if middleware is not None: - for cls, options in reversed(middleware): - self.app = cls(app=self.app, **options) - + self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: From e57a42ffc1e01bc423252c6b8234cc15b244d24d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 12 Sep 2021 23:48:49 -0500 Subject: [PATCH 03/21] Remove spaces --- starlette/routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/routing.py b/starlette/routing.py index 49fe2f79b..e442e0e14 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -292,7 +292,7 @@ def __init__( else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint - + self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: From 6e4c6dc9d542ae04032210538a26eb533fff660b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 12 Sep 2021 23:49:34 -0500 Subject: [PATCH 04/21] remove newline --- tests/test_routing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_routing.py b/tests/test_routing.py index 8a1e34665..c02892ae2 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -655,7 +655,6 @@ def test_duplicated_param_names(): Route("/{id}/{name}/{id}/{name}", user) - def assert_middleware_header_route(request: Request): assert getattr(request.state, "middleware_touched") == "Set by middleware" return Response() From ad318575b6dae46bfb456bdc341d6c3a92a5ca8d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 12 Sep 2021 23:54:42 -0500 Subject: [PATCH 05/21] revert more changes --- starlette/routing.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index e442e0e14..698e97348 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -275,11 +275,7 @@ def __eq__(self, other: typing.Any) -> bool: class WebSocketRoute(BaseRoute): def __init__( - self, - path: str, - endpoint: typing.Callable, - *, - name: str = None, + self, path: str, endpoint: typing.Callable, *, name: str = None ) -> None: assert path.startswith("/"), "Routed paths must start with '/'" self.path = path From 4549f2cca9702e0d9b5ebdc0d0d02730ad817fe4 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 18 Sep 2021 12:53:12 -0500 Subject: [PATCH 06/21] linting --- starlette/routing.py | 2 +- tests/test_routing.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 698e97348..0f6e300dc 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -211,7 +211,7 @@ def __init__( else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint - + if middleware is not None: for cls, options in reversed(middleware): self.app = cls(app=self.app, **options) diff --git a/tests/test_routing.py b/tests/test_routing.py index c02892ae2..db252bdfc 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,5 +1,4 @@ import functools -from starlette.testclient import TestClient import typing import uuid @@ -8,9 +7,10 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute +from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -661,8 +661,9 @@ def assert_middleware_header_route(request: Request): class AddHeadersMiddleware(BaseHTTPMiddleware): - - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: setattr(request.state, "middleware_touched", "Set by middleware") response: Response = await call_next(request) response.headers["X-Test"] = "Set by middleware" @@ -675,13 +676,15 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - "/http", endpoint=assert_middleware_header_route, methods=["GET"], - middleware=[Middleware(AddHeadersMiddleware)] + middleware=[Middleware(AddHeadersMiddleware)], ), ] ) -def test_router_middleware_http(test_client_factory: typing.Callable[..., TestClient]) -> None: +def test_router_middleware_http( + test_client_factory: typing.Callable[..., TestClient] +) -> None: test_client = test_client_factory(middleware_router) response = test_client.get("/http") assert response.status_code == 200 From 5ec587175979241e95a5ef7c36e1546768ebd308 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 11 Nov 2021 12:05:09 -0600 Subject: [PATCH 07/21] add Mount middleware and docs --- docs/middleware.md | 25 +++++++++++++++++++++++++ starlette/routing.py | 6 ++++++ tests/test_routing.py | 26 ++++++++++++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/docs/middleware.md b/docs/middleware.md index 4c6fb8a9e..3c6c4c07f 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -41,6 +41,31 @@ application would look like this: * Routing * Endpoint +Middleware can also be added at the route level, in which case it will be executed after routing ocurrs: + +```python +from starlette.applications import Starlette +from starlette.middleware.trustedhost import TrustedHostMiddleware + +example_route_middleware = [ + Middleware(TrustedHostMiddleware, allowed_hosts=['example.com', '*.example.com']), +] + +routes = [ + Route( + "/example", + endpoint=..., + middleware=example_route_middleware, + ) +] + +middleware = [ + Middleware(HTTPSRedirectMiddleware) +] + +app = Starlette(routes=routes, middleware=middleware) +``` + The following middleware implementations are available in the Starlette package: ## CORSMiddleware diff --git a/starlette/routing.py b/starlette/routing.py index 0f6e300dc..2b66ffcc8 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -335,6 +335,8 @@ def __init__( app: ASGIApp = None, routes: typing.Sequence[BaseRoute] = None, name: str = None, + *, + middleware: typing.Optional[typing.Sequence[Middleware]] = None, ) -> None: assert path == "" or path.startswith("/"), "Routed paths must start with '/'" assert ( @@ -350,6 +352,10 @@ def __init__( self.path + "/{path:path}" ) + if middleware is not None: + for cls, options in reversed(middleware): + self.app = cls(app=self.app, **options) + @property def routes(self) -> typing.List[BaseRoute]: return getattr(self.app, "routes", None) diff --git a/tests/test_routing.py b/tests/test_routing.py index db252bdfc..7ec9e6a30 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -689,3 +689,29 @@ def test_router_middleware_http( response = test_client.get("/http") assert response.status_code == 200 assert response.headers["X-Test"] == "Set by middleware" + + +mounted_middleware_router = Router( + [ + Mount( + "/http", + routes=[ + Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + ), + ], + middleware=[Middleware(AddHeadersMiddleware)], + ) + ] +) + + +def test_mounted_router_middleware( + test_client_factory: typing.Callable[..., TestClient] +) -> None: + test_client = test_client_factory(mounted_middleware_router) + response = test_client.get("/http") + assert response.status_code == 200 + assert response.headers["X-Test"] == "Set by middleware" From 3c64e176975e119e7b4b130354e9841c74f8b0d0 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 11 Nov 2021 12:10:13 -0600 Subject: [PATCH 08/21] add warning about modifying the path --- docs/middleware.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/middleware.md b/docs/middleware.md index 3c6c4c07f..f9f9f022c 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -66,6 +66,8 @@ middleware = [ app = Starlette(routes=routes, middleware=middleware) ``` +Note that since this is run after routing, modifying the path in the middleware will have no effect. + The following middleware implementations are available in the Starlette package: ## CORSMiddleware From 586ed21e5cea12874d733de63545981b112acf62 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 11 Nov 2021 12:12:07 -0600 Subject: [PATCH 09/21] combine tests --- tests/test_routing.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/test_routing.py b/tests/test_routing.py index 7ec9e6a30..dd92bbcfd 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -682,15 +682,6 @@ async def dispatch( ) -def test_router_middleware_http( - test_client_factory: typing.Callable[..., TestClient] -) -> None: - test_client = test_client_factory(middleware_router) - response = test_client.get("/http") - assert response.status_code == 200 - assert response.headers["X-Test"] == "Set by middleware" - - mounted_middleware_router = Router( [ Mount( @@ -707,11 +698,17 @@ def test_router_middleware_http( ] ) - -def test_mounted_router_middleware( - test_client_factory: typing.Callable[..., TestClient] +@pytest.mark.parametrize( + "router", [ + middleware_router, + mounted_middleware_router, + ] +) +def test_http_route_middleware( + test_client_factory: typing.Callable[..., TestClient], + router: Router, ) -> None: - test_client = test_client_factory(mounted_middleware_router) + test_client = test_client_factory(router) response = test_client.get("/http") assert response.status_code == 200 assert response.headers["X-Test"] == "Set by middleware" From 24988bfe5d2157999a662ab2d83969a512ebb836 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 11 Nov 2021 12:59:26 -0600 Subject: [PATCH 10/21] Update docs/middleware.md Co-authored-by: Marcelo Trylesinski --- docs/middleware.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/middleware.md b/docs/middleware.md index f9f9f022c..58fd244a0 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -41,7 +41,7 @@ application would look like this: * Routing * Endpoint -Middleware can also be added at the route level, in which case it will be executed after routing ocurrs: +Middleware can also be added at the route level, in which case it will be executed after routing occurs: ```python from starlette.applications import Starlette From 73f920cc784c9edfd315283a0f4e08adf5c5b481 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 15 Nov 2021 14:16:26 -0600 Subject: [PATCH 11/21] linting --- tests/test_routing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_routing.py b/tests/test_routing.py index dd92bbcfd..92c692411 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -698,11 +698,13 @@ async def dispatch( ] ) + @pytest.mark.parametrize( - "router", [ + "router", + [ middleware_router, mounted_middleware_router, - ] + ], ) def test_http_route_middleware( test_client_factory: typing.Callable[..., TestClient], From 14f85bdb881c0bde1f4b794f03f1217d7d0b4390 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 22 Nov 2021 17:34:40 -0600 Subject: [PATCH 12/21] Add note on error handling --- docs/middleware.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/middleware.md b/docs/middleware.md index 58fd244a0..141c2b04c 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -39,6 +39,7 @@ application would look like this: * `HTTPSRedirectMiddleware` * `ExceptionMiddleware` * Routing +* Route middleware * Endpoint Middleware can also be added at the route level, in which case it will be executed after routing occurs: @@ -67,6 +68,7 @@ app = Starlette(routes=routes, middleware=middleware) ``` Note that since this is run after routing, modifying the path in the middleware will have no effect. +There is also no built-in error handling for route middleware, so your middleware will need to handle exceptions and resource cleanup itself. The following middleware implementations are available in the Starlette package: From eb7d41d3a7fa9a07f8b984c90135ae99407603d3 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 22 Nov 2021 19:50:15 -0600 Subject: [PATCH 13/21] capture routes before wrapping --- starlette/routing.py | 7 +++- tests/test_routing.py | 98 ++++++++++++++++++++++++++++++++----------- 2 files changed, 78 insertions(+), 27 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 2b66ffcc8..292b64ece 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -345,8 +345,11 @@ def __init__( self.path = path.rstrip("/") if app is not None: self.app: ASGIApp = app + routes = getattr(self.app, "routes", []) else: self.app = Router(routes=routes) + routes = list(routes or []) + self._routes = routes self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path( self.path + "/{path:path}" @@ -358,7 +361,7 @@ def __init__( @property def routes(self) -> typing.List[BaseRoute]: - return getattr(self.app, "routes", None) + return self._routes def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): @@ -406,7 +409,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: ) if path_kwarg is not None: remaining_params["path"] = path_kwarg - for route in self.routes or []: + for route in self.routes: try: url = route.url_path_for(remaining_name, **remaining_params) return URLPath( diff --git a/tests/test_routing.py b/tests/test_routing.py index 92c692411..8a9c4b0b1 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -9,7 +9,7 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response -from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute +from starlette.routing import BaseRoute, Host, Mount, NoMatchFound, Route, Router, WebSocketRoute from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -670,47 +670,95 @@ async def dispatch( return response -middleware_router = Router( - [ +route_with_middleware = Route( + "/http", + endpoint=assert_middleware_header_route, + methods=["GET"], + middleware=[Middleware(AddHeadersMiddleware)], +), + + +mounted_routes_with_middleware = Mount( + "/http", + routes=[ Route( - "/http", + "/", endpoint=assert_middleware_header_route, methods=["GET"], - middleware=[Middleware(AddHeadersMiddleware)], + name="route", ), - ] + ], + middleware=[Middleware(AddHeadersMiddleware)], ) -mounted_middleware_router = Router( - [ - Mount( - "/http", - routes=[ - Route( - "/", - endpoint=assert_middleware_header_route, - methods=["GET"], - ), - ], +mounted_app_with_middleware = Mount( + "/http", + app=Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", + ), + middleware=[Middleware(AddHeadersMiddleware)], +) + + +mounted_routes_with_route_middleware = Mount( + "/http", + routes=[ + Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", middleware=[Middleware(AddHeadersMiddleware)], - ) - ] + ), + ], +) + + +mounted_app_with_route_middleware = Mount( + "/http", + app=Route( + "/", + endpoint=assert_middleware_header_route, + methods=["GET"], + name="route", + middleware=[Middleware(AddHeadersMiddleware)], + ), ) @pytest.mark.parametrize( - "router", + "route", [ - middleware_router, - mounted_middleware_router, + mounted_routes_with_middleware, + mounted_routes_with_middleware, + mounted_app_with_middleware, + mounted_routes_with_route_middleware, + mounted_app_with_route_middleware ], ) -def test_http_route_middleware( +def test_route_level_middleware( test_client_factory: typing.Callable[..., TestClient], - router: Router, + route: BaseRoute, ) -> None: - test_client = test_client_factory(router) + test_client = test_client_factory(Router([route])) response = test_client.get("/http") assert response.status_code == 200 assert response.headers["X-Test"] == "Set by middleware" + + +@pytest.mark.parametrize( + "route", + [ + mounted_routes_with_middleware, + mounted_routes_with_middleware, + mounted_routes_with_route_middleware, + ], +) +def test_mount_middleware_url_path_for_(route: BaseRoute) -> None: + """Checks that url_path_for still works with middelware on Mounts""" + router = Router([route]) + assert router.url_path_for("route") == "/http/" From 93ef0e33152841ad4ab422dfc2546bd6fc70a5e1 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 7 Dec 2021 09:30:03 -0600 Subject: [PATCH 14/21] chore: run linting --- tests/test_routing.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/test_routing.py b/tests/test_routing.py index 516e77a4c..f1e29626b 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -9,7 +9,15 @@ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse, Response -from starlette.routing import BaseRoute, Host, Mount, NoMatchFound, Route, Router, WebSocketRoute +from starlette.routing import ( + BaseRoute, + Host, + Mount, + NoMatchFound, + Route, + Router, + WebSocketRoute, +) from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -696,12 +704,14 @@ async def dispatch( return response -route_with_middleware = Route( - "/http", - endpoint=assert_middleware_header_route, - methods=["GET"], - middleware=[Middleware(AddHeadersMiddleware)], -), +route_with_middleware = ( + Route( + "/http", + endpoint=assert_middleware_header_route, + methods=["GET"], + middleware=[Middleware(AddHeadersMiddleware)], + ), +) mounted_routes_with_middleware = Mount( @@ -763,7 +773,7 @@ async def dispatch( mounted_routes_with_middleware, mounted_app_with_middleware, mounted_routes_with_route_middleware, - mounted_app_with_route_middleware + mounted_app_with_route_middleware, ], ) def test_route_level_middleware( From b22526928a8cada7551d87c1ec6d967d025e9f47 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 17 Dec 2021 13:51:44 -0600 Subject: [PATCH 15/21] fix botched merge --- starlette/routing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/routing.py b/starlette/routing.py index 3058521a1..73b263191 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -364,7 +364,7 @@ def __init__( @property def routes(self) -> typing.List[BaseRoute]: - return getattr(self.app, "routes", []) + return self._routes def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): From 62c3b612a17a8d7ab000d8f226daa294ce11b144 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 17 Dec 2021 18:13:32 -0600 Subject: [PATCH 16/21] add test for preservation of behavior of modifying an app after mounting --- starlette/routing.py | 5 ++--- tests/test_routing.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 73b263191..7a64f2a95 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -351,8 +351,7 @@ def __init__( routes = getattr(self.app, "routes", []) else: self.app = Router(routes=routes) - routes = list(routes or []) - self._routes = routes + self._user_app = self.app self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path( self.path + "/{path:path}" @@ -364,7 +363,7 @@ def __init__( @property def routes(self) -> typing.List[BaseRoute]: - return self._routes + return getattr(self._user_app, "routes", []) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): diff --git a/tests/test_routing.py b/tests/test_routing.py index 26460ebe6..5142109bb 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -823,3 +823,24 @@ def test_mount_middleware_url_path_for_(route: BaseRoute) -> None: """Checks that url_path_for still works with middelware on Mounts""" router = Router([route]) assert router.url_path_for("route") == "/http/" + + +def test_add_route_to_app_after_mount( + test_client_factory: typing.Callable[..., TestClient], +) -> None: + """Checks that mounds will pick up routes + added to the underlaying app after it is mounted + """ + inner_app = Router() + app = Mount( + "/http", + app=inner_app + ) + inner_app.add_route( + "/inner", + endpoint=lambda request: Response(), + methods=["GET"], + ) + client = test_client_factory(app) + response = client.get("/http/inner") + assert response.status_code == 200 From c25ac029f46aa31118b11134029595f5056bcf60 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 21 Dec 2021 18:16:46 -0600 Subject: [PATCH 17/21] lint --- tests/test_routing.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_routing.py b/tests/test_routing.py index 5142109bb..b3368305c 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -832,10 +832,7 @@ def test_add_route_to_app_after_mount( added to the underlaying app after it is mounted """ inner_app = Router() - app = Mount( - "/http", - app=inner_app - ) + app = Mount("/http", app=inner_app) inner_app.add_route( "/inner", endpoint=lambda request: Response(), From 6ec05f29886bcd0cdad240b8708b8262b1bf0d4a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 10 Jan 2022 08:37:13 -0800 Subject: [PATCH 18/21] remove unused variable --- starlette/routing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/routing.py b/starlette/routing.py index 7a64f2a95..27da55991 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -348,7 +348,6 @@ def __init__( self.path = path.rstrip("/") if app is not None: self.app: ASGIApp = app - routes = getattr(self.app, "routes", []) else: self.app = Router(routes=routes) self._user_app = self.app From b58f0c32b1161c442207254a25294c0df71083ce Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 10 Jan 2022 08:38:08 -0800 Subject: [PATCH 19/21] add comment on why we dynamically fetch routes --- starlette/routing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/starlette/routing.py b/starlette/routing.py index 27da55991..3e59e7fa7 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -362,6 +362,8 @@ def __init__( @property def routes(self) -> typing.List[BaseRoute]: + # we dynamically grab the routes so that if this is a Starlette router + # it can have routes added to it after it is mounted return getattr(self._user_app, "routes", []) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: From e9ff90328df6285b2ebf77ca5abf2c053d42add6 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 10 Jan 2022 08:43:28 -0800 Subject: [PATCH 20/21] grab routes in __init__ and remove _user_app --- starlette/routing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 3e59e7fa7..cdf534368 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -332,6 +332,8 @@ def __eq__(self, other: typing.Any) -> bool: class Mount(BaseRoute): + _routes: typing.List[BaseRoute] + def __init__( self, path: str, @@ -348,9 +350,10 @@ def __init__( self.path = path.rstrip("/") if app is not None: self.app: ASGIApp = app + self._routes = getattr(app, "routes", []) else: self.app = Router(routes=routes) - self._user_app = self.app + self._routes = getattr(self.app, "routes", []) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path( self.path + "/{path:path}" @@ -362,9 +365,7 @@ def __init__( @property def routes(self) -> typing.List[BaseRoute]: - # we dynamically grab the routes so that if this is a Starlette router - # it can have routes added to it after it is mounted - return getattr(self._user_app, "routes", []) + return self._routes def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: if scope["type"] in ("http", "websocket"): From 3b01035b6dc8c33adb228d89abcc11428c11432b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 29 Jan 2022 12:10:54 -0600 Subject: [PATCH 21/21] Update starlette/routing.py Co-authored-by: Marcelo Trylesinski --- starlette/routing.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index cdf534368..b4981b7cf 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -350,10 +350,9 @@ def __init__( self.path = path.rstrip("/") if app is not None: self.app: ASGIApp = app - self._routes = getattr(app, "routes", []) else: self.app = Router(routes=routes) - self._routes = getattr(self.app, "routes", []) + self._routes = getattr(self.app, "routes", []) self.name = name self.path_regex, self.path_format, self.param_convertors = compile_path( self.path + "/{path:path}"