diff --git a/mangum/adapter.py b/mangum/adapter.py index bb99cfb..b00929d 100644 --- a/mangum/adapter.py +++ b/mangum/adapter.py @@ -1,21 +1,20 @@ import logging -from itertools import chain from contextlib import ExitStack +from itertools import chain from typing import List, Optional, Type -from mangum.protocols import HTTPCycle, LifespanCycle -from mangum.handlers import ALB, HTTPGateway, APIGateway, LambdaAtEdge from mangum.exceptions import ConfigurationError +from mangum.handlers import ALB, APIGateway, HTTPGateway, LambdaAtEdge +from mangum.protocols import LifespanCycle from mangum.types import ( ASGI, - LifespanMode, LambdaConfig, - LambdaEvent, LambdaContext, + LambdaEvent, LambdaHandler, + LifespanMode, ) - logger = logging.getLogger("mangum") @@ -65,6 +64,7 @@ def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler: for handler_cls in chain(self.custom_handlers, HANDLERS): if handler_cls.infer(event, context, self.config): return handler_cls(event, context, self.config) + raise RuntimeError( # pragma: no cover "The adapter was unable to infer a handler to use for the event. This " "is likely related to how the Lambda function was invoked. (Are you " @@ -79,9 +79,9 @@ def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict: lifespan_cycle = LifespanCycle(self.app, self.lifespan) stack.enter_context(lifespan_cycle) - http_cycle = HTTPCycle(handler.scope, handler.body) - http_response = http_cycle(self.app) + cycle = handler.cycle_cls(handler.scope, handler.body) + response = cycle(self.app) - return handler(http_response) + return handler(response) assert False, "unreachable" # pragma: no cover diff --git a/mangum/handlers/__init__.py b/mangum/handlers/__init__.py index d92f864..5afca3c 100644 --- a/mangum/handlers/__init__.py +++ b/mangum/handlers/__init__.py @@ -1,6 +1,5 @@ -from mangum.handlers.api_gateway import APIGateway, HTTPGateway from mangum.handlers.alb import ALB +from mangum.handlers.api_gateway import APIGateway, HTTPGateway from mangum.handlers.lambda_at_edge import LambdaAtEdge - __all__ = ["APIGateway", "HTTPGateway", "ALB", "LambdaAtEdge"] diff --git a/mangum/handlers/alb.py b/mangum/handlers/alb.py index 875c4ee..b6d96bf 100644 --- a/mangum/handlers/alb.py +++ b/mangum/handlers/alb.py @@ -1,6 +1,6 @@ from itertools import islice -from typing import Dict, Generator, List, Tuple -from urllib.parse import urlencode, unquote, unquote_plus +from typing import Dict, Generator, List, Tuple, Type +from urllib.parse import unquote, unquote_plus, urlencode from mangum.handlers.utils import ( get_server_and_port, @@ -8,13 +8,15 @@ handle_exclude_headers, maybe_encode_body, ) +from mangum.protocols import HTTPCycle from mangum.types import ( - Response, - Scope, + Cycle, LambdaConfig, - LambdaEvent, LambdaContext, + LambdaEvent, QueryParams, + Response, + Scope, ) @@ -95,6 +97,10 @@ def __init__( self.context = context self.config = config + @property + def cycle_cls(self) -> Type[Cycle]: + return HTTPCycle + @property def body(self) -> bytes: return maybe_encode_body( @@ -104,7 +110,6 @@ def body(self) -> bytes: @property def scope(self) -> Scope: - headers = transform_headers(self.event) list_headers = [list(x) for x in headers] # Unique headers. If there are duplicates, it will use the last defined. diff --git a/mangum/handlers/api_gateway.py b/mangum/handlers/api_gateway.py index d9b30c0..ce8aa28 100644 --- a/mangum/handlers/api_gateway.py +++ b/mangum/handlers/api_gateway.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Type from urllib.parse import urlencode from mangum.handlers.utils import ( @@ -9,13 +9,15 @@ maybe_encode_body, strip_api_gateway_path, ) +from mangum.protocols import HTTPCycle from mangum.types import ( - Response, - LambdaConfig, + Cycle, Headers, - LambdaEvent, + LambdaConfig, LambdaContext, + LambdaEvent, QueryParams, + Response, Scope, ) @@ -78,6 +80,10 @@ def __init__( self.context = context self.config = config + @property + def cycle_cls(self) -> Type[Cycle]: + return HTTPCycle + @property def body(self) -> bytes: return maybe_encode_body( @@ -144,6 +150,10 @@ def __init__( self.context = context self.config = config + @property + def cycle_cls(self) -> Type[Cycle]: + return HTTPCycle + @property def body(self) -> bytes: return maybe_encode_body( diff --git a/mangum/handlers/lambda_at_edge.py b/mangum/handlers/lambda_at_edge.py index 89a3709..1bbb09e 100644 --- a/mangum/handlers/lambda_at_edge.py +++ b/mangum/handlers/lambda_at_edge.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from typing import Dict, List, Type from mangum.handlers.utils import ( handle_base64_response_body, @@ -6,7 +6,15 @@ handle_multi_value_headers, maybe_encode_body, ) -from mangum.types import Scope, Response, LambdaConfig, LambdaEvent, LambdaContext +from mangum.protocols import HTTPCycle +from mangum.types import ( + Cycle, + LambdaConfig, + LambdaContext, + LambdaEvent, + Response, + Scope, +) class LambdaAtEdge: @@ -31,6 +39,10 @@ def __init__( self.context = context self.config = config + @property + def cycle_cls(self) -> Type[Cycle]: + return HTTPCycle + @property def body(self) -> bytes: cf_request_body = self.event["Records"][0]["cf"]["request"].get("body", {}) diff --git a/mangum/protocols/__init__.py b/mangum/protocols/__init__.py index f8d8384..12571c3 100644 --- a/mangum/protocols/__init__.py +++ b/mangum/protocols/__init__.py @@ -1,4 +1,4 @@ from .http import HTTPCycle -from .lifespan import LifespanCycleState, LifespanCycle +from .lifespan import LifespanCycle, LifespanCycleState __all__ = ["HTTPCycle", "LifespanCycleState", "LifespanCycle"] diff --git a/mangum/protocols/http.py b/mangum/protocols/http.py index b43b11b..6ab6b37 100644 --- a/mangum/protocols/http.py +++ b/mangum/protocols/http.py @@ -3,8 +3,8 @@ import logging from io import BytesIO -from mangum.types import ASGI, Message, Scope, Response from mangum.exceptions import UnexpectedMessage +from mangum.types import ASGI, Message, Response, Scope class HTTPCycleState(enum.Enum): @@ -93,7 +93,6 @@ async def send(self, message: Message) -> None: self.state is HTTPCycleState.RESPONSE and message["type"] == "http.response.body" ): - body = message.get("body", b"") more_body = message.get("more_body", False) self.buffer.write(body) diff --git a/mangum/protocols/lifespan.py b/mangum/protocols/lifespan.py index ca87392..9b9b72d 100644 --- a/mangum/protocols/lifespan.py +++ b/mangum/protocols/lifespan.py @@ -4,8 +4,8 @@ from types import TracebackType from typing import Optional, Type +from mangum.exceptions import LifespanFailure, LifespanUnsupported, UnexpectedMessage from mangum.types import ASGI, LifespanMode, Message -from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage class LifespanCycleState(enum.Enum): @@ -98,14 +98,12 @@ async def run(self) -> None: async def receive(self) -> Message: """Awaited by the application to receive ASGI `lifespan` events.""" if self.state is LifespanCycleState.CONNECTING: - # Connection established. The next event returned by the queue will be # `lifespan.startup` to inform the application that the connection is # ready to receive lfiespan messages. self.state = LifespanCycleState.STARTUP elif self.state is LifespanCycleState.STARTUP: - # Connection shutting down. The next event returned by the queue will be # `lifespan.shutdown` to inform the application that the connection is now # closing so that it may perform cleanup. diff --git a/mangum/types.py b/mangum/types.py index 0ff436c..cc4b3f2 100644 --- a/mangum/types.py +++ b/mangum/types.py @@ -1,18 +1,19 @@ from __future__ import annotations from typing import ( - List, - Dict, Any, - Union, - Optional, - Sequence, - MutableMapping, Awaitable, Callable, + Dict, + List, + MutableMapping, + Optional, + Sequence, + Type, + Union, ) -from typing_extensions import Literal, Protocol, TypedDict, TypeAlias +from typing_extensions import Literal, Protocol, TypeAlias, TypedDict LambdaEvent = Dict[str, Any] QueryParams: TypeAlias = MutableMapping[str, Union[str, Sequence[str]]] @@ -120,6 +121,23 @@ class LambdaConfig(TypedDict): exclude_headers: List[str] +class Cycle(Protocol): + def __init__(self, scope: Scope, body: bytes) -> None: + ... # pragma: no cover + + def __call__(self, app: ASGI) -> Response: + ... # pragma: no cover + + async def run(self, app: ASGI) -> None: + ... # pragma: no cover + + async def receive(self) -> Message: + ... # pragma: no cover + + async def send(self, message: Message) -> None: + ... # pragma: no cover + + class LambdaHandler(Protocol): def __init__(self, *args: Any) -> None: ... # pragma: no cover @@ -130,6 +148,10 @@ def infer( ) -> bool: ... # pragma: no cover + @property + def cycle_cls(self) -> Type[Cycle]: + ... # pragma: no cover + @property def body(self) -> bytes: ... # pragma: no cover diff --git a/tests/handlers/test_custom.py b/tests/handlers/test_custom.py index c330bb6..e73c71f 100644 --- a/tests/handlers/test_custom.py +++ b/tests/handlers/test_custom.py @@ -1,10 +1,4 @@ -from mangum.types import ( - Scope, - Headers, - LambdaConfig, - LambdaContext, - LambdaEvent, -) +from mangum.types import Headers, LambdaConfig, LambdaContext, LambdaEvent, Scope class CustomHandler: diff --git a/tests/test_http.py b/tests/test_http.py index 4179805..126e98d 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -2,11 +2,9 @@ import gzip import json -import pytest - import brotli +import pytest from brotli_asgi import BrotliMiddleware - from starlette.applications import Starlette from starlette.middleware.gzip import GZipMiddleware from starlette.responses import PlainTextResponse diff --git a/tests/test_lifespan.py b/tests/test_lifespan.py index 2bac53e..682426a 100644 --- a/tests/test_lifespan.py +++ b/tests/test_lifespan.py @@ -1,15 +1,13 @@ import logging import pytest - +from quart import Quart from starlette.applications import Starlette from starlette.responses import PlainTextResponse from mangum import Mangum from mangum.exceptions import LifespanFailure -from quart import Quart - @pytest.mark.parametrize( "mock_aws_api_gateway_event,lifespan",