From 158e69219bbeff81b6a84f08308563dbc633b462 Mon Sep 17 00:00:00 2001 From: Jaakko Lappalainen Date: Sat, 10 Apr 2021 19:22:36 +0000 Subject: [PATCH 1/3] added typing --- uvicorn/protocols/http/h11_impl.py | 79 +++++++++++++++++------------- 1 file changed, 44 insertions(+), 35 deletions(-) diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index dc887cd6e..95067ac78 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -1,10 +1,13 @@ import asyncio import http import logging +from typing import Any, ByteString, Callable from urllib.parse import unquote import h11 +from uvicorn._types import HTTPScope +from uvicorn.config import Config from uvicorn.protocols.utils import ( get_client_addr, get_local_addr, @@ -12,9 +15,10 @@ get_remote_addr, is_ssl, ) +from uvicorn.server import ServerState -def _get_status_phrase(status_code): +def _get_status_phrase(status_code: http.HTTPStatus) -> ByteString: try: return http.HTTPStatus(status_code).phrase.encode() except ValueError: @@ -33,38 +37,38 @@ def _get_status_phrase(status_code): class FlowControl: - def __init__(self, transport): + def __init__(self, transport: asyncio.Transport) -> None: self._transport = transport self.read_paused = False self.write_paused = False self._is_writable_event = asyncio.Event() self._is_writable_event.set() - async def drain(self): + async def drain(self) -> bool: await self._is_writable_event.wait() - def pause_reading(self): + def pause_reading(self) -> None: if not self.read_paused: self.read_paused = True self._transport.pause_reading() - def resume_reading(self): + def resume_reading(self) -> None: if self.read_paused: self.read_paused = False self._transport.resume_reading() - def pause_writing(self): + def pause_writing(self) -> None: if not self.write_paused: self.write_paused = True self._is_writable_event.clear() - def resume_writing(self): + def resume_writing(self) -> None: if self.write_paused: self.write_paused = False self._is_writable_event.set() -async def service_unavailable(scope, receive, send): +async def service_unavailable(scope, receive, send) -> None: await send( { "type": "http.response.start", @@ -79,7 +83,12 @@ async def service_unavailable(scope, receive, send): class H11Protocol(asyncio.Protocol): - def __init__(self, config, server_state, _loop=None): + def __init__( + self, + config: Config, + server_state: ServerState, + _loop: asyncio.AbstractEventLoop = None, + ) -> None: if not config.loaded: config.load() @@ -117,7 +126,7 @@ def __init__(self, config, server_state, _loop=None): self.cycle = None # Protocol interface - def connection_made(self, transport): + def connection_made(self, transport: asyncio.Transport) -> None: self.connections.add(self) self.transport = transport @@ -130,7 +139,7 @@ def connection_made(self, transport): prefix = "%s:%d - " % tuple(self.client) if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sConnection made", prefix) - def connection_lost(self, exc): + def connection_lost(self, exc: Any) -> None: self.connections.discard(self) if self.logger.level <= TRACE_LOG_LEVEL: @@ -152,21 +161,21 @@ def connection_lost(self, exc): if self.flow is not None: self.flow.resume_writing() - def eof_received(self): + def eof_received(self) -> None: pass - def _unset_keepalive_if_required(self): + def _unset_keepalive_if_required(self) -> None: if self.timeout_keep_alive_task is not None: self.timeout_keep_alive_task.cancel() self.timeout_keep_alive_task = None - def data_received(self, data): + def data_received(self, data) -> None: self._unset_keepalive_if_required() self.conn.receive_data(data) self.handle_events() - def handle_events(self): + def handle_events(self) -> None: while True: try: event = self.conn.next_event() @@ -259,7 +268,7 @@ def handle_events(self): self.cycle.more_body = False self.cycle.message_event.set() - def handle_upgrade(self, event): + def handle_upgrade(self, event: asyncio.Event) -> None: upgrade_value = None for name, value in self.headers: if name == b"upgrade": @@ -304,7 +313,7 @@ def handle_upgrade(self, event): protocol.data_received(b"".join(output)) self.transport.set_protocol(protocol) - def on_response_complete(self): + def on_response_complete(self) -> None: self.server_state.total_requests += 1 if self.transport.is_closing(): @@ -325,7 +334,7 @@ def on_response_complete(self): self.conn.start_next_cycle() self.handle_events() - def shutdown(self): + def shutdown(self) -> None: """ Called by the server to commence a graceful shutdown. """ @@ -336,19 +345,19 @@ def shutdown(self): else: self.cycle.keep_alive = False - def pause_writing(self): + def pause_writing(self) -> None: """ Called by the transport when the write buffer exceeds the high water mark. """ self.flow.pause_writing() - def resume_writing(self): + def resume_writing(self) -> None: """ Called by the transport when the write buffer drops below the low water mark. """ self.flow.resume_writing() - def timeout_keep_alive_handler(self): + def timeout_keep_alive_handler(self) -> None: """ Called on a keep-alive connection if no new data is received after a short delay. @@ -362,16 +371,16 @@ def timeout_keep_alive_handler(self): class RequestResponseCycle: def __init__( self, - scope, - conn, - transport, - flow, - logger, - access_logger, - access_log, - default_headers, - message_event, - on_response, + scope: HTTPScope, + conn: h11.Connection, + transport: asyncio.Transport, + flow: FlowControl, + logger: logging.Logger, + access_logger: logging.Logger, + access_log: bool, + default_headers: list, + message_event: asyncio.Event, + on_response: Callable, ): self.scope = scope self.conn = conn @@ -398,7 +407,7 @@ def __init__( self.response_complete = False # ASGI exception wrapper - async def run_asgi(self, app): + async def run_asgi(self, app: Callable) -> None: try: result = await app(self.scope, self.receive, self.send) except BaseException as exc: @@ -424,7 +433,7 @@ async def run_asgi(self, app): finally: self.on_response = None - async def send_500_response(self): + async def send_500_response(self) -> None: await self.send( { "type": "http.response.start", @@ -440,7 +449,7 @@ async def send_500_response(self): ) # ASGI interface - async def send(self, message): + async def send(self, message: dict) -> None: message_type = message["type"] if self.flow.write_paused and not self.disconnected: @@ -519,7 +528,7 @@ async def send(self, message): self.transport.close() self.on_response() - async def receive(self): + async def receive(self) -> dict: if self.waiting_for_100_continue and not self.transport.is_closing(): event = h11.InformationalResponse( status_code=100, headers=[], reason="Continue" From e133cb1e0457fff2904b25350162cc5c559dc422 Mon Sep 17 00:00:00 2001 From: Jaakko Lappalainen Date: Sun, 11 Apr 2021 14:32:39 +0000 Subject: [PATCH 2/3] bring new types --- uvicorn/_types.py | 155 ++++++++++++++++++++++++++++- uvicorn/protocols/http/h11_impl.py | 6 +- 2 files changed, 155 insertions(+), 6 deletions(-) diff --git a/uvicorn/_types.py b/uvicorn/_types.py index 501d04337..e63f4723c 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -1,10 +1,10 @@ import sys -from typing import Dict, Iterable, Optional, Tuple, Union +from typing import Awaitable, Callable, Dict, Iterable, Optional, Tuple, Type, Union if sys.version_info < (3, 8): - from typing_extensions import Literal, TypedDict + from typing_extensions import Literal, Protocol, TypedDict else: - from typing import Literal, TypedDict + from typing import Literal, Protocol, TypedDict class ASGISpecInfo(TypedDict): @@ -65,3 +65,152 @@ class WebsocketScope(TypedDict): WWWScope = Union[HTTPScope, WebsocketScope] Scope = Union[HTTPScope, WebsocketScope, LifespanScope] + + +class HTTPRequestEvent(TypedDict): + type: Literal["http.request"] + body: bytes + more_body: bool + + +class HTTPResponseStartEvent(TypedDict): + type: Literal["http.response.start"] + status: int + headers: Iterable[Tuple[bytes, bytes]] + + +class HTTPResponseBodyEvent(TypedDict): + type: Literal["http.response.body"] + body: bytes + more_body: bool + + +class HTTPServerPushEvent(TypedDict): + type: Literal["http.response.push"] + path: str + headers: Iterable[Tuple[bytes, bytes]] + + +class HTTPDisconnectEvent(TypedDict): + type: Literal["http.disconnect"] + + +class WebsocketConnectEvent(TypedDict): + type: Literal["websocket.connect"] + + +class WebsocketAcceptEvent(TypedDict): + type: Literal["websocket.accept"] + subprotocol: Optional[str] + headers: Iterable[Tuple[bytes, bytes]] + + +class WebsocketReceiveEvent(TypedDict): + type: Literal["websocket.receive"] + bytes: Optional[bytes] + text: Optional[str] + + +class WebsocketSendEvent(TypedDict): + type: Literal["websocket.send"] + bytes: Optional[bytes] + text: Optional[str] + + +class WebsocketResponseStartEvent(TypedDict): + type: Literal["websocket.http.response.start"] + status: int + headers: Iterable[Tuple[bytes, bytes]] + + +class WebsocketResponseBodyEvent(TypedDict): + type: Literal["websocket.http.response.body"] + body: bytes + more_body: bool + + +class WebsocketDisconnectEvent(TypedDict): + type: Literal["websocket.disconnect"] + code: int + + +class WebsocketCloseEvent(TypedDict): + type: Literal["websocket.close"] + code: int + reason: Optional[str] + + +class LifespanStartupEvent(TypedDict): + type: Literal["lifespan.startup"] + + +class LifespanShutdownEvent(TypedDict): + type: Literal["lifespan.shutdown"] + + +class LifespanStartupCompleteEvent(TypedDict): + type: Literal["lifespan.startup.complete"] + + +class LifespanStartupFailedEvent(TypedDict): + type: Literal["lifespan.startup.failed"] + message: str + + +class LifespanShutdownCompleteEvent(TypedDict): + type: Literal["lifespan.shutdown.complete"] + + +class LifespanShutdownFailedEvent(TypedDict): + type: Literal["lifespan.shutdown.failed"] + message: str + + +ASGIReceiveEvent = Union[ + HTTPRequestEvent, + HTTPDisconnectEvent, + WebsocketConnectEvent, + WebsocketReceiveEvent, + WebsocketDisconnectEvent, + LifespanStartupEvent, + LifespanShutdownEvent, +] + + +ASGISendEvent = Union[ + HTTPResponseStartEvent, + HTTPResponseBodyEvent, + HTTPServerPushEvent, + HTTPDisconnectEvent, + WebsocketAcceptEvent, + WebsocketSendEvent, + WebsocketResponseStartEvent, + WebsocketResponseBodyEvent, + WebsocketCloseEvent, + LifespanStartupCompleteEvent, + LifespanStartupFailedEvent, + LifespanShutdownCompleteEvent, + LifespanShutdownFailedEvent, +] + + +ASGIReceiveCallable = Callable[[], Awaitable[ASGIReceiveEvent]] +ASGISendCallable = Callable[[ASGISendEvent], Awaitable[None]] + + +class ASGI2Protocol(Protocol): + def __init__(self, scope: Scope) -> None: + ... + + async def __call__( + self, receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: + ... + + +ASGI2Application = Type[ASGI2Protocol] +ASGI3Application = Callable[ + [Scope, ASGIReceiveCallable, ASGISendCallable], + Awaitable[None], +] +ASGIApplication = Union[ASGI2Application, ASGI3Application] diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 95067ac78..015758dcd 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -6,7 +6,7 @@ import h11 -from uvicorn._types import HTTPScope +from uvicorn._types import ASGIReceiveEvent, ASGISendEvent, HTTPScope from uvicorn.config import Config from uvicorn.protocols.utils import ( get_client_addr, @@ -449,7 +449,7 @@ async def send_500_response(self) -> None: ) # ASGI interface - async def send(self, message: dict) -> None: + async def send(self, message: ASGISendEvent) -> None: message_type = message["type"] if self.flow.write_paused and not self.disconnected: @@ -528,7 +528,7 @@ async def send(self, message: dict) -> None: self.transport.close() self.on_response() - async def receive(self) -> dict: + async def receive(self) -> ASGIReceiveEvent: if self.waiting_for_100_continue and not self.transport.is_closing(): event = h11.InformationalResponse( status_code=100, headers=[], reason="Continue" From 09465a61575a54ce64d7563b22ad122de3366e1b Mon Sep 17 00:00:00 2001 From: Jaakko Lappalainen Date: Sun, 11 Apr 2021 14:36:34 +0000 Subject: [PATCH 3/3] add type to app --- uvicorn/protocols/http/h11_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 015758dcd..4b6eaff2a 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -6,7 +6,7 @@ import h11 -from uvicorn._types import ASGIReceiveEvent, ASGISendEvent, HTTPScope +from uvicorn._types import ASGI3Application, ASGIReceiveEvent, ASGISendEvent, HTTPScope from uvicorn.config import Config from uvicorn.protocols.utils import ( get_client_addr, @@ -407,7 +407,7 @@ def __init__( self.response_complete = False # ASGI exception wrapper - async def run_asgi(self, app: Callable) -> None: + async def run_asgi(self, app: ASGI3Application) -> None: try: result = await app(self.scope, self.receive, self.send) except BaseException as exc: