Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support the WebSocket Denial Response ASGI extension #2041

Merged
merged 51 commits into from
Feb 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
83f863d
supply asgi_extensions to TestClient
kristjanvalur Feb 18, 2023
2e311be
Add WebSocket.send_response()
kristjanvalur Feb 18, 2023
936a075
Add response support for WebSocket testclient
kristjanvalur Feb 18, 2023
b0edc2b
fix test for filesystem line-endings
kristjanvalur Feb 18, 2023
a230ba5
lintint
kristjanvalur Feb 18, 2023
86ff3a1
support websocket.http.response extension by default
kristjanvalur Feb 18, 2023
9c6ddc4
Improve coverate
kristjanvalur Feb 18, 2023
9e9773c
Apply suggestions from code review
kristjanvalur Mar 17, 2023
73e29f7
Undo unrelated change
kristjanvalur Mar 17, 2023
5ee468a
fix incorrect error message
kristjanvalur Mar 17, 2023
e2f609d
Update starlette/websockets.py
kristjanvalur Mar 17, 2023
ce56778
formatting
kristjanvalur Mar 17, 2023
07c0e36
Re-introduce close-code and close-reason to WebSocketReject
kristjanvalur Mar 17, 2023
87c89a8
Make sure the "websocket.connect" message is received in tests
kristjanvalur Mar 22, 2023
1af1d8b
Deliver a websocket.disconnect message to the app even if it closes/r…
kristjanvalur Mar 22, 2023
48330b2
Add test for filling out missing `websocket.disconnect` code
kristjanvalur Mar 22, 2023
f075b02
Add rejection headers. Expand tests.
kristjanvalur Apr 3, 2023
f132f7b
Fix types, headers in message are `bytes` tuples.
kristjanvalur Apr 11, 2023
7e65b30
Minimal WebSocket Denial Response implementation
Kludex Dec 20, 2023
5ed6ade
Revert "Minimal WebSocket Denial Response implementation"
Kludex Jan 20, 2024
d882c19
Rename to send_denial_response and update documentation
Kludex Jan 20, 2024
2ffed52
Remove the app_disconnect_msg. This can be added later in a separate PR
kristjanvalur Jan 24, 2024
4c188cd
Remove status code 1005 from this PR
kristjanvalur Jan 28, 2024
24fc617
Assume that the application has tested for the extension before sendi…
kristjanvalur Jan 28, 2024
4fc34f1
Rename WebSocketReject to WebSocketDenialResponse
kristjanvalur Jan 28, 2024
ca0d0b2
Remove code and status from WebSocketDenialResponse.
kristjanvalur Jan 28, 2024
5bda56e
Raise an exception if attempting to send a http response and server d…
kristjanvalur Jan 28, 2024
71b76e3
WebSocketDenialClose and WebSocketDenialResponse
kristjanvalur Jan 28, 2024
a508902
Update starlette/testclient.py
kristjanvalur Jan 29, 2024
550b132
Revert "WebSocketDenialClose and WebSocketDenialResponse"
kristjanvalur Jan 29, 2024
3f76af9
Rename parameters, member variables
kristjanvalur Jan 29, 2024
ab85785
Use httpx.Response as the base for WebSocketDenialResponse.
kristjanvalur Jan 29, 2024
387cb15
Merge branch 'master' into kristjan/reject
Kludex Feb 3, 2024
f7f4497
Apply suggestions from code review
kristjanvalur Feb 3, 2024
7f0c902
Update sanity check message
kristjanvalur Feb 3, 2024
991abc5
Remove un-needed function
kristjanvalur Feb 3, 2024
e536a03
Expand error message test regex
kristjanvalur Feb 3, 2024
a561041
Add type hings to test methods
kristjanvalur Feb 3, 2024
1e70746
Add doc string to test.
kristjanvalur Feb 3, 2024
4f27aba
Fix mypy complaining about mismatching parent methods.
kristjanvalur Feb 3, 2024
f53efa6
nitpick & remove test
Kludex Feb 3, 2024
dd90d1b
Simplify the documentation
Kludex Feb 3, 2024
fa7c84e
Merge branch 'master' into kristjan/reject
Kludex Feb 3, 2024
9bd6db5
Update starlette/testclient.py
Kludex Feb 3, 2024
721a2d2
Update starlette/testclient.py
Kludex Feb 3, 2024
a70b7dd
Remove an unnecessary test
kristjanvalur Feb 4, 2024
0f4e1f2
there is no special "close because of rejection" in the testclient an…
kristjanvalur Feb 4, 2024
04817a6
Merge remote-tracking branch 'origin/kristjan/reject' into kristjan/r…
kristjanvalur Feb 4, 2024
bc3ecbd
Merge branch 'master' into kristjan/reject
kristjanvalur Feb 4, 2024
aa224dd
Merge branch 'master' into kristjan/reject
Kludex Feb 4, 2024
094f080
Merge branch 'master' into kristjan/reject
kristjanvalur Feb 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/websockets.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,17 @@ correctly updated.

* `await websocket.send(message)`
* `await websocket.receive()`

### Send Denial Response

If you call `websocket.close()` before calling `websocket.accept()` then
the server will automatically send a HTTP 403 error to the client.

If you want to send a different error response, you can use the
`websocket.send_denial_response()` method. This will send the response
and then close the connection.

* `await websocket.send_denial_response(response)`

This requires the ASGI server to support the WebSocket Denial Response
extension. If it is not supported a `RuntimeError` will be raised.
5 changes: 3 additions & 2 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,15 @@ def delete_cookie(
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
prefix = "websocket." if scope["type"] == "websocket" else ""
await send(
{
"type": "http.response.start",
"type": prefix + "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
await send({"type": "http.response.body", "body": self.body})
await send({"type": prefix + "http.response.body", "body": self.body})

if self.background is not None:
await self.background()
Expand Down
28 changes: 27 additions & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def __init__(self, session: WebSocketTestSession) -> None:
self.session = session


class WebSocketDenialResponse( # type: ignore[misc]
httpx.Response,
WebSocketDisconnect,
):
"""
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
`WebSocket` is closed before being accepted with a `send_denial_response()`.
"""


class WebSocketTestSession:
def __init__(
self,
Expand Down Expand Up @@ -159,7 +169,22 @@ async def _asgi_send(self, message: Message) -> None:
def _raise_on_close(self, message: Message) -> None:
Kludex marked this conversation as resolved.
Show resolved Hide resolved
if message["type"] == "websocket.close":
raise WebSocketDisconnect(
message.get("code", 1000), message.get("reason", "")
code=message.get("code", 1000), reason=message.get("reason", "")
)
elif message["type"] == "websocket.http.response.start":
status_code: int = message["status"]
headers: list[tuple[bytes, bytes]] = message["headers"]
body: list[bytes] = []
while True:
message = self.receive()
assert message["type"] == "websocket.http.response.body"
body.append(message["body"])
if not message.get("more_body", False):
break
raise WebSocketDenialResponse(
status_code=status_code,
headers=headers,
content=b"".join(body),
)

def send(self, message: Message) -> None:
Expand Down Expand Up @@ -277,6 +302,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"server": [host, port],
"subprotocols": subprotocols,
"state": self.app_state.copy(),
"extensions": {"websocket.http.response": {}},
}
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)
Expand Down
33 changes: 30 additions & 3 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
import typing

from starlette.requests import HTTPConnection
from starlette.responses import Response
from starlette.types import Message, Receive, Scope, Send


class WebSocketState(enum.Enum):
CONNECTING = 0
CONNECTED = 1
DISCONNECTED = 2
RESPONSE = 3


class WebSocketDisconnect(Exception):
Expand Down Expand Up @@ -65,13 +67,20 @@ async def send(self, message: Message) -> None:
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
if message_type not in {"websocket.accept", "websocket.close"}:
if message_type not in {
"websocket.accept",
"websocket.close",
"websocket.http.response.start",
Kludex marked this conversation as resolved.
Show resolved Hide resolved
}:
raise RuntimeError(
'Expected ASGI message "websocket.accept" or '
f'"websocket.close", but got {message_type!r}'
'Expected ASGI message "websocket.accept",'
'"websocket.close" or "websocket.http.response.start",'
f"but got {message_type!r}"
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
Expand All @@ -89,6 +98,16 @@ async def send(self, message: Message) -> None:
except IOError:
self.application_state = WebSocketState.DISCONNECTED
raise WebSocketDisconnect(code=1006)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
if message_type != "websocket.http.response.body":
raise RuntimeError(
'Expected ASGI message "websocket.http.response.body", '
f"but got {message_type!r}"
)
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')

Expand Down Expand Up @@ -185,6 +204,14 @@ async def close(self, code: int = 1000, reason: str | None = None) -> None:
{"type": "websocket.close", "code": code, "reason": reason or ""}
)

async def send_denial_response(self, response: Response) -> None:
if "websocket.http.response" in self.scope.get("extensions", {}):
await response(self.scope, self.receive, self.send)
else:
raise RuntimeError(
"The server doesn't support the Websocket Denial Response extension."
)


class WebSocketClose:
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
Expand Down
110 changes: 109 additions & 1 deletion tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette import status
from starlette.testclient import TestClient
from starlette.responses import Response
from starlette.testclient import TestClient, WebSocketDenialResponse
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState

Expand Down Expand Up @@ -293,6 +294,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
def test_rejected_connection(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
await websocket.close(status.WS_1001_GOING_AWAY)

client = test_client_factory(app)
Expand All @@ -302,6 +305,111 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
assert exc.value.code == status.WS_1001_GOING_AWAY


def test_send_denial_response(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
await websocket.send_denial_response(response)

client = test_client_factory(app)
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/"):
pass # pragma: no cover
assert exc.value.status_code == 404
assert exc.value.content == b"foo"


def test_send_response_multi(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
await websocket.send(
{
"type": "websocket.http.response.start",
"status": 404,
"headers": [(b"content-type", b"text/plain"), (b"foo", b"bar")],
}
)
await websocket.send(
{
"type": "websocket.http.response.body",
"body": b"hard",
"more_body": True,
}
)
await websocket.send(
{
"type": "websocket.http.response.body",
"body": b"body",
}
)

client = test_client_factory(app)
with pytest.raises(WebSocketDenialResponse) as exc:
with client.websocket_connect("/"):
pass # pragma: no cover
assert exc.value.status_code == 404
assert exc.value.content == b"hardbody"
assert exc.value.headers["foo"] == "bar"


def test_send_response_unsupported(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
del scope["extensions"]["websocket.http.response"]
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
with pytest.raises(
RuntimeError,
match="The server doesn't support the Websocket Denial Response extension.",
):
await websocket.send_denial_response(response)
await websocket.close()

client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect) as exc:
with client.websocket_connect("/"):
pass # pragma: no cover
assert exc.value.code == status.WS_1000_NORMAL_CLOSURE


def test_send_response_duplicate_start(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
await websocket.send(
{
"type": "websocket.http.response.start",
"status": response.status_code,
"headers": response.raw_headers,
}
)
await websocket.send(
{
"type": "websocket.http.response.start",
"status": response.status_code,
"headers": response.raw_headers,
}
)

client = test_client_factory(app)
with pytest.raises(
RuntimeError,
match=(
'Expected ASGI message "websocket.http.response.body", but got '
"'websocket.http.response.start'"
),
):
with client.websocket_connect("/"):
pass # pragma: no cover


def test_subprotocol(test_client_factory: Callable[..., TestClient]):
async def app(scope: Scope, receive: Receive, send: Send) -> None:
websocket = WebSocket(scope, receive=receive, send=send)
Expand Down