Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 674 replace starlette middleware types #718

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/reference/middleware/0-base.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@
options:
members:
- __init__

::: starlite.middleware.AbstractMiddleware
options:
members:
- __init__
14 changes: 14 additions & 0 deletions docs/reference/middleware/1-http-middleware.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# HTTP Middleware

::: starlite.middleware.CallNext

::: starlite.middleware.DispatchCallable

::: starlite.middleware.BaseHTTPMiddleware
options:
members:
- __init__
- __call__
- dispatch

::: starlite.middleware.http_middleware
7 changes: 4 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,10 @@ nav:
- reference/handlers/3-asgi-handlers.md
- Middleware:
- reference/middleware/0-base.md
- reference/middleware/1-authentication-middleware.md
- reference/middleware/2-logging-middleware.md
- reference/middleware/3-rate-limit-middleware.md
- reference/middleware/1-http-middleware.md
- reference/middleware/2-authentication-middleware.md
- reference/middleware/3-logging-middleware.md
- reference/middleware/4-rate-limit-middleware.md
- Session Middleware:
- reference/middleware/session-middleware/0-middleware.md
- Backends:
Expand Down
19 changes: 15 additions & 4 deletions starlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,17 @@
route,
websocket,
)
from starlite.middleware.authentication import (
from starlite.middleware import (
AbstractAuthenticationMiddleware,
AbstractMiddleware,
AuthenticationResult,
BaseHTTPMiddleware,
CallNext,
DefineMiddleware,
DispatchCallable,
MiddlewareProtocol,
http_middleware,
)
from starlite.middleware.base import DefineMiddleware, MiddlewareProtocol
from starlite.openapi.controller import OpenAPIController
from starlite.openapi.datastructures import ResponseSpec
from starlite.params import Body, Dependency, Parameter
Expand All @@ -85,23 +91,26 @@
"ASGIRoute",
"ASGIRouteHandler",
"AbstractAuthenticationMiddleware",
"AbstractMiddleware",
"AuthenticationResult",
"BackgroundTask",
"BackgroundTasks",
"BaseHTTPMiddleware",
"BaseLoggingConfig",
"BaseRoute",
"BaseRouteHandler",
"Body",
"CORSConfig",
"CSRFConfig",
"CacheConfig",
"CallNext",
"CompressionConfig",
"Controller",
"Cookie",
"create_test_client",
"DTOFactory",
"DefineMiddleware",
"Dependency",
"DispatchCallable",
"File",
"FormMultiDict",
"HTTPException",
Expand Down Expand Up @@ -143,17 +152,19 @@
"StructLoggingConfig",
"Template",
"TemplateConfig",
"TooManyRequestsException",
"TestClient",
"TooManyRequestsException",
"UploadFile",
"ValidationException",
"WebSocket",
"WebSocketException",
"WebSocketRoute",
"WebsocketRouteHandler",
"asgi",
"create_test_client",
"delete",
"get",
"http_middleware",
"patch",
"post",
"put",
Expand Down
17 changes: 16 additions & 1 deletion starlite/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,38 @@
AbstractAuthenticationMiddleware,
AuthenticationResult,
)
from starlite.middleware.base import DefineMiddleware, MiddlewareProtocol
from starlite.middleware.base import (
AbstractMiddleware,
DefineMiddleware,
MiddlewareProtocol,
)
from starlite.middleware.compression import CompressionMiddleware
from starlite.middleware.csrf import CSRFMiddleware
from starlite.middleware.exceptions import ExceptionHandlerMiddleware
from starlite.middleware.http import (
BaseHTTPMiddleware,
CallNext,
DispatchCallable,
http_middleware,
)
from starlite.middleware.logging import LoggingMiddleware, LoggingMiddlewareConfig
from starlite.middleware.rate_limit import RateLimitConfig, RateLimitMiddleware

__all__ = (
"AbstractAuthenticationMiddleware",
"AbstractMiddleware",
"AuthenticationResult",
"BaseHTTPMiddleware",
"CSRFMiddleware",
"CallNext",
"CompressionMiddleware",
"DefineMiddleware",
"DispatchCallable",
"ExceptionHandlerMiddleware",
"LoggingMiddleware",
"LoggingMiddlewareConfig",
"MiddlewareProtocol",
"RateLimitConfig",
"RateLimitMiddleware",
"http_middleware",
)
8 changes: 5 additions & 3 deletions starlite/middleware/authentication.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Pattern, Union
from typing import TYPE_CHECKING, Any, List, Optional, Pattern, Set, Union

from pydantic import BaseConfig, BaseModel

from starlite.connection import ASGIConnection
from starlite.enums import ScopeType
from starlite.middleware.util import should_bypass_middleware
from starlite.middleware.utils import should_bypass_middleware
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should update CSRF middleware as well in this PR. It also relies on exclude opts to skip protection on certain routes.

I think @provinzkraut simply forgot to uppdate it in #630 when he introduced should_bypass_middleware function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, but it should be done in a followup

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, I did!


if TYPE_CHECKING:
from typing_extensions import Literal

from starlite.types import ASGIApp, Receive, Scope, Send


Expand All @@ -29,7 +31,7 @@ class Config(BaseConfig):


class AbstractAuthenticationMiddleware(ABC):
scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET}
scopes: Set["Literal[ScopeType.HTTP, ScopeType.WEBSOCKET]"] = {ScopeType.HTTP, ScopeType.WEBSOCKET}
"""
Scopes supported by the middleware.
"""
Expand Down
111 changes: 108 additions & 3 deletions starlite/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
from typing import TYPE_CHECKING, Any, Callable

from typing_extensions import Protocol, runtime_checkable
import re
from abc import ABCMeta, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Pattern,
Set,
Tuple,
Type,
Union,
)

from typing_extensions import Literal, Protocol, runtime_checkable

from starlite.enums import ScopeType
from starlite.middleware.utils import should_bypass_middleware

if TYPE_CHECKING:
from starlite.types.asgi_types import ASGIApp, Receive, Scope, Send
Expand Down Expand Up @@ -65,3 +82,91 @@ def __call__(self, app: "ASGIApp") -> "ASGIApp":
"""

return self.middleware(*self.args, app=app, **self.kwargs)


class _AbstractMiddlewareMetaClass(ABCMeta):
app: "ASGIApp"
exclude: Optional[Union[str, List[str]]]
exclude_opt_key: Optional[str]
scopes: Set["Literal[ScopeType.HTTP, ScopeType.WEBSOCKET]"]

def __new__(cls, name: str, bases: Tuple[Type, ...], namespace: Dict[str, Callable], **kwargs: Any) -> Any:
Goldziher marked this conversation as resolved.
Show resolved Hide resolved
"""This metaclass override intercepts the creation of subclasses and
wraps their call method.

Notes:
- This is somewhat magical, and as such - suboptimal. There is no other way though to wrap the __call__
method because it is a read only attribute on class instances and classes that are already created.

Args:
name: The name of class that is being created.
bases: A tuple of super classes.
namespace: A mapping of method names to callables.
**kwargs: Any other kwargs passed to 'type()' call.
"""
if name != "AbstractMiddleware":
call_method = namespace.pop("__call__")

async def wrapped_call(self: Any, scope: "Scope", receive: "Receive", send: "Send") -> None:
if should_bypass_middleware(
scope=scope,
scopes=getattr(self, "scopes", {ScopeType.HTTP, ScopeType.WEBSOCKET}), # pyright: ignore
exclude_path_pattern=getattr(self, "exclude_pattern", None),
exclude_opt_key=getattr(self, "exclude_opt_key", None),
):
await self.app(scope, receive, send)
else:
await call_method(self, scope, receive, send)

namespace["__call__"] = wrapped_call

return super().__new__(cls, name, bases, namespace, **kwargs)
Goldziher marked this conversation as resolved.
Show resolved Hide resolved


class AbstractMiddleware(metaclass=_AbstractMiddlewareMetaClass):
__slots__ = ("app", "scopes", "exclude_opt_key", "exclude_pattern")

def __init__(
self,
app: "ASGIApp",
exclude: Optional[Union[str, List[str]]] = None,
exclude_opt_key: Optional[str] = None,
scopes: Optional[Set["Literal[ScopeType.HTTP, ScopeType.WEBSOCKET]"]] = None,
) -> None:
"""

Args:
app: The 'next' ASGI app to call.
exclude: A pattern or list of patterns to match against a request's path.
If a match is found, the middleware will be skipped. .
exclude_opt_key: An identifier that is set in the route handler
'opt' key which allows skipping the middleware.
scopes: ASGI scope types, should be a set including
either or both 'ScopeType.HTTP' and 'ScopeType.WEBSOCKET'.
"""
self.app = app
self.scopes = scopes or {ScopeType.HTTP, ScopeType.WEBSOCKET}

self.exclude_opt_key = exclude_opt_key
self.exclude_pattern: Optional[Pattern] = None
if exclude is not None:
self.exclude_pattern = re.compile("|".join(exclude)) if isinstance(exclude, list) else re.compile(exclude)

@abstractmethod
async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> None:
"""Executes the ASGI middleware.

Called by the previous middleware in the stack if a response is not awaited prior.

Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its
closure.

Args:
scope: The ASGI connection scope.
receive: The ASGI receive function.
send: The ASGI send function.

Returns:
None
"""
raise NotImplementedError("abstract method must be implemented")
Loading