Skip to content

Commit

Permalink
Lifespan route instance (#401)
Browse files Browse the repository at this point in the history
* Add Mount(routes=...)
* Lifespan as a standard routing component
* Linting
  • Loading branch information
tomchristie committed Feb 19, 2019
1 parent 933d786 commit 06adecd
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 62 deletions.
1 change: 1 addition & 0 deletions scripts/lint
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ ${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
${PREFIX}autoflake --in-place --recursive starlette tests
${PREFIX}black starlette tests
${PREFIX}isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --apply starlette tests
${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
7 changes: 3 additions & 4 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
)
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
if template_directory is not None:
from starlette.templating import Jinja2Templates
Expand Down Expand Up @@ -53,7 +52,7 @@ def schema(self) -> dict:
return self.schema_generator.get_schema(self.routes)

def on_event(self, event_type: str) -> typing.Callable:
return self.lifespan_middleware.on_event(event_type)
return self.router.lifespan.on_event(event_type)

def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
self.router.mount(path, app=app, name=name)
Expand All @@ -79,7 +78,7 @@ def add_exception_handler(
)

def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
self.lifespan_middleware.add_event_handler(event_type, func)
self.router.lifespan.add_event_handler(event_type, func)

def add_route(
self,
Expand Down Expand Up @@ -149,4 +148,4 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:

def __call__(self, scope: Scope) -> ASGIInstance:
scope["app"] = self
return self.lifespan_middleware(scope)
return self.error_middleware(scope)
2 changes: 1 addition & 1 deletion starlette/middleware/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __call__(self, scope: Scope) -> ASGIInstance:
return LifespanHandler(
self.app, scope, self.startup_handlers, self.shutdown_handlers
)
return self.app(scope)
return self.app(scope) # pragma: no cover


class LifespanHandler:
Expand Down
159 changes: 102 additions & 57 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,10 @@ def __init__(
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert (
app is not None or routes is not None
), "Either 'app', or 'routes' must be specified"
), "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if routes is None:
assert app is not None
self.app = app
if app is not None:
self.app = app # type: ASGIApp
else:
self.app = Router(routes=routes)
self.name = name
Expand All @@ -303,23 +302,24 @@ def routes(self) -> typing.List[BaseRoute]:
return getattr(self.app, "routes", None)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path = scope["path"]
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {
"path_params": path_params,
"root_path": scope.get("root_path", "") + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
if scope["type"] in ("http", "websocket"):
path = scope["path"]
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {
"path_params": path_params,
"root_path": scope.get("root_path", "") + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
Expand Down Expand Up @@ -375,17 +375,18 @@ def routes(self) -> typing.List[BaseRoute]:
return getattr(self.app, "routes", None)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
if scope["type"] in ("http", "websocket"):
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
Expand Down Expand Up @@ -426,6 +427,63 @@ def __eq__(self, other: typing.Any) -> bool:
)


class Lifespan(BaseRoute):
def __init__(
self, on_startup: typing.Callable = None, on_shutdown: typing.Callable = None
):
self.startup_handlers = [] if on_startup is None else [on_startup]
self.shutdown_handlers = [] if on_shutdown is None else [on_shutdown]

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
if scope["type"] == "lifespan":
return Match.FULL, {}
return Match.NONE, {}

def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
assert event_type in ("startup", "shutdown")

if event_type == "startup":
self.startup_handlers.append(func)
else:
assert event_type == "shutdown"
self.shutdown_handlers.append(func)

def on_event(self, event_type: str) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
self.add_event_handler(event_type, func)
return func

return decorator

async def startup(self) -> None:
for handler in self.startup_handlers:
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()

async def shutdown(self) -> None:
for handler in self.shutdown_handlers:
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()

def __call__(self, scope: Scope) -> ASGIInstance:
return self.asgi

async def asgi(self, receive: Receive, send: Send) -> None:
message = await receive()
assert message["type"] == "lifespan.startup"
await self.startup()
await send({"type": "lifespan.startup.complete"})

message = await receive()
assert message["type"] == "lifespan.shutdown"
await self.shutdown()
await send({"type": "lifespan.shutdown.complete"})


class Router:
def __init__(
self,
Expand All @@ -436,6 +494,7 @@ def __init__(
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
self.lifespan = Lifespan()

def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
route = Mount(path, app=app, name=name)
Expand Down Expand Up @@ -516,9 +575,6 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:
def __call__(self, scope: Scope) -> ASGIInstance:
assert scope["type"] in ("http", "websocket", "lifespan")

if scope["type"] == "lifespan":
return LifespanHandler(scope)

if "router" not in scope:
scope["router"] = self

Expand All @@ -537,31 +593,20 @@ def __call__(self, scope: Scope) -> ASGIInstance:
scope.update(partial_scope)
return partial(scope)

if self.redirect_slashes and not scope["path"].endswith("/"):
redirect_scope = dict(scope)
redirect_scope["path"] += "/"
if scope["type"] == "http" and self.redirect_slashes:
if not scope["path"].endswith("/"):
redirect_scope = dict(scope)
redirect_scope["path"] += "/"

for route in self.routes:
match, child_scope = route.matches(redirect_scope)
if match != Match.NONE:
redirect_url = URL(scope=redirect_scope)
return RedirectResponse(url=str(redirect_url))
for route in self.routes:
match, child_scope = route.matches(redirect_scope)
if match != Match.NONE:
redirect_url = URL(scope=redirect_scope)
return RedirectResponse(url=str(redirect_url))

if scope["type"] == "lifespan":
return self.lifespan(scope)
return self.default(scope)

def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes


class LifespanHandler:
def __init__(self, scope: Scope) -> None:
pass

async def __call__(self, receive: Receive, send: Send) -> None:
message = await receive()
assert message["type"] == "lifespan.startup"
await send({"type": "lifespan.startup.complete"})

message = await receive()
assert message["type"] == "lifespan.shutdown"
await send({"type": "lifespan.shutdown.complete"})
58 changes: 58 additions & 0 deletions tests/middleware/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from starlette.applications import Starlette
from starlette.middleware.lifespan import LifespanMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Lifespan, Route, Router
from starlette.testclient import TestClient


Expand Down Expand Up @@ -98,6 +100,38 @@ def test_raise_on_shutdown():
pass


def test_routed_lifespan():
startup_complete = False
shutdown_complete = False

def hello_world(request):
return PlainTextResponse("hello, world")

def run_startup():
nonlocal startup_complete
startup_complete = True

def run_shutdown():
nonlocal shutdown_complete
shutdown_complete = True

app = Router(
routes=[
Lifespan(on_startup=run_startup, on_shutdown=run_shutdown),
Route("/", hello_world),
]
)

assert not startup_complete
assert not shutdown_complete
with TestClient(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
assert startup_complete
assert shutdown_complete


def test_app_lifespan():
startup_complete = False
cleanup_complete = False
Expand All @@ -120,3 +154,27 @@ def run_cleanup():
assert not cleanup_complete
assert startup_complete
assert cleanup_complete


def test_app_async_lifespan():
startup_complete = False
cleanup_complete = False
app = Starlette()

@app.on_event("startup")
async def run_startup():
nonlocal startup_complete
startup_complete = True

@app.on_event("shutdown")
async def run_cleanup():
nonlocal cleanup_complete
cleanup_complete = True

assert not startup_complete
assert not cleanup_complete
with TestClient(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete

0 comments on commit 06adecd

Please sign in to comment.