Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow the use of a custom protocol #286

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 9 additions & 9 deletions mangum/adapter.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import logging
from itertools import chain
from contextlib import ExitStack
from itertools import chain
from typing import List, Optional, Type

from mangum.protocols import HTTPCycle, LifespanCycle
from mangum.handlers import ALB, HTTPGateway, APIGateway, LambdaAtEdge
from mangum.exceptions import ConfigurationError
from mangum.handlers import ALB, APIGateway, HTTPGateway, LambdaAtEdge
from mangum.protocols import LifespanCycle
from mangum.types import (
ASGI,
LifespanMode,
LambdaConfig,
LambdaEvent,
LambdaContext,
LambdaEvent,
LambdaHandler,
LifespanMode,
)


logger = logging.getLogger("mangum")


Expand Down Expand Up @@ -65,6 +64,7 @@ def infer(self, event: LambdaEvent, context: LambdaContext) -> LambdaHandler:
for handler_cls in chain(self.custom_handlers, HANDLERS):
if handler_cls.infer(event, context, self.config):
return handler_cls(event, context, self.config)

raise RuntimeError( # pragma: no cover
"The adapter was unable to infer a handler to use for the event. This "
"is likely related to how the Lambda function was invoked. (Are you "
Expand All @@ -79,9 +79,9 @@ def __call__(self, event: LambdaEvent, context: LambdaContext) -> dict:
lifespan_cycle = LifespanCycle(self.app, self.lifespan)
stack.enter_context(lifespan_cycle)

http_cycle = HTTPCycle(handler.scope, handler.body)
http_response = http_cycle(self.app)
cycle = handler.cycle_cls(handler.scope, handler.body)
response = cycle(self.app)

return handler(http_response)
return handler(response)

assert False, "unreachable" # pragma: no cover
3 changes: 1 addition & 2 deletions mangum/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from mangum.handlers.api_gateway import APIGateway, HTTPGateway
from mangum.handlers.alb import ALB
from mangum.handlers.api_gateway import APIGateway, HTTPGateway
from mangum.handlers.lambda_at_edge import LambdaAtEdge


__all__ = ["APIGateway", "HTTPGateway", "ALB", "LambdaAtEdge"]
17 changes: 11 additions & 6 deletions mangum/handlers/alb.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from itertools import islice
from typing import Dict, Generator, List, Tuple
from urllib.parse import urlencode, unquote, unquote_plus
from typing import Dict, Generator, List, Tuple, Type
from urllib.parse import unquote, unquote_plus, urlencode

from mangum.handlers.utils import (
get_server_and_port,
handle_base64_response_body,
handle_exclude_headers,
maybe_encode_body,
)
from mangum.protocols import HTTPCycle
from mangum.types import (
Response,
Scope,
Cycle,
LambdaConfig,
LambdaEvent,
LambdaContext,
LambdaEvent,
QueryParams,
Response,
Scope,
)


Expand Down Expand Up @@ -95,6 +97,10 @@ def __init__(
self.context = context
self.config = config

@property
def cycle_cls(self) -> Type[Cycle]:
return HTTPCycle

@property
def body(self) -> bytes:
return maybe_encode_body(
Expand All @@ -104,7 +110,6 @@ def body(self) -> bytes:

@property
def scope(self) -> Scope:

headers = transform_headers(self.event)
list_headers = [list(x) for x in headers]
# Unique headers. If there are duplicates, it will use the last defined.
Expand Down
18 changes: 14 additions & 4 deletions mangum/handlers/api_gateway.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Type
from urllib.parse import urlencode

from mangum.handlers.utils import (
Expand All @@ -9,13 +9,15 @@
maybe_encode_body,
strip_api_gateway_path,
)
from mangum.protocols import HTTPCycle
from mangum.types import (
Response,
LambdaConfig,
Cycle,
Headers,
LambdaEvent,
LambdaConfig,
LambdaContext,
LambdaEvent,
QueryParams,
Response,
Scope,
)

Expand Down Expand Up @@ -78,6 +80,10 @@ def __init__(
self.context = context
self.config = config

@property
def cycle_cls(self) -> Type[Cycle]:
return HTTPCycle

@property
def body(self) -> bytes:
return maybe_encode_body(
Expand Down Expand Up @@ -144,6 +150,10 @@ def __init__(
self.context = context
self.config = config

@property
def cycle_cls(self) -> Type[Cycle]:
return HTTPCycle

@property
def body(self) -> bytes:
return maybe_encode_body(
Expand Down
16 changes: 14 additions & 2 deletions mangum/handlers/lambda_at_edge.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from typing import Dict, List
from typing import Dict, List, Type

from mangum.handlers.utils import (
handle_base64_response_body,
handle_exclude_headers,
handle_multi_value_headers,
maybe_encode_body,
)
from mangum.types import Scope, Response, LambdaConfig, LambdaEvent, LambdaContext
from mangum.protocols import HTTPCycle
from mangum.types import (
Cycle,
LambdaConfig,
LambdaContext,
LambdaEvent,
Response,
Scope,
)


class LambdaAtEdge:
Expand All @@ -31,6 +39,10 @@ def __init__(
self.context = context
self.config = config

@property
def cycle_cls(self) -> Type[Cycle]:
return HTTPCycle

@property
def body(self) -> bytes:
cf_request_body = self.event["Records"][0]["cf"]["request"].get("body", {})
Expand Down
2 changes: 1 addition & 1 deletion mangum/protocols/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .http import HTTPCycle
from .lifespan import LifespanCycleState, LifespanCycle
from .lifespan import LifespanCycle, LifespanCycleState

__all__ = ["HTTPCycle", "LifespanCycleState", "LifespanCycle"]
3 changes: 1 addition & 2 deletions mangum/protocols/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
from io import BytesIO

from mangum.types import ASGI, Message, Scope, Response
from mangum.exceptions import UnexpectedMessage
from mangum.types import ASGI, Message, Response, Scope


class HTTPCycleState(enum.Enum):
Expand Down Expand Up @@ -93,7 +93,6 @@ async def send(self, message: Message) -> None:
self.state is HTTPCycleState.RESPONSE
and message["type"] == "http.response.body"
):

body = message.get("body", b"")
more_body = message.get("more_body", False)
self.buffer.write(body)
Expand Down
4 changes: 1 addition & 3 deletions mangum/protocols/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from types import TracebackType
from typing import Optional, Type

from mangum.exceptions import LifespanFailure, LifespanUnsupported, UnexpectedMessage
from mangum.types import ASGI, LifespanMode, Message
from mangum.exceptions import LifespanUnsupported, LifespanFailure, UnexpectedMessage


class LifespanCycleState(enum.Enum):
Expand Down Expand Up @@ -98,14 +98,12 @@ async def run(self) -> None:
async def receive(self) -> Message:
"""Awaited by the application to receive ASGI `lifespan` events."""
if self.state is LifespanCycleState.CONNECTING:

# Connection established. The next event returned by the queue will be
# `lifespan.startup` to inform the application that the connection is
# ready to receive lfiespan messages.
self.state = LifespanCycleState.STARTUP

elif self.state is LifespanCycleState.STARTUP:

# Connection shutting down. The next event returned by the queue will be
# `lifespan.shutdown` to inform the application that the connection is now
# closing so that it may perform cleanup.
Expand Down
36 changes: 29 additions & 7 deletions mangum/types.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from __future__ import annotations

from typing import (
List,
Dict,
Any,
Union,
Optional,
Sequence,
MutableMapping,
Awaitable,
Callable,
Dict,
List,
MutableMapping,
Optional,
Sequence,
Type,
Union,
)
from typing_extensions import Literal, Protocol, TypedDict, TypeAlias

from typing_extensions import Literal, Protocol, TypeAlias, TypedDict

LambdaEvent = Dict[str, Any]
QueryParams: TypeAlias = MutableMapping[str, Union[str, Sequence[str]]]
Expand Down Expand Up @@ -120,6 +121,23 @@ class LambdaConfig(TypedDict):
exclude_headers: List[str]


class Cycle(Protocol):
def __init__(self, scope: Scope, body: bytes) -> None:
... # pragma: no cover

def __call__(self, app: ASGI) -> Response:
... # pragma: no cover

async def run(self, app: ASGI) -> None:
... # pragma: no cover

async def receive(self) -> Message:
... # pragma: no cover

async def send(self, message: Message) -> None:
... # pragma: no cover


class LambdaHandler(Protocol):
def __init__(self, *args: Any) -> None:
... # pragma: no cover
Expand All @@ -130,6 +148,10 @@ def infer(
) -> bool:
... # pragma: no cover

@property
def cycle_cls(self) -> Type[Cycle]:
... # pragma: no cover

@property
def body(self) -> bytes:
... # pragma: no cover
Expand Down
8 changes: 1 addition & 7 deletions tests/handlers/test_custom.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from mangum.types import (
Scope,
Headers,
LambdaConfig,
LambdaContext,
LambdaEvent,
)
from mangum.types import Headers, LambdaConfig, LambdaContext, LambdaEvent, Scope


class CustomHandler:
Expand Down
4 changes: 1 addition & 3 deletions tests/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@
import gzip
import json

import pytest

import brotli
import pytest
from brotli_asgi import BrotliMiddleware

from starlette.applications import Starlette
from starlette.middleware.gzip import GZipMiddleware
from starlette.responses import PlainTextResponse
Expand Down
4 changes: 1 addition & 3 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import logging

import pytest

from quart import Quart
from starlette.applications import Starlette
from starlette.responses import PlainTextResponse

from mangum import Mangum
from mangum.exceptions import LifespanFailure

from quart import Quart


@pytest.mark.parametrize(
"mock_aws_api_gateway_event,lifespan",
Expand Down