Skip to content

Commit

Permalink
Improve type hints on WebSockets implementations (#2335)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed May 14, 2024
1 parent 14bdf04 commit b9c03a8
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 51 deletions.
15 changes: 2 additions & 13 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,8 @@
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState


Expand Down
15 changes: 2 additions & 13 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,8 @@
)
from uvicorn.config import Config
from uvicorn.logging import TRACE_LOG_LEVEL
from uvicorn.protocols.http.flow_control import (
CLOSE_HEADER,
HIGH_WATER_LIMIT,
FlowControl,
service_unavailable,
)
from uvicorn.protocols.utils import (
get_client_addr,
get_local_addr,
get_path_with_query_string,
get_remote_addr,
is_ssl,
)
from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable
from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl
from uvicorn.server import ServerState

HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]')
Expand Down
33 changes: 17 additions & 16 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from urllib.parse import unquote

import websockets
import websockets.legacy.handshake
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.extensions.base import ServerExtensionFactory
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.legacy.server import HTTPResponse
from websockets.server import WebSocketServerProtocol
from websockets.typing import Subprotocol

from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
Expand Down Expand Up @@ -53,6 +56,7 @@ def is_serving(self) -> bool:

class WebSocketProtocol(WebSocketServerProtocol):
extra_headers: list[tuple[str, str]]
logger: logging.Logger | logging.LoggerAdapter[Any]

def __init__(
self,
Expand All @@ -65,7 +69,7 @@ def __init__(
config.load()

self.config = config
self.app = config.loaded_app
self.app = cast(ASGI3Application, config.loaded_app)
self.loop = _loop or asyncio.get_event_loop()
self.root_path = config.root_path
self.app_state = app_state
Expand All @@ -92,7 +96,7 @@ def __init__(

self.ws_server: Server = Server() # type: ignore[assignment]

extensions = []
extensions: list[ServerExtensionFactory] = []
if self.config.ws_per_message_deflate:
extensions.append(ServerPerMessageDeflateFactory())

Expand Down Expand Up @@ -147,10 +151,10 @@ def shutdown(self) -> None:
self.send_500_response()
self.transport.close()

def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)

async def process_request(self, path: str, headers: Headers) -> HTTPResponse | None:
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
"""
This hook is called to determine if the websocket should return
an HTTP response and close.
Expand All @@ -161,15 +165,15 @@ async def process_request(self, path: str, headers: Headers) -> HTTPResponse | N
"""
path_portion, _, query_string = path.partition("?")

websockets.legacy.handshake.check_request(headers)
websockets.legacy.handshake.check_request(request_headers)

subprotocols = []
for header in headers.get_all("Sec-WebSocket-Protocol"):
subprotocols: list[str] = []
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
subprotocols.extend([token.strip() for token in header.split(",")])

asgi_headers = [
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for name, value in headers.raw_items()
for name, value in request_headers.raw_items()
]
path = unquote(path_portion)
full_path = self.root_path + path
Expand Down Expand Up @@ -237,14 +241,13 @@ async def run_asgi(self) -> None:
termination states.
"""
try:
result = await self.app(self.scope, self.asgi_receive, self.asgi_send)
result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value]
except ClientDisconnected:
self.closed_event.set()
self.transport.close()
except BaseException as exc:
except BaseException:
self.closed_event.set()
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
self.logger.exception("Exception in ASGI application\n")
if not self.handshake_started_event.is_set():
self.send_500_response()
else:
Expand All @@ -253,13 +256,11 @@ async def run_asgi(self) -> None:
else:
self.closed_event.set()
if not self.handshake_started_event.is_set():
msg = "ASGI callable returned without sending handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without sending handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
await self.handshake_completed_event.wait()
self.transport.close()

Expand Down
17 changes: 8 additions & 9 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
import typing
from typing import Literal
from typing import Literal, cast
from urllib.parse import unquote

import wsproto
Expand All @@ -13,6 +13,7 @@
from wsproto.utilities import LocalProtocolError, RemoteProtocolError

from uvicorn._types import (
ASGI3Application,
ASGISendEvent,
WebSocketAcceptEvent,
WebSocketCloseEvent,
Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
config.load()

self.config = config
self.app = config.loaded_app
self.app = cast(ASGI3Application, config.loaded_app)
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path
Expand Down Expand Up @@ -156,7 +157,7 @@ def shutdown(self) -> None:
self.send_500_response()
self.transport.close()

def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[None]) -> None:
self.tasks.discard(task)

# Event handlers
Expand Down Expand Up @@ -220,7 +221,7 @@ def handle_ping(self, event: events.Ping) -> None:
def send_500_response(self) -> None:
if self.response_started or self.handshake_complete:
return # we cannot send responses anymore
headers = [
headers: list[tuple[bytes, bytes]] = [
(b"content-type", b"text/plain; charset=utf-8"),
(b"connection", b"close"),
]
Expand All @@ -230,7 +231,7 @@ def send_500_response(self) -> None:

async def run_asgi(self) -> None:
try:
result = await self.app(self.scope, self.receive, self.send)
result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value]
except ClientDisconnected:
self.transport.close()
except BaseException:
Expand All @@ -239,13 +240,11 @@ async def run_asgi(self) -> None:
self.transport.close()
else:
if not self.handshake_complete:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
self.logger.error("ASGI callable returned without completing handshake.")
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
self.transport.close()

async def send(self, message: ASGISendEvent) -> None:
Expand Down

0 comments on commit b9c03a8

Please sign in to comment.