Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add include_router to Starlette and Router #2189

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,21 @@ def decorator(func: typing.Callable) -> typing.Callable:

return decorator

def include_router(
self,
router: Router,
*,
prefix: str = "",
include_in_schema: bool = True,
**kwargs: typing.Any, # arguments to extend by inherits
) -> None:
self.router.include_router(
router,
prefix=prefix,
include_in_schema=include_in_schema,
**kwargs,
)

def middleware(self, middleware_type: str) -> typing.Callable:
"""
We no longer document this decorator style API, and its usage is discouraged.
Expand Down
76 changes: 75 additions & 1 deletion starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -703,6 +703,65 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
else:
await send({"type": "lifespan.shutdown.complete"})

def include_router(
self,
router: "Router",
*,
prefix: str = "",
include_in_schema: bool = True,
**kwargs: typing.Any, # arguments to extend by inherits
) -> None:
if prefix:
assert prefix.startswith("/"), "A path prefix must start with '/'"
assert not prefix.endswith(
"/"
), "A path prefix must not end with '/', as the routes will start with '/'"

for route in router.routes:
self._mv_router_route(
route=route,
prefix=prefix,
include_in_schema=include_in_schema,
**kwargs,
)

self._merge_router_events(router)

def _mv_router_route(
self,
route: BaseRoute,
*,
prefix: str = "",
include_in_schema: bool = True,
**kwargs: typing.Any, # arguments to extend by inherits
) -> None:
if isinstance(route, Route):
methods = list(route.methods or [])
self.add_route(
prefix + route.path,
route.endpoint,
methods=methods,
include_in_schema=route.include_in_schema and include_in_schema,
name=route.name,
)
elif isinstance(route, WebSocketRoute):
self.add_websocket_route(
prefix + route.path,
route.endpoint,
name=route.name,
)

def _merge_router_events(self, router: typing.Optional["Router"]) -> None:
if router is not None:
for handler in router.on_startup:
self.add_event_handler("startup", handler)
for handler in router.on_shutdown:
self.add_event_handler("shutdown", handler)
self.lifespan_context = _merge_lifespan_context(
self.lifespan_context,
router.lifespan_context,
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
The main entry point to the Router class.
Expand Down Expand Up @@ -762,6 +821,7 @@ def mount(
self, path: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None: # pragma: nocover
route = Mount(path, app=app, name=name)
self._merge_router_events(getattr(app, "router", None))
self.routes.append(route)

def host(
Expand Down Expand Up @@ -869,3 +929,17 @@ def decorator(func: typing.Callable) -> typing.Callable:
return func

return decorator


def _merge_lifespan_context(
original_context: Lifespan[typing.Any], nested_context: Lifespan[typing.Any]
) -> Lifespan[typing.Any]:
@asynccontextmanager
async def merged_lifespan(
app: AppType,
) -> typing.AsyncIterator[typing.Mapping[str, typing.Any]]:
async with original_context(app) as maybe_self_context:
async with nested_context(app) as maybe_nested_context:
yield {**(maybe_self_context or {}), **(maybe_nested_context or {})}

return merged_lifespan
36 changes: 36 additions & 0 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient
from starlette.types import ASGIApp
from starlette.websockets import WebSocket

Expand Down Expand Up @@ -281,6 +282,41 @@ def test_app_mount(tmpdir, test_client_factory):
assert response.text == "Method Not Allowed"


def test_app_mount_events():
app = Starlette()
nested_app = Starlette()

with pytest.warns(DeprecationWarning):

@nested_app.on_event("startup")
async def startup():
app.state.nested_started = True

app.mount(path="/", app=nested_app)

with TestClient(app):
assert app.state.nested_started is True


def test_app_mount_lifespan():
app = Starlette()

@asynccontextmanager
async def lifespan(app: Starlette):
app.state.nested_started = True
yield {"router": True}
app.state.nested_shutdown = True

nested_app = Starlette(lifespan=lifespan)
app.mount(path="/", app=nested_app)

with TestClient(app):
assert app.state.nested_started is True

assert app.state.nested_started is True
assert app.state.nested_shutdown is True


def test_app_debug(test_client_factory):
async def homepage(request):
raise RuntimeError()
Expand Down
131 changes: 131 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,134 @@ async def startup() -> None:
... # pragma: nocover

router.on_event("startup")(startup)


def test_include_router() -> None:
base_router = Router()

nested_router = Router(
routes=[
Route("/home", endpoint=homepage, include_in_schema=True),
WebSocketRoute("/ws", endpoint=ws_helloworld),
]
)

base_router.include_router(nested_router, prefix="/nested", include_in_schema=False)

route1 = base_router.routes[0]
assert isinstance(route1, Route)
assert route1.path == "/nested/home"
assert route1.include_in_schema is False

route2 = base_router.routes[1]
assert isinstance(route2, WebSocketRoute)
assert route2.path == "/nested/ws"


def test_router_nested_lifespan_state() -> None:
@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
app.state.app_startup = True
yield {"app": True}
app.state.app_shutdown = True

@contextlib.asynccontextmanager
async def router_lifespan(app: Starlette):
app.state.router_startup = True
yield {"router": True}
app.state.router_shutdown = True

@contextlib.asynccontextmanager
async def subrouter_lifespan(app: Starlette):
app.state.sub_router_startup = True
yield {"sub_router": True}
app.state.sub_router_shutdown = True

def main(request: Request):
assert request.state.app
assert request.state.router
assert request.state.sub_router
return JSONResponse({"message": "Hello World"})

sub_router = Router(lifespan=subrouter_lifespan, routes=[Route("/", endpoint=main)])

router = Router(lifespan=router_lifespan)
router.include_router(sub_router)

app = Starlette(lifespan=lifespan)
app.include_router(router)

with TestClient(app) as client:
assert app.state.app_startup is True
assert app.state.router_startup is True
assert app.state.sub_router_startup is True

response = client.get("/")
assert response.status_code == 200, response.text
assert response.json() == {"message": "Hello World"}

assert app.state.app_startup is True
assert app.state.router_startup is True
assert app.state.sub_router_startup is True
assert app.state.app_shutdown is True
assert app.state.router_shutdown is True
assert app.state.sub_router_shutdown is True


def test_router_events() -> None:
app = Starlette()

with pytest.warns(DeprecationWarning):

@app.on_event("startup")
def app_startup() -> None:
app.state.app_startup = True

with pytest.warns(DeprecationWarning):

@app.on_event("shutdown")
def app_shutdown() -> None:
app.state.app_shutdown = True

router = Router()

with pytest.warns(DeprecationWarning):

@router.on_event("startup")
def router_startup() -> None:
app.state.router_startup = True

with pytest.warns(DeprecationWarning):

@router.on_event("shutdown")
def router_shutdown() -> None:
app.state.router_shutdown = True

sub_router = Router()

with pytest.warns(DeprecationWarning):

@sub_router.on_event("startup")
def sub_router_startup() -> None:
app.state.sub_router_startup = True

with pytest.warns(DeprecationWarning):

@sub_router.on_event("shutdown")
def sub_router_shutdown() -> None:
app.state.sub_router_shutdown = True

router.include_router(sub_router)
app.include_router(router)

with TestClient(app):
assert app.state.app_startup is True
assert app.state.router_startup is True
assert app.state.sub_router_startup is True

assert app.state.app_startup is True
assert app.state.router_startup is True
assert app.state.sub_router_startup is True
assert app.state.app_shutdown is True
assert app.state.router_shutdown is True
assert app.state.sub_router_shutdown is True