Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lifespan route instance #401

Merged
merged 6 commits into from
Feb 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/lint
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ ${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
${PREFIX}autoflake --in-place --recursive starlette tests
${PREFIX}black starlette tests
${PREFIX}isort --multi-line=3 --trailing-comma --force-grid-wrap=0 --combine-as --line-width 88 --recursive --apply starlette tests
${PREFIX}mypy starlette --ignore-missing-imports --disallow-untyped-defs
7 changes: 3 additions & 4 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
self.error_middleware = ServerErrorMiddleware(
self.exception_middleware, debug=debug
)
self.lifespan_middleware = LifespanMiddleware(self.error_middleware)
self.schema_generator = None # type: typing.Optional[BaseSchemaGenerator]
if template_directory is not None:
from starlette.templating import Jinja2Templates
Expand Down Expand Up @@ -53,7 +52,7 @@ def schema(self) -> dict:
return self.schema_generator.get_schema(self.routes)

def on_event(self, event_type: str) -> typing.Callable:
return self.lifespan_middleware.on_event(event_type)
return self.router.lifespan.on_event(event_type)

def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
self.router.mount(path, app=app, name=name)
Expand All @@ -79,7 +78,7 @@ def add_exception_handler(
)

def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
self.lifespan_middleware.add_event_handler(event_type, func)
self.router.lifespan.add_event_handler(event_type, func)

def add_route(
self,
Expand Down Expand Up @@ -149,4 +148,4 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:

def __call__(self, scope: Scope) -> ASGIInstance:
scope["app"] = self
return self.lifespan_middleware(scope)
return self.error_middleware(scope)
2 changes: 1 addition & 1 deletion starlette/middleware/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __call__(self, scope: Scope) -> ASGIInstance:
return LifespanHandler(
self.app, scope, self.startup_handlers, self.shutdown_handlers
)
return self.app(scope)
return self.app(scope) # pragma: no cover


class LifespanHandler:
Expand Down
159 changes: 102 additions & 57 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,10 @@ def __init__(
assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
assert (
app is not None or routes is not None
), "Either 'app', or 'routes' must be specified"
), "Either 'app=...', or 'routes=' must be specified"
self.path = path.rstrip("/")
if routes is None:
assert app is not None
self.app = app
if app is not None:
self.app = app # type: ASGIApp
else:
self.app = Router(routes=routes)
self.name = name
Expand All @@ -303,23 +302,24 @@ def routes(self) -> typing.List[BaseRoute]:
return getattr(self.app, "routes", None)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path = scope["path"]
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {
"path_params": path_params,
"root_path": scope.get("root_path", "") + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
if scope["type"] in ("http", "websocket"):
path = scope["path"]
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {
"path_params": path_params,
"root_path": scope.get("root_path", "") + matched_path,
"path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
Expand Down Expand Up @@ -375,17 +375,18 @@ def routes(self) -> typing.List[BaseRoute]:
return getattr(self.app, "routes", None)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
if scope["type"] in ("http", "websocket"):
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
Expand Down Expand Up @@ -426,6 +427,63 @@ def __eq__(self, other: typing.Any) -> bool:
)


class Lifespan(BaseRoute):
def __init__(
self, on_startup: typing.Callable = None, on_shutdown: typing.Callable = None
):
self.startup_handlers = [] if on_startup is None else [on_startup]
self.shutdown_handlers = [] if on_shutdown is None else [on_shutdown]

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
if scope["type"] == "lifespan":
return Match.FULL, {}
return Match.NONE, {}

def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
assert event_type in ("startup", "shutdown")

if event_type == "startup":
self.startup_handlers.append(func)
else:
assert event_type == "shutdown"
self.shutdown_handlers.append(func)

def on_event(self, event_type: str) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
self.add_event_handler(event_type, func)
return func

return decorator

async def startup(self) -> None:
for handler in self.startup_handlers:
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()

async def shutdown(self) -> None:
for handler in self.shutdown_handlers:
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()

def __call__(self, scope: Scope) -> ASGIInstance:
return self.asgi

async def asgi(self, receive: Receive, send: Send) -> None:
message = await receive()
assert message["type"] == "lifespan.startup"
await self.startup()
await send({"type": "lifespan.startup.complete"})

message = await receive()
assert message["type"] == "lifespan.shutdown"
await self.shutdown()
await send({"type": "lifespan.shutdown.complete"})


class Router:
def __init__(
self,
Expand All @@ -436,6 +494,7 @@ def __init__(
self.routes = [] if routes is None else list(routes)
self.redirect_slashes = redirect_slashes
self.default = self.not_found if default is None else default
self.lifespan = Lifespan()

def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
route = Mount(path, app=app, name=name)
Expand Down Expand Up @@ -516,9 +575,6 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:
def __call__(self, scope: Scope) -> ASGIInstance:
assert scope["type"] in ("http", "websocket", "lifespan")

if scope["type"] == "lifespan":
return LifespanHandler(scope)

if "router" not in scope:
scope["router"] = self

Expand All @@ -537,31 +593,20 @@ def __call__(self, scope: Scope) -> ASGIInstance:
scope.update(partial_scope)
return partial(scope)

if self.redirect_slashes and not scope["path"].endswith("/"):
redirect_scope = dict(scope)
redirect_scope["path"] += "/"
if scope["type"] == "http" and self.redirect_slashes:
if not scope["path"].endswith("/"):
redirect_scope = dict(scope)
redirect_scope["path"] += "/"

for route in self.routes:
match, child_scope = route.matches(redirect_scope)
if match != Match.NONE:
redirect_url = URL(scope=redirect_scope)
return RedirectResponse(url=str(redirect_url))
for route in self.routes:
match, child_scope = route.matches(redirect_scope)
if match != Match.NONE:
redirect_url = URL(scope=redirect_scope)
return RedirectResponse(url=str(redirect_url))

if scope["type"] == "lifespan":
return self.lifespan(scope)
return self.default(scope)

def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, Router) and self.routes == other.routes


class LifespanHandler:
def __init__(self, scope: Scope) -> None:
pass

async def __call__(self, receive: Receive, send: Send) -> None:
message = await receive()
assert message["type"] == "lifespan.startup"
await send({"type": "lifespan.startup.complete"})

message = await receive()
assert message["type"] == "lifespan.shutdown"
await send({"type": "lifespan.shutdown.complete"})
58 changes: 58 additions & 0 deletions tests/middleware/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from starlette.applications import Starlette
from starlette.middleware.lifespan import LifespanMiddleware
from starlette.responses import PlainTextResponse
from starlette.routing import Lifespan, Route, Router
from starlette.testclient import TestClient


Expand Down Expand Up @@ -98,6 +100,38 @@ def test_raise_on_shutdown():
pass


def test_routed_lifespan():
startup_complete = False
shutdown_complete = False

def hello_world(request):
return PlainTextResponse("hello, world")

def run_startup():
nonlocal startup_complete
startup_complete = True

def run_shutdown():
nonlocal shutdown_complete
shutdown_complete = True

app = Router(
routes=[
Lifespan(on_startup=run_startup, on_shutdown=run_shutdown),
Route("/", hello_world),
]
)

assert not startup_complete
assert not shutdown_complete
with TestClient(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
assert startup_complete
assert shutdown_complete


def test_app_lifespan():
startup_complete = False
cleanup_complete = False
Expand All @@ -120,3 +154,27 @@ def run_cleanup():
assert not cleanup_complete
assert startup_complete
assert cleanup_complete


def test_app_async_lifespan():
startup_complete = False
cleanup_complete = False
app = Starlette()

@app.on_event("startup")
async def run_startup():
nonlocal startup_complete
startup_complete = True

@app.on_event("shutdown")
async def run_cleanup():
nonlocal cleanup_complete
cleanup_complete = True

assert not startup_complete
assert not cleanup_complete
with TestClient(app):
assert startup_complete
assert not cleanup_complete
assert startup_complete
assert cleanup_complete