diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index 7dba44f6b..11960b062 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -3,9 +3,22 @@ import httpx import pytest +from tests.protocols.test_http import HTTP_PROTOCOLS from tests.response import Response +from tests.utils import run_server from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope +from uvicorn.config import Config from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware +from uvicorn.protocols.websockets.wsproto_impl import WSProtocol + +try: + import websockets.client + + from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol + + WS_PROTOCOLS = [WSProtocol, WebSocketProtocol] +except ImportError: # pragma: nocover + WS_PROTOCOLS = [] async def app( @@ -103,3 +116,34 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None: response = await client.get("/", headers=headers) assert response.status_code == 200 assert response.text == "Remote: https://1.2.3.4:0" + + +@pytest.mark.anyio +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +@pytest.mark.skipif(not WS_PROTOCOLS, reason="websockets module not installed.") +async def test_proxy_headers_websocket_x_forwarded_proto( + ws_protocol_cls, http_protocol_cls, unused_tcp_port: int +) -> None: + async def websocket_app(scope, receive, send): + scheme = scope["scheme"] + host, port = scope["client"] + addr = "%s://%s:%d" % (scheme, host, port) + await send({"type": "websocket.accept"}) + await send({"type": "websocket.send", "text": addr}) + + app_with_middleware = ProxyHeadersMiddleware(websocket_app, trusted_hosts="*") + config = Config( + app=app_with_middleware, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + + async with run_server(config): + url = f"ws://127.0.0.1:{unused_tcp_port}" + headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"} + async with websockets.client.connect(url, extra_headers=headers) as websocket: + data = await websocket.recv() + assert data == "wss://1.2.3.4:0" diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 3a1e53399..28277e1d6 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -59,8 +59,15 @@ async def __call__( if b"x-forwarded-proto" in headers: # Determine if the incoming request was http or https based on # the X-Forwarded-Proto header. - x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1") - scope["scheme"] = x_forwarded_proto.strip() + x_forwarded_proto = ( + headers[b"x-forwarded-proto"].decode("latin1").strip() + ) + if scope["type"] == "websocket": + scope["scheme"] = ( + "wss" if x_forwarded_proto == "https" else "ws" + ) + else: + scope["scheme"] = x_forwarded_proto if b"x-forwarded-for" in headers: # Determine the client address from the last trusted IP in the