Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add middleware per Route/WebSocketRoute #2349

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
Expand All @@ -236,6 +237,10 @@ def __init__(
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)

if methods is None:
self.methods = None
else:
Expand Down Expand Up @@ -309,6 +314,7 @@ def __init__(
endpoint: typing.Callable[..., typing.Any],
*,
name: typing.Optional[str] = None,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
Expand All @@ -325,6 +331,10 @@ def __init__(
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)

self.path_regex, self.path_format, self.param_convertors = compile_path(path)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def delete( # type: ignore[override]

def websocket_connect(
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
) -> typing.Any:
) -> "WebSocketTestSession":
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
Expand Down
53 changes: 52 additions & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,18 @@ def assert_middleware_header_route(request: Request) -> Response:
return Response()


route_with_middleware = Starlette(
routes=[
Route(
"/http",
endpoint=assert_middleware_header_route,
methods=["GET"],
middleware=[Middleware(AddHeadersMiddleware)],
),
Route("/home", homepage),
]
)

mounted_routes_with_middleware = Starlette(
routes=[
Mount(
Expand Down Expand Up @@ -960,9 +972,10 @@ def assert_middleware_header_route(request: Request) -> Response:
[
mounted_routes_with_middleware,
mounted_app_with_middleware,
route_with_middleware,
],
)
def test_mount_middleware(
def test_base_route_middleware(
test_client_factory: typing.Callable[..., TestClient],
app: Starlette,
) -> None:
Expand Down Expand Up @@ -1076,6 +1089,44 @@ async def modified_send(msg: Message) -> None:
assert "X-Mounted" in resp.headers


def test_websocket_route_middleware(
test_client_factory: typing.Callable[..., TestClient]
):
async def websocket_endpoint(session: WebSocket):
await session.accept()
await session.send_text("Hello, world!")
await session.close()

class WebsocketMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def modified_send(msg: Message) -> None:
if msg["type"] == "websocket.accept":
msg["headers"].append((b"X-Test", b"Set by middleware"))
await send(msg)

await self.app(scope, receive, modified_send)

app = Starlette(
routes=[
WebSocketRoute(
"/ws",
endpoint=websocket_endpoint,
middleware=[Middleware(WebsocketMiddleware)],
)
]
)

client = test_client_factory(app)

with client.websocket_connect("/ws") as websocket:
text = websocket.receive_text()
assert text == "Hello, world!"
assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")]


def test_route_repr() -> None:
route = Route("/welcome", endpoint=homepage)
assert (
Expand Down