-
-
Notifications
You must be signed in to change notification settings - Fork 717
/
proxy_headers.py
84 lines (71 loc) · 3.18 KB
/
proxy_headers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
This middleware can be used when a known proxy is fronting the application,
and is trusted to be properly setting the `X-Forwarded-Proto` and
`X-Forwarded-For` headers with the connecting client information.
Modifies the `client` and `scheme` information so that they reference
the connecting client, rather that the connecting proxy.
https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies
"""
from typing import List, Optional, Tuple, Union, cast
from uvicorn._types import (
ASGI3Application,
ASGIReceiveCallable,
ASGISendCallable,
HTTPScope,
Scope,
WebSocketScope,
)
class ProxyHeadersMiddleware:
def __init__(
self,
app: "ASGI3Application",
trusted_hosts: Union[List[str], str] = "127.0.0.1",
) -> None:
self.app = app
if isinstance(trusted_hosts, str):
self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")}
else:
self.trusted_hosts = set(trusted_hosts)
self.always_trust = "*" in self.trusted_hosts
def get_trusted_client_host(
self, x_forwarded_for_hosts: List[str]
) -> Optional[str]:
if self.always_trust:
return x_forwarded_for_hosts[0]
for host in reversed(x_forwarded_for_hosts):
if host not in self.trusted_hosts:
return host
return None
async def __call__(
self, scope: "Scope", receive: "ASGIReceiveCallable", send: "ASGISendCallable"
) -> None:
if scope["type"] in ("http", "websocket"):
scope = cast(Union["HTTPScope", "WebSocketScope"], scope)
client_addr: Optional[Tuple[str, int]] = scope.get("client")
client_host = client_addr[0] if client_addr else None
if self.always_trust or client_host in self.trusted_hosts:
headers = dict(scope["headers"])
if b"x-forwarded-proto" in headers:
# Determine if the incoming request was http or https based on
# the X-Forwarded-Proto header.
x_forwarded_proto = (
headers[b"x-forwarded-proto"].decode("latin1").strip()
)
if scope["type"] == "websocket":
scope["scheme"] = (
"wss" if x_forwarded_proto == "https" else "ws"
)
else:
scope["scheme"] = x_forwarded_proto
if b"x-forwarded-for" in headers:
# Determine the client address from the last trusted IP in the
# X-Forwarded-For header. We've lost the connecting client's port
# information by now, so only include the host.
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
x_forwarded_for_hosts = [
item.strip() for item in x_forwarded_for.split(",")
]
host = self.get_trusted_client_host(x_forwarded_for_hosts)
port = 0
scope["client"] = (host, port) # type: ignore[arg-type]
return await self.app(scope, receive, send)