diff --git a/requirements.txt b/requirements.txt index 3a450fb2d3..77c9e2d494 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ pytest pytest-mock mypy types-click +types-pyyaml trustme cryptography coverage diff --git a/setup.cfg b/setup.cfg index edd940a640..ed475e94a2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,9 +9,12 @@ follow_imports = silent files = uvicorn/lifespan, tests/test_lifespan.py, + uvicorn/config.py, + tests/test_config.py, uvicorn/middleware/message_logger.py, uvicorn/supervisors/basereload.py, uvicorn/importer.py, + tests/importer/test_importer.py, uvicorn/protocols/utils.py, uvicorn/loops, uvicorn/main.py, diff --git a/tests/importer/test_importer.py b/tests/importer/test_importer.py index b20317c9d9..d9eb3a86b7 100644 --- a/tests/importer/test_importer.py +++ b/tests/importer/test_importer.py @@ -3,40 +3,40 @@ from uvicorn.importer import ImportFromStringError, import_from_string -def test_invalid_format(): +def test_invalid_format() -> None: with pytest.raises(ImportFromStringError) as exc_info: import_from_string("example:") expected = 'Import string "example:" must be in format ":".' assert expected in str(exc_info.value) -def test_invalid_module(): +def test_invalid_module() -> None: with pytest.raises(ImportFromStringError) as exc_info: import_from_string("module_does_not_exist:myattr") expected = 'Could not import module "module_does_not_exist".' assert expected in str(exc_info.value) -def test_invalid_attr(): +def test_invalid_attr() -> None: with pytest.raises(ImportFromStringError) as exc_info: import_from_string("tempfile:attr_does_not_exist") expected = 'Attribute "attr_does_not_exist" not found in module "tempfile".' assert expected in str(exc_info.value) -def test_internal_import_error(): +def test_internal_import_error() -> None: with pytest.raises(ImportError): import_from_string("tests.importer.raise_import_error:myattr") -def test_valid_import(): +def test_valid_import() -> None: instance = import_from_string("tempfile:TemporaryFile") from tempfile import TemporaryFile assert instance == TemporaryFile -def test_no_import_needed(): +def test_no_import_needed() -> None: from tempfile import TemporaryFile instance = import_from_string(TemporaryFile) diff --git a/tests/test_config.py b/tests/test_config.py index 0b7383e526..761eec4048 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,11 +2,23 @@ import logging import os import socket +import sys +import typing from copy import deepcopy +from pathlib import Path +from unittest.mock import MagicMock + +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal import pytest import yaml +from asgiref.typing import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope +from pytest_mock import MockerFixture +from uvicorn._types import Environ, StartResponse from uvicorn.config import LOGGING_CONFIG, Config from uvicorn.middleware.debug import DebugMiddleware from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware @@ -15,34 +27,36 @@ @pytest.fixture -def mocked_logging_config_module(mocker): +def mocked_logging_config_module(mocker: MockerFixture) -> MagicMock: return mocker.patch("logging.config") @pytest.fixture(scope="function") -def logging_config(): +def logging_config() -> dict: return deepcopy(LOGGING_CONFIG) @pytest.fixture -def json_logging_config(logging_config): +def json_logging_config(logging_config: dict) -> str: return json.dumps(logging_config) @pytest.fixture -def yaml_logging_config(logging_config): +def yaml_logging_config(logging_config: dict) -> str: return yaml.dump(logging_config) -async def asgi_app(scope, receive, send): +async def asgi_app( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: pass # pragma: nocover -def wsgi_app(environ, start_response): +def wsgi_app(environ: Environ, start_response: StartResponse) -> None: pass # pragma: nocover -def test_debug_app(): +def test_debug_app() -> None: config = Config(app=asgi_app, debug=True, proxy_headers=False) config.load() @@ -54,7 +68,9 @@ def test_debug_app(): "app, expected_should_reload", [(asgi_app, False), ("tests.test_config:asgi_app", True)], ) -def test_config_should_reload_is_set(app, expected_should_reload): +def test_config_should_reload_is_set( + app: ASGIApplication, expected_should_reload: bool +) -> None: config_debug = Config(app=app, debug=True) assert config_debug.debug is True assert config_debug.should_reload is expected_should_reload @@ -64,12 +80,12 @@ def test_config_should_reload_is_set(app, expected_should_reload): assert config_reload.should_reload is expected_should_reload -def test_reload_dir_is_set(): +def test_reload_dir_is_set() -> None: config = Config(app=asgi_app, reload=True, reload_dirs="reload_me") assert config.reload_dirs == ["reload_me"] -def test_wsgi_app(): +def test_wsgi_app() -> None: config = Config(app=wsgi_app, interface="wsgi", proxy_headers=False) config.load() @@ -78,7 +94,7 @@ def test_wsgi_app(): assert config.asgi_version == "3.0" -def test_proxy_headers(): +def test_proxy_headers() -> None: config = Config(app=asgi_app) config.load() @@ -86,13 +102,13 @@ def test_proxy_headers(): assert isinstance(config.loaded_app, ProxyHeadersMiddleware) -def test_app_unimportable_module(): +def test_app_unimportable_module() -> None: config = Config(app="no.such:app") with pytest.raises(ImportError): config.load() -def test_app_unimportable_other(caplog): +def test_app_unimportable_other(caplog: pytest.LogCaptureFixture) -> None: config = Config(app="tests.test_config:app") with pytest.raises(SystemExit): config.load() @@ -107,8 +123,8 @@ def test_app_unimportable_other(caplog): ) -def test_app_factory(caplog): - def create_app(): +def test_app_factory(caplog: pytest.LogCaptureFixture) -> None: + def create_app() -> ASGIApplication: return asgi_app config = Config(app=create_app, factory=True, proxy_headers=False) @@ -131,13 +147,13 @@ def create_app(): config.load() -def test_concrete_http_class(): +def test_concrete_http_class() -> None: config = Config(app=asgi_app, http=H11Protocol) config.load() assert config.http_protocol_class is H11Protocol -def test_socket_bind(): +def test_socket_bind() -> None: config = Config(app=asgi_app) config.load() sock = config.bind_socket() @@ -145,7 +161,10 @@ def test_socket_bind(): sock.close() -def test_ssl_config(tls_ca_certificate_pem_path, tls_ca_certificate_private_key_path): +def test_ssl_config( + tls_ca_certificate_pem_path: str, + tls_ca_certificate_private_key_path: str, +) -> None: config = Config( app=asgi_app, ssl_certfile=tls_ca_certificate_pem_path, @@ -156,7 +175,7 @@ def test_ssl_config(tls_ca_certificate_pem_path, tls_ca_certificate_private_key_ assert config.is_ssl is True -def test_ssl_config_combined(tls_certificate_pem_path): +def test_ssl_config_combined(tls_certificate_pem_path: str) -> None: config = Config( app=asgi_app, ssl_certfile=tls_certificate_pem_path, @@ -166,8 +185,10 @@ def test_ssl_config_combined(tls_certificate_pem_path): assert config.is_ssl is True -def asgi2_app(scope): - async def asgi(receive, send): # pragma: nocover +def asgi2_app(scope: Scope) -> typing.Callable: + async def asgi( + receive: ASGIReceiveCallable, send: ASGISendCallable + ) -> None: # pragma: nocover pass return asgi # pragma: nocover @@ -176,7 +197,9 @@ async def asgi(receive, send): # pragma: nocover @pytest.mark.parametrize( "app, expected_interface", [(asgi_app, "3.0"), (asgi2_app, "2.0")] ) -def test_asgi_version(app, expected_interface): +def test_asgi_version( + app: ASGIApplication, expected_interface: Literal["2.0", "3.0"] +) -> None: config = Config(app=app) config.load() assert config.asgi_version == expected_interface @@ -191,7 +214,11 @@ def test_asgi_version(app, expected_interface): pytest.param(False, False, id="use_colors_disabled"), ], ) -def test_log_config_default(mocked_logging_config_module, use_colors, expected): +def test_log_config_default( + mocked_logging_config_module: MagicMock, + use_colors: typing.Optional[bool], + expected: typing.Optional[bool], +) -> None: """ Test that one can specify the use_colors option when using the default logging config. @@ -206,8 +233,11 @@ def test_log_config_default(mocked_logging_config_module, use_colors, expected): def test_log_config_json( - mocked_logging_config_module, logging_config, json_logging_config, mocker -): + mocked_logging_config_module: MagicMock, + logging_config: dict, + json_logging_config: str, + mocker: MockerFixture, +) -> None: """ Test that one can load a json config from disk. """ @@ -224,12 +254,12 @@ def test_log_config_json( @pytest.mark.parametrize("config_filename", ["log_config.yml", "log_config.yaml"]) def test_log_config_yaml( - mocked_logging_config_module, - logging_config, - yaml_logging_config, - mocker, - config_filename, -): + mocked_logging_config_module: MagicMock, + logging_config: dict, + yaml_logging_config: str, + mocker: MockerFixture, + config_filename: str, +) -> None: """ Test that one can load a yaml config from disk. """ @@ -244,7 +274,7 @@ def test_log_config_yaml( mocked_logging_config_module.dictConfig.assert_called_once_with(logging_config) -def test_log_config_file(mocked_logging_config_module): +def test_log_config_file(mocked_logging_config_module: MagicMock) -> None: """ Test that one can load a configparser config from disk. """ @@ -257,20 +287,25 @@ def test_log_config_file(mocked_logging_config_module): @pytest.fixture(params=[0, 1]) -def web_concurrency(request): - yield request.param +def web_concurrency(request: pytest.FixtureRequest) -> typing.Iterator[int]: + yield getattr(request, "param") if os.getenv("WEB_CONCURRENCY"): del os.environ["WEB_CONCURRENCY"] @pytest.fixture(params=["127.0.0.1", "127.0.0.2"]) -def forwarded_allow_ips(request): - yield request.param +def forwarded_allow_ips(request: pytest.FixtureRequest) -> typing.Iterator[str]: + yield getattr(request, "param") if os.getenv("FORWARDED_ALLOW_IPS"): del os.environ["FORWARDED_ALLOW_IPS"] -def test_env_file(web_concurrency: int, forwarded_allow_ips: str, caplog, tmp_path): +def test_env_file( + web_concurrency: int, + forwarded_allow_ips: str, + caplog: pytest.LogCaptureFixture, + tmp_path: Path, +) -> None: """ Test that one can load environment variables using an env file. """ @@ -284,7 +319,7 @@ def test_env_file(web_concurrency: int, forwarded_allow_ips: str, caplog, tmp_pa config = Config(app=asgi_app, env_file=fp) config.load() - assert config.workers == int(os.getenv("WEB_CONCURRENCY")) + assert config.workers == int(str(os.getenv("WEB_CONCURRENCY"))) assert config.forwarded_allow_ips == os.getenv("FORWARDED_ALLOW_IPS") assert len(caplog.records) == 1 assert f"Loading environment from '{fp}'" in caplog.records[0].message @@ -297,7 +332,7 @@ def test_env_file(web_concurrency: int, forwarded_allow_ips: str, caplog, tmp_pa pytest.param(False, 0, id="access log disabled shouldn't have handlers"), ], ) -def test_config_access_log(access_log: bool, handlers: int): +def test_config_access_log(access_log: bool, handlers: int) -> None: config = Config(app=asgi_app, access_log=access_log) config.load() @@ -306,7 +341,7 @@ def test_config_access_log(access_log: bool, handlers: int): @pytest.mark.parametrize("log_level", [5, 10, 20, 30, 40, 50]) -def test_config_log_level(log_level): +def test_config_log_level(log_level: int) -> None: config = Config(app=asgi_app, log_level=log_level) config.load() @@ -316,7 +351,7 @@ def test_config_log_level(log_level): assert config.log_level == log_level -def test_ws_max_size(): +def test_ws_max_size() -> None: config = Config(app=asgi_app, ws_max_size=1000) config.load() assert config.ws_max_size == 1000 diff --git a/uvicorn/_types.py b/uvicorn/_types.py new file mode 100644 index 0000000000..20e9014d5a --- /dev/null +++ b/uvicorn/_types.py @@ -0,0 +1,11 @@ +import types +import typing + +# WSGI +Environ = typing.MutableMapping[str, typing.Any] +ExcInfo = typing.Tuple[ + typing.Type[BaseException], BaseException, typing.Optional[types.TracebackType] +] +StartResponse = typing.Callable[ + [str, typing.Iterable[typing.Tuple[str, str]], typing.Optional[ExcInfo]], None +] diff --git a/uvicorn/config.py b/uvicorn/config.py index d54c49fca0..173492e5ae 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -7,7 +7,7 @@ import socket import ssl import sys -from typing import List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from uvicorn.logging import TRACE_LOG_LEVEL @@ -17,6 +17,7 @@ from typing import Literal import click +from asgiref.typing import ASGIApplication try: import yaml @@ -33,7 +34,13 @@ from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware from uvicorn.middleware.wsgi import WSGIMiddleware -LOG_LEVELS = { +HTTPProtocolType = Literal["auto", "h11", "httptools"] +WSProtocolType = Literal["auto", "none", "websockets", "wsproto"] +LifespanType = Literal["auto", "on", "off"] +LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"] +InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"] + +LOG_LEVELS: Dict[str, int] = { "critical": logging.CRITICAL, "error": logging.ERROR, "warning": logging.WARNING, @@ -41,36 +48,36 @@ "debug": logging.DEBUG, "trace": TRACE_LOG_LEVEL, } -HTTP_PROTOCOLS = { +HTTP_PROTOCOLS: Dict[HTTPProtocolType, str] = { "auto": "uvicorn.protocols.http.auto:AutoHTTPProtocol", "h11": "uvicorn.protocols.http.h11_impl:H11Protocol", "httptools": "uvicorn.protocols.http.httptools_impl:HttpToolsProtocol", } -WS_PROTOCOLS = { +WS_PROTOCOLS: Dict[WSProtocolType, Optional[str]] = { "auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol", "none": None, "websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", "wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol", } -LIFESPAN = { +LIFESPAN: Dict[LifespanType, str] = { "auto": "uvicorn.lifespan.on:LifespanOn", "on": "uvicorn.lifespan.on:LifespanOn", "off": "uvicorn.lifespan.off:LifespanOff", } -LOOP_SETUPS = { +LOOP_SETUPS: Dict[LoopSetupType, Optional[str]] = { "none": None, "auto": "uvicorn.loops.auto:auto_loop_setup", "asyncio": "uvicorn.loops.asyncio:asyncio_setup", "uvloop": "uvicorn.loops.uvloop:uvloop_setup", } -INTERFACES = ["auto", "asgi3", "asgi2", "wsgi"] +INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"] # Fallback to 'ssl.PROTOCOL_SSLv23' in order to support Python < 3.5.3. -SSL_PROTOCOL_VERSION = getattr(ssl, "PROTOCOL_TLS", ssl.PROTOCOL_SSLv23) +SSL_PROTOCOL_VERSION: int = getattr(ssl, "PROTOCOL_TLS", ssl.PROTOCOL_SSLv23) -LOGGING_CONFIG = { +LOGGING_CONFIG: dict = { "version": 1, "disable_existing_loggers": False, "formatters": { @@ -107,8 +114,14 @@ def create_ssl_context( - certfile, keyfile, password, ssl_version, cert_reqs, ca_certs, ciphers -): + certfile: Union[str, os.PathLike], + keyfile: Optional[Union[str, os.PathLike]], + password: Optional[str], + ssl_version: int, + cert_reqs: int, + ca_certs: Optional[Union[str, os.PathLike]], + ciphers: Optional[str], +) -> ssl.SSLContext: ctx = ssl.SSLContext(ssl_version) get_password = (lambda: password) if password else None ctx.load_cert_chain(certfile, keyfile, get_password) @@ -123,49 +136,49 @@ def create_ssl_context( class Config: def __init__( self, - app, - host="127.0.0.1", - port=8000, - uds=None, - fd=None, - loop="auto", - http="auto", - ws="auto", - ws_max_size=16 * 1024 * 1024, - ws_ping_interval=20, - ws_ping_timeout=20, - lifespan="auto", - env_file=None, - log_config=LOGGING_CONFIG, - log_level=None, - access_log=True, - use_colors=None, - interface="auto", - debug=False, - reload=False, - reload_dirs=None, - reload_delay=None, - workers=None, - proxy_headers=True, - server_header=True, - date_header=True, - forwarded_allow_ips=None, - root_path="", - limit_concurrency=None, - limit_max_requests=None, - backlog=2048, - timeout_keep_alive=5, - timeout_notify=30, - callback_notify=None, - ssl_keyfile=None, - ssl_certfile=None, - ssl_keyfile_password=None, - ssl_version=SSL_PROTOCOL_VERSION, - ssl_cert_reqs=ssl.CERT_NONE, - ssl_ca_certs=None, - ssl_ciphers="TLSv1", - headers=None, - factory=False, + app: Union[ASGIApplication, Callable, str], + host: str = "127.0.0.1", + port: int = 8000, + uds: Optional[str] = None, + fd: Optional[int] = None, + loop: LoopSetupType = "auto", + http: Union[Type[asyncio.Protocol], HTTPProtocolType] = "auto", + ws: Union[Type[asyncio.Protocol], WSProtocolType] = "auto", + ws_max_size: int = 16 * 1024 * 1024, + ws_ping_interval: int = 20, + ws_ping_timeout: int = 20, + lifespan: LifespanType = "auto", + env_file: Optional[Union[str, os.PathLike]] = None, + log_config: Optional[Union[dict, str]] = LOGGING_CONFIG, + log_level: Optional[Union[str, int]] = None, + access_log: bool = True, + use_colors: Optional[bool] = None, + interface: InterfaceType = "auto", + debug: bool = False, + reload: bool = False, + reload_dirs: Optional[Union[List[str], str]] = None, + reload_delay: Optional[float] = None, + workers: Optional[int] = None, + proxy_headers: bool = True, + server_header: bool = True, + date_header: bool = True, + forwarded_allow_ips: Optional[str] = None, + root_path: str = "", + limit_concurrency: Optional[int] = None, + limit_max_requests: Optional[int] = None, + backlog: int = 2048, + timeout_keep_alive: int = 5, + timeout_notify: int = 30, + callback_notify: Callable[..., None] = None, + ssl_keyfile: Optional[str] = None, + ssl_certfile: Optional[Union[str, os.PathLike]] = None, + ssl_keyfile_password: Optional[str] = None, + ssl_version: int = SSL_PROTOCOL_VERSION, + ssl_cert_reqs: int = ssl.CERT_NONE, + ssl_ca_certs: Optional[str] = None, + ssl_ciphers: str = "TLSv1", + headers: Optional[List[List[str]]] = None, + factory: bool = False, ): self.app = app self.host = host @@ -205,8 +218,8 @@ def __init__( self.ssl_cert_reqs = ssl_cert_reqs self.ssl_ca_certs = ssl_ca_certs self.ssl_ciphers = ssl_ciphers - self.headers = headers if headers else [] # type: List[str] - self.encoded_headers = None # type: List[Tuple[bytes, bytes]] + self.headers: List[List[str]] = headers or [] + self.encoded_headers: Optional[List[Tuple[bytes, bytes]]] = None self.factory = factory self.loaded = False @@ -237,14 +250,19 @@ def __init__( self.forwarded_allow_ips = forwarded_allow_ips @property - def asgi_version(self) -> Union[Literal["2.0"], Literal["3.0"]]: - return {"asgi2": "2.0", "asgi3": "3.0", "wsgi": "3.0"}[self.interface] + def asgi_version(self) -> Literal["2.0", "3.0"]: + mapping: Dict[str, Literal["2.0", "3.0"]] = { + "asgi2": "2.0", + "asgi3": "3.0", + "wsgi": "3.0", + } + return mapping[self.interface] @property def is_ssl(self) -> bool: return bool(self.ssl_keyfile or self.ssl_certfile) - def configure_logging(self): + def configure_logging(self) -> None: logging.addLevelName(TRACE_LOG_LEVEL, "TRACE") if self.log_config is not None: @@ -284,11 +302,12 @@ def configure_logging(self): logging.getLogger("uvicorn.access").handlers = [] logging.getLogger("uvicorn.access").propagate = False - def load(self): + def load(self) -> None: assert not self.loaded if self.is_ssl: - self.ssl = create_ssl_context( + assert self.ssl_certfile + self.ssl: Optional[ssl.SSLContext] = create_ssl_context( keyfile=self.ssl_keyfile, certfile=self.ssl_certfile, password=self.ssl_keyfile_password, @@ -308,15 +327,17 @@ def load(self): [(b"server", b"uvicorn")] + encoded_headers if b"server" not in dict(encoded_headers) and self.server_header else encoded_headers - ) # type: List[Tuple[bytes, bytes]] + ) if isinstance(self.http, str): - self.http_protocol_class = import_from_string(HTTP_PROTOCOLS[self.http]) + http_protocol_class = import_from_string(HTTP_PROTOCOLS[self.http]) + self.http_protocol_class: Type[asyncio.Protocol] = http_protocol_class else: self.http_protocol_class = self.http if isinstance(self.ws, str): - self.ws_protocol_class = import_from_string(WS_PROTOCOLS[self.ws]) + ws_protocol_class = import_from_string(WS_PROTOCOLS[self.ws]) + self.ws_protocol_class: Optional[Type[asyncio.Protocol]] = ws_protocol_class else: self.ws_protocol_class = self.ws @@ -368,12 +389,12 @@ def load(self): self.loaded = True - def setup_event_loop(self): - loop_setup = import_from_string(LOOP_SETUPS[self.loop]) + def setup_event_loop(self) -> None: + loop_setup: Optional[Callable] = import_from_string(LOOP_SETUPS[self.loop]) if loop_setup is not None: loop_setup() - def bind_socket(self): + def bind_socket(self) -> socket.socket: family = socket.AF_INET addr_format = "%s://%s:%d" @@ -408,5 +429,5 @@ def bind_socket(self): return sock @property - def should_reload(self): + def should_reload(self) -> bool: return isinstance(self.app, str) and (self.debug or self.reload) diff --git a/uvicorn/importer.py b/uvicorn/importer.py index 8cd18a8d44..e612bf134c 100644 --- a/uvicorn/importer.py +++ b/uvicorn/importer.py @@ -1,13 +1,12 @@ import importlib -from types import ModuleType -from typing import Union +from typing import Any class ImportFromStringError(Exception): pass -def import_from_string(import_str: Union[ModuleType, str]) -> ModuleType: +def import_from_string(import_str: Any) -> Any: if not isinstance(import_str, str): return import_str diff --git a/uvicorn/workers.py b/uvicorn/workers.py index 4ffe2ebd94..1b23f899c4 100644 --- a/uvicorn/workers.py +++ b/uvicorn/workers.py @@ -30,7 +30,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: logger.setLevel(self.log.access_log.level) logger.propagate = False - config_kwargs = { + config_kwargs: dict = { "app": None, "log_config": None, "timeout_keep_alive": self.cfg.keepalive,