Skip to content

Commit

Permalink
Merge pull request #230 from lsst-sqre/tickets/DM-42527
Browse files Browse the repository at this point in the history
DM-42527: Rewrite middleware as pure ASGI middleware
  • Loading branch information
rra committed Jan 17, 2024
2 parents afef104 + 93f3c1b commit 3ed26c8
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 88 deletions.
3 changes: 3 additions & 0 deletions changelog.d/20240116_164958_rra_DM_42527.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### Bug fixes

- Rewrite `CaseInsensitiveQueryMiddleware` and `XForwardedMiddleware` as pure ASGI middleware rather than using the Starlette `BaseHTTPMiddleware` class. The latter seems to be behind some poor error reporting of application exceptions, has caused problems in the past due to its complexity, and is not used internally by Starlette middleware.
35 changes: 20 additions & 15 deletions src/safir/middleware/ivoa.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
"""Middleware for IVOA services."""

from collections.abc import Awaitable, Callable
from urllib.parse import urlencode
from copy import copy
from urllib.parse import parse_qsl, urlencode

from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send

__all__ = ["CaseInsensitiveQueryMiddleware"]


class CaseInsensitiveQueryMiddleware(BaseHTTPMiddleware):
class CaseInsensitiveQueryMiddleware:
"""Make query parameter keys all lowercase.
Unfortunately, several IVOA standards require that query parameters be
case-insensitive, which is not supported by modern HTTP web frameworks.
This middleware attempts to work around this by lowercasing the query
parameter keys before the request is processed, allowing normal FastAPI
query parsing to then work without regard for case. This, in turn,
permits FastAPI to perform input validation on GET parameters, which would
query parsing to then work without regard for case. This, in turn, permits
FastAPI to perform input validation on GET parameters, which would
otherwise only happen if the case used in the request happened to match
the case used in the function signature.
Expand All @@ -28,11 +27,17 @@ class CaseInsensitiveQueryMiddleware(BaseHTTPMiddleware):
Based on `fastapi#826 <https://github.com/tiangolo/fastapi/issues/826>`__.
"""

async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
params = [(k.lower(), v) for k, v in request.query_params.items()]
request.scope["query_string"] = urlencode(params).encode()
return await call_next(request)
def __init__(self, app: ASGIApp) -> None:
self._app = app

async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
if scope["type"] != "http" or not scope.get("query_string"):
await self._app(scope, receive, send)
return
scope = copy(scope)
params = [(k.lower(), v) for k, v in parse_qsl(scope["query_string"])]
scope["query_string"] = urlencode(params).encode()
await self._app(scope, receive, send)
return
121 changes: 54 additions & 67 deletions src/safir/middleware/x_forwarded.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,78 +2,65 @@

from __future__ import annotations

from collections.abc import Awaitable, Callable
from copy import copy
from ipaddress import _BaseAddress, _BaseNetwork, ip_address

from fastapi import FastAPI, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.datastructures import Headers
from starlette.types import ASGIApp, Receive, Scope, Send

__all__ = ["XForwardedMiddleware"]


class XForwardedMiddleware(BaseHTTPMiddleware):
"""Middleware to update the request based on ``X-Forwarded-For``.
class XForwardedMiddleware:
"""ASGI middleware to update the request based on ``X-Forwarded-For``.
The remote IP address will be replaced with the right-most IP address in
``X-Forwarded-For`` that is not contained within one of the trusted
proxy networks.
If ``X-Forwarded-For`` is found and ``X-Forwarded-Proto`` is also present,
the corresponding entry of ``X-Forwarded-Proto`` is used to replace the
scheme in the request scope. If ``X-Forwarded-Proto`` only has one entry
scheme in the request scope. If ``X-Forwarded-Proto`` only has one entry
(ingress-nginx has this behavior), that one entry will become the new
scheme in the request scope.
The contents of ``X-Forwarded-Host`` will be stored as ``forwarded_host``
in the request state if it and ``X-Forwarded-For`` are present. Normally
in the request state if it and ``X-Forwarded-For`` are present. Normally
this is not needed since NGINX will pass the original ``Host`` header
without modification.
Parameters
----------
proxies
The networks of the trusted proxies. If not specified, defaults to
the empty list, which means only the immediately upstream proxy will
be trusted.
The networks of the trusted proxies. If not specified, defaults to the
empty list, which means only the immediately upstream proxy will be
trusted.
"""

def __init__(
self, app: FastAPI, *, proxies: list[_BaseNetwork] | None = None
self, app: ASGIApp, *, proxies: list[_BaseNetwork] | None = None
) -> None:
super().__init__(app)
if proxies:
self.proxies = proxies
else:
self.proxies = []

async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Middleware to update the request based on ``X-Forwarded-For``.
Parameters
----------
request
The incoming request.
call_next
The next step in the processing stack.
self._app = app
self._proxies = proxies if proxies else []

Returns
-------
``fastapi.Response``
The response with additional information about proxy headers.
"""
forwarded_for = list(reversed(self._get_forwarded_for(request)))
async def __call__(
self, scope: Scope, receive: Receive, send: Send
) -> None:
if scope["type"] != "http":
await self._app(scope, receive, send)
return
scope = copy(scope)
scope.setdefault("state", {})
headers = Headers(scope=scope)
forwarded_for = list(reversed(self._get_forwarded_for(headers)))
if not forwarded_for:
request.state.forwarded_host = None
request.state.forwarded_proto = None
return await call_next(request)
scope["state"]["forwarded_host"] = None
await self._app(scope, receive, send)
return

client = None
for n, ip in enumerate(forwarded_for):
if any(ip in network for network in self.proxies):
if any(ip in network for network in self._proxies):
continue
client = str(ip)
index = n
Expand All @@ -85,45 +72,45 @@ async def dispatch(
client = str(forwarded_for[-1])
index = -1

# Update the request's understanding of the client IP. This uses an
# undocumented interface; hopefully it will keep working.
if request.client:
request.scope["client"] = (client, request.client.port)
# Update the request's understanding of the client IP.
if scope.get("client"):
scope["client"] = (client, scope["client"][1])
else:
request.scope["client"] = (client, None)
scope["client"] = (client, None)

# Ideally this should take the scheme corresponding to the entry in
# X-Forwarded-For that was chosen, but some proxies (the Kubernetes
# NGINX ingress, for example) only retain one element in
# X-Forwarded-Proto. In that case, use what we have.
proto = list(reversed(self._get_forwarded_proto(request)))
# X-Forwarded-Proto. In that case, use what we have.
proto = list(reversed(self._get_forwarded_proto(headers)))
if proto:
if index >= len(proto):
index = -1
request.scope["scheme"] = proto[index]
scope["scheme"] = proto[index]

# Rather than one entry per hop, NGINX seems to add only a single
# X-Forwarded-Host header with the original hostname.
request.state.forwarded_host = self._get_forwarded_host(request)
# Record what appears to be the client host for logging purposes.
scope["state"]["forwarded_host"] = self._get_forwarded_host(headers)

return await call_next(request)
# Perform the rest of the request processing.
await self._app(scope, receive, send)
return

def _get_forwarded_for(self, request: Request) -> list[_BaseAddress]:
def _get_forwarded_for(self, headers: Headers) -> list[_BaseAddress]:
"""Retrieve the ``X-Forwarded-For`` entries from the request.
Parameters
----------
request
The incoming request.
scope
Request headers.
Returns
-------
list of ipaddress._BaseAddress
The list of addresses found in the header. If there are multiple
The list of addresses found in the header. If there are multiple
``X-Forwarded-For`` headers, we don't know which one is correct,
so act as if there are no headers.
"""
forwarded_for_str = request.headers.getlist("X-Forwarded-For")
forwarded_for_str = headers.getlist("X-Forwarded-For")
if not forwarded_for_str or len(forwarded_for_str) > 1:
return []
return [
Expand All @@ -132,43 +119,43 @@ def _get_forwarded_for(self, request: Request) -> list[_BaseAddress]:
if addr
]

def _get_forwarded_host(self, request: Request) -> str | None:
def _get_forwarded_host(self, headers: Headers) -> str | None:
"""Retrieve the ``X-Forwarded-Host`` header.
Parameters
----------
request
The incoming request.
headers
Request headers.
Returns
-------
str
The value of the ``X-Forwarded-Host`` header, if present and if
there is only one header. If there are multiple
there is only one header. If there are multiple
``X-Forwarded-Host`` headers, we don't know which one is correct,
so act as if there are no headers.
"""
forwarded_host = request.headers.getlist("X-Forwarded-Host")
forwarded_host = headers.getlist("X-Forwarded-Host")
if not forwarded_host or len(forwarded_host) > 1:
return None
return forwarded_host[0].strip()

def _get_forwarded_proto(self, request: Request) -> list[str]:
def _get_forwarded_proto(self, headers: Headers) -> list[str]:
"""Retrieve the ``X-Forwarded-Proto`` entries from the request.
Parameters
----------
request
The incoming request.
headers
Request headers.
Returns
-------
list of str
The list of schemes found in the header. If there are multiple
The list of schemes found in the header. If there are multiple
``X-Forwarded-Proto`` headers, we don't know which one is correct,
so act as if there are no headers.
"""
forwarded_proto_str = request.headers.getlist("X-Forwarded-Proto")
forwarded_proto_str = headers.getlist("X-Forwarded-Proto")
if not forwarded_proto_str or len(forwarded_proto_str) > 1:
return []
return [p.strip() for p in forwarded_proto_str[0].split(",")]
25 changes: 24 additions & 1 deletion tests/middleware/ivoa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from __future__ import annotations

from typing import Annotated

import pytest
from fastapi import FastAPI
from fastapi import FastAPI, Query
from httpx import AsyncClient

from safir.middleware.ivoa import CaseInsensitiveQueryMiddleware
Expand All @@ -24,6 +26,16 @@ async def test_case_insensitive() -> None:
async def handler(param: str) -> dict[str, str]:
return {"param": param}

@app.get("/simple")
async def simple_handler() -> dict[str, str]:
return {"foo": "bar"}

@app.get("/list")
async def list_handler(
param: Annotated[list[str], Query()]
) -> dict[str, list[str]]:
return {"param": param}

async with AsyncClient(app=app, base_url="https://example.com") as client:
r = await client.get("/", params={"param": "foo"})
assert r.status_code == 200
Expand All @@ -39,3 +51,14 @@ async def handler(param: str) -> dict[str, str]:

r = await client.get("/", params={"paramX": "foo"})
assert r.status_code == 422

r = await client.get("/simple")
assert r.status_code == 200
assert r.json() == {"foo": "bar"}

r = await client.get(
"/list",
params=[("param", "foo"), ("PARAM", "BAR"), ("parAM", "baZ")],
)
assert r.status_code == 200
assert r.json() == {"param": ["foo", "BAR", "baZ"]}
11 changes: 6 additions & 5 deletions tests/middleware/x_forwarded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from fastapi import FastAPI, Request
from httpx import AsyncClient
from starlette.datastructures import Headers

from safir.middleware.x_forwarded import XForwardedMiddleware

Expand Down Expand Up @@ -159,7 +160,7 @@ async def test_too_many_headers() -> None:
end. Instead, test by generating a mock request and then calling the
underling middleware functions directly.
"""
state = {
scope = {
"type": "http",
"headers": [
("X-Forwarded-For", "10.10.10.10"),
Expand All @@ -170,9 +171,9 @@ async def test_too_many_headers() -> None:
("X-Forwarded-Host", "example.com"),
],
}
request = Request(state)
headers = Headers(scope=scope)
app = FastAPI()
middleware = XForwardedMiddleware(app, proxies=[ip_network("10.0.0.0/8")])
assert middleware._get_forwarded_for(request) == []
assert middleware._get_forwarded_proto(request) == []
assert not middleware._get_forwarded_host(request)
assert middleware._get_forwarded_for(headers) == []
assert middleware._get_forwarded_proto(headers) == []
assert not middleware._get_forwarded_host(headers)

0 comments on commit 3ed26c8

Please sign in to comment.