Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce lifespan state #1818

Merged
merged 23 commits into from
Mar 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions tests/protocols/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
import socket
import threading
import time
from typing import Optional, Union

import pytest

from tests.response import Response
from uvicorn import Server
from uvicorn.config import WS_PROTOCOLS, Config
from uvicorn.lifespan.off import LifespanOff
from uvicorn.lifespan.on import LifespanOn
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol

Expand Down Expand Up @@ -184,12 +187,23 @@ def add_done_callback(self, callback):
pass


def get_connected_protocol(app, protocol_cls, **kwargs):
def get_connected_protocol(
app,
protocol_cls,
lifespan: Optional[Union[LifespanOff, LifespanOn]] = None,
**kwargs,
):
loop = MockLoop()
transport = MockTransport()
config = Config(app=app, **kwargs)
lifespan = lifespan or LifespanOff(config)
server_state = ServerState()
protocol = protocol_cls(config=config, server_state=server_state, _loop=loop)
protocol = protocol_cls(
config=config,
server_state=server_state,
app_state=lifespan.state.copy(),
_loop=loop,
)
protocol.connection_made(transport)
return protocol

Expand Down Expand Up @@ -980,3 +994,32 @@ async def app(scope, receive, send):
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"x-test-header: test value" in protocol.transport.buffer


@pytest.mark.anyio
@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
async def test_lifespan_state(protocol_cls):
expected_states = [{"a": 123, "b": [1]}, {"a": 123, "b": [1, 2]}]

async def app(scope, receive, send):
expected_state = expected_states.pop(0)
assert scope["state"] == expected_state
# modifications to keys are not preserved
scope["state"]["a"] = 456
# unless of course the value itself is mutated
scope["state"]["b"].append(2)
return await Response("Hi!")(scope, receive, send)

lifespan = LifespanOn(config=Config(app=app))
# skip over actually running the lifespan, that is tested
# in the lifespan tests
lifespan.state.update({"a": 123, "b": [1]})

for _ in range(2):
protocol = get_connected_protocol(app, protocol_cls, lifespan=lifespan)
protocol.data_received(SIMPLE_GET_REQUEST)
await protocol.loop.run_one()
assert b"HTTP/1.1 200 OK" in protocol.transport.buffer
assert b"Hi!" in protocol.transport.buffer

assert not expected_states # consumed
55 changes: 55 additions & 0 deletions tests/protocols/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import typing
from copy import deepcopy

import httpx
import pytest
Expand Down Expand Up @@ -1087,3 +1088,57 @@ async def open_connection(url):
async with run_server(config):
headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"]


@pytest.mark.anyio
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
async def test_lifespan_state(ws_protocol_cls, http_protocol_cls, unused_tcp_port: int):
expected_states = [
{"a": 123, "b": [1]},
{"a": 123, "b": [1, 2]},
]

actual_states = []

async def lifespan_app(scope, receive, send):
message = await receive()
assert message["type"] == "lifespan.startup"
scope["state"]["a"] = 123
scope["state"]["b"] = [1]
await send({"type": "lifespan.startup.complete"})
message = await receive()
assert message["type"] == "lifespan.shutdown"
await send({"type": "lifespan.shutdown.complete"})

class App(WebSocketResponse):
async def websocket_connect(self, message):
actual_states.append(deepcopy(self.scope["state"]))
self.scope["state"]["a"] = 456
self.scope["state"]["b"].append(2)
await self.send({"type": "websocket.accept"})

async def open_connection(url):
async with websockets.connect(url) as websocket:
return websocket.open

async def app_wrapper(scope, receive, send):
if scope["type"] == "lifespan":
return await lifespan_app(scope, receive, send)
else:
return await App(scope, receive, send)

config = Config(
app=app_wrapper,
ws=ws_protocol_cls,
http=http_protocol_cls,
lifespan="on",
port=unused_tcp_port,
)
async with run_server(config):
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert is_open
is_open = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}")
assert is_open

assert expected_states == actual_states
6 changes: 4 additions & 2 deletions tests/test_auto_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_loop_auto():
async def test_http_auto():
config = Config(app=app)
server_state = ServerState()
protocol = AutoHTTPProtocol(config=config, server_state=server_state)
protocol = AutoHTTPProtocol(config=config, server_state=server_state, app_state={})
expected_http = "H11Protocol" if httptools is None else "HttpToolsProtocol"
assert type(protocol).__name__ == expected_http

Expand All @@ -54,6 +54,8 @@ async def test_http_auto():
async def test_websocket_auto():
config = Config(app=app)
server_state = ServerState()
protocol = AutoWebSocketsProtocol(config=config, server_state=server_state)
protocol = AutoWebSocketsProtocol(
config=config, server_state=server_state, app_state={}
)
expected_websockets = "WSProtocol" if websockets is None else "WebSocketProtocol"
assert type(protocol).__name__ == expected_websockets
25 changes: 25 additions & 0 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ async def asgi3app(scope, receive, send):
assert scope == {
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.0"},
"state": {},
}

async def test():
Expand All @@ -188,6 +189,7 @@ def asgi2app(scope):
assert scope == {
"type": "lifespan",
"asgi": {"version": "2.0", "spec_version": "2.0"},
"state": {},
}

async def asgi(receive, send):
Expand Down Expand Up @@ -245,3 +247,26 @@ async def test():
assert "the lifespan event failed" in error_messages.pop(0)
assert "Application shutdown failed. Exiting." in error_messages.pop(0)
loop.close()


def test_lifespan_state():
async def app(scope, receive, send):
message = await receive()
assert message["type"] == "lifespan.startup"
await send({"type": "lifespan.startup.complete"})
scope["state"]["foo"] = 123
message = await receive()
assert message["type"] == "lifespan.shutdown"
await send({"type": "lifespan.shutdown.complete"})

async def test():
config = Config(app=app, lifespan="on")
lifespan = LifespanOn(config)

await lifespan.startup()
assert lifespan.state == {"foo": 123}
await lifespan.shutdown()

loop = asyncio.new_event_loop()
loop.run_until_complete(test())
loop.close()
3 changes: 3 additions & 0 deletions uvicorn/lifespan/off.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Any, Dict

from uvicorn import Config


class LifespanOff:
def __init__(self, config: Config) -> None:
self.should_exit = False
self.state: Dict[str, Any] = {}

async def startup(self) -> None:
pass
Expand Down
6 changes: 4 additions & 2 deletions uvicorn/lifespan/on.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from asyncio import Queue
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Any, Dict, Union

from uvicorn import Config

Expand Down Expand Up @@ -42,6 +42,7 @@ def __init__(self, config: Config) -> None:
self.startup_failed = False
self.shutdown_failed = False
self.should_exit = False
self.state: Dict[str, Any] = {}

async def startup(self) -> None:
self.logger.info("Waiting for application startup.")
Expand Down Expand Up @@ -79,9 +80,10 @@ async def shutdown(self) -> None:
async def main(self) -> None:
try:
app = self.config.loaded_app
scope: LifespanScope = {
scope: LifespanScope = { # type: ignore[typeddict-item]
"type": "lifespan",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.0"},
"state": self.state,
}
await app(scope, self.receive, self.send)
except BaseException as exc:
Expand Down
20 changes: 18 additions & 2 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import http
import logging
import sys
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
from urllib.parse import unquote

import h11
Expand Down Expand Up @@ -42,6 +52,7 @@
HTTPScope,
)


H11Event = Union[
h11.Request,
h11.InformationalResponse,
Expand Down Expand Up @@ -69,6 +80,7 @@ def __init__(
self,
config: Config,
server_state: ServerState,
app_state: Dict[str, Any],
adriangb marked this conversation as resolved.
Show resolved Hide resolved
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
Expand All @@ -89,6 +101,7 @@ def __init__(
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
self.app_state = app_state

# Timeouts
self.timeout_keep_alive_task: Optional[asyncio.TimerHandle] = None
Expand Down Expand Up @@ -229,6 +242,7 @@ def handle_events(self) -> None:
"raw_path": raw_path,
"query_string": query_string,
"headers": self.headers,
"state": self.app_state,
}

upgrade = self._get_upgrade()
Expand Down Expand Up @@ -290,7 +304,9 @@ def handle_websocket_upgrade(self, event: H11Event) -> None:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
config=self.config, server_state=self.server_state
config=self.config,
server_state=self.server_state,
app_state=self.app_state,
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down
21 changes: 19 additions & 2 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
import urllib
from asyncio.events import TimerHandle
from collections import deque
from typing import TYPE_CHECKING, Callable, Deque, List, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Deque,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)

import httptools

Expand Down Expand Up @@ -44,6 +55,7 @@
HTTPScope,
)


HEADER_RE = re.compile(b'[\x00-\x1F\x7F()<>@,;:[]={} \t\\"]')
HEADER_VALUE_RE = re.compile(b"[\x00-\x1F\x7F]")

Expand All @@ -66,6 +78,7 @@ def __init__(
self,
config: Config,
server_state: ServerState,
app_state: Dict[str, Any],
_loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if not config.loaded:
Expand All @@ -81,6 +94,7 @@ def __init__(
self.ws_protocol_class = config.ws_protocol_class
self.root_path = config.root_path
self.limit_concurrency = config.limit_concurrency
self.app_state = app_state

# Timeouts
self.timeout_keep_alive_task: Optional[TimerHandle] = None
Expand Down Expand Up @@ -201,7 +215,9 @@ def handle_websocket_upgrade(self) -> None:
output += [name, b": ", value, b"\r\n"]
output.append(b"\r\n")
protocol = self.ws_protocol_class( # type: ignore[call-arg, misc]
config=self.config, server_state=self.server_state
config=self.config,
server_state=self.server_state,
app_state=self.app_state,
)
protocol.connection_made(self.transport)
protocol.data_received(b"".join(output))
Expand Down Expand Up @@ -237,6 +253,7 @@ def on_message_begin(self) -> None:
"scheme": self.scheme,
"root_path": self.root_path,
"headers": self.headers,
"state": self.app_state,
}

# Parser callbacks
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/websockets/auto.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import typing

AutoWebSocketsProtocol: typing.Optional[typing.Type[asyncio.Protocol]]
AutoWebSocketsProtocol: typing.Optional[typing.Callable[..., asyncio.Protocol]]
try:
import websockets # noqa
except ImportError: # pragma: no cover
Expand Down
Loading