Skip to content

Commit

Permalink
Merge pull request #105 from slavugan/middleware_call_order
Browse files Browse the repository at this point in the history
changed middleware call order
  • Loading branch information
Goldziher committed Apr 25, 2022
2 parents 2eeb642 + a14eacc commit 2fd16a5
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 10 deletions.
27 changes: 23 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ sqlalchemy = {extras = ["mypy"], version = "*"}
Jinja2 = "*"
Mako = "*"
freezegun = "*"
pytest-mock = "^3.7.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
13 changes: 8 additions & 5 deletions starlite/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,15 +257,18 @@ def build_middleware_stack(
Builds the middleware stack by passing middlewares in a specific order
"""
current_app: ASGIApp = self.asgi_router
if allowed_hosts:
current_app = TrustedHostMiddleware(app=current_app, allowed_hosts=allowed_hosts)
if cors_config:
current_app = CORSMiddleware(app=current_app, **cors_config.dict())
for middleware in user_middleware:
# last added middleware will be on the top of stack and it will therefore be called first.
# we therefore need to reverse the middlewares to keep the call order according to
# the middlewares' list provided by the user
for middleware in reversed(user_middleware):
if isinstance(middleware, Middleware):
current_app = middleware.cls(app=current_app, **middleware.options)
else:
current_app = middleware(app=current_app)
if allowed_hosts:
current_app = TrustedHostMiddleware(app=current_app, allowed_hosts=allowed_hosts)
if cors_config:
current_app = CORSMiddleware(app=current_app, **cors_config.dict())
return current_app

def default_http_exception_handler(self, request: Request, exc: Exception) -> StarletteResponse:
Expand Down
2 changes: 1 addition & 1 deletion starlite/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
status_code=status_code,
headers=headers or {},
media_type=media_type,
background=background, # type: ignore
background=background,
)

@staticmethod
Expand Down
21 changes: 21 additions & 0 deletions tests/app/test_middleware_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from _pytest.logging import LogCaptureFixture
from pydantic import BaseModel
from pytest_mock import MockerFixture
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.middleware.cors import CORSMiddleware
Expand Down Expand Up @@ -135,3 +136,23 @@ def test_request_body_logging_middleware(caplog: LogCaptureFixture) -> None:
response = client.post("/", json={"name": "moishe zuchmir", "age": 40, "programmer": True})
assert response.status_code == 201
assert "test logging" in caplog.text


def test_middleware_call_order(mocker: MockerFixture) -> None:
"""Test that middlewares are called in the order they have been passed"""
m1 = mocker.spy(BaseMiddlewareRequestLoggingMiddleware, "dispatch")
m2 = mocker.spy(CustomHeaderMiddleware, "dispatch")
manager = mocker.Mock()
manager.attach_mock(m1, "m1")
manager.attach_mock(m2, "m2")

client = create_test_client(
route_handlers=[handler],
middleware=[
BaseMiddlewareRequestLoggingMiddleware,
Middleware(CustomHeaderMiddleware, header_value="Customized"),
],
)
client.get("/")

manager.assert_has_calls([mocker.call.m1(*m1.call_args[0]), mocker.call.m2(*m2.call_args[0])], any_order=False)

0 comments on commit 2fd16a5

Please sign in to comment.