Skip to content

Commit

Permalink
Deliver a websocket.disconnect message to the app even if it closes/r…
Browse files Browse the repository at this point in the history
…ejects itself.
  • Loading branch information
kristjanvalur committed Mar 22, 2023
1 parent 39ba68c commit 463564a
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
15 changes: 15 additions & 0 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
self._send = send
self.client_state = WebSocketState.CONNECTING
self.application_state = WebSocketState.CONNECTING
self.app_disconnect_msg: typing.Optional[Message] = None

def _have_response_extension(self) -> bool:
return "websocket.http.response" in self.scope.get("extensions", {})
Expand All @@ -36,6 +37,11 @@ async def receive(self) -> Message:
"""
Receive ASGI websocket messages, ensuring valid state transitions.
"""
if self.app_disconnect_msg is not None:
# return message which resulted from app disconnect
msg = self.app_disconnect_msg
self.app_disconnect_msg = None
return msg
if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
Expand All @@ -56,6 +62,8 @@ async def receive(self) -> Message:
)
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
if "code" not in message:
message["code"] = 1005 # websocket spec
return message
else:
raise RuntimeError(
Expand All @@ -80,6 +88,8 @@ async def send(self, message: Message) -> None:
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
# no close frame is sent, then the default is 1006
self.app_disconnect_msg = {"type": "websocket.disconnect", "code": 1006}
elif message_type == "websocket.http.response.start":
self.application_state = WebSocketState.RESPONSE
else:
Expand All @@ -94,6 +104,10 @@ async def send(self, message: Message) -> None:
)
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
self.app_disconnect_msg = {
"type": "websocket.disconnect",
"code": message.get("code", 1000),
}
await self._send(message)
elif self.application_state == WebSocketState.RESPONSE:
message_type = message["type"]
Expand All @@ -104,6 +118,7 @@ async def send(self, message: Message) -> None:
)
if not message.get("more_body", False):
self.application_state = WebSocketState.DISCONNECTED
self.app_disconnect_msg = {"type": "websocket.disconnect", "code": 1006}
await self._send(message)
else:
raise RuntimeError('Cannot call "send" once a close message has been sent.')
Expand Down
34 changes: 33 additions & 1 deletion tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,63 +226,95 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:


def test_application_close(test_client_factory):
close_msg = {}

async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal close_msg
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.close(status.WS_1001_GOING_AWAY)
close_msg = await websocket.receive()

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.code == status.WS_1001_GOING_AWAY
assert close_msg == {
"type": "websocket.disconnect",
"code": status.WS_1001_GOING_AWAY,
}


def test_rejected_connection(test_client_factory):
close_msg = {}

async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal close_msg
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
await websocket.close(status.WS_1001_GOING_AWAY)
close_msg = await websocket.receive()

client = test_client_factory(app)
with pytest.raises(WebSocketReject) as exc:
with client.websocket_connect("/"):
pass # pragma: nocover
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.response_status == 403
assert close_msg == {
"type": "websocket.disconnect",
"code": status.WS_1006_ABNORMAL_CLOSURE,
}


def test_send_response(test_client_factory):
close_msg = {}

async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal close_msg
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
await websocket.send_response(response)
close_msg = await websocket.receive()

client = test_client_factory(app)
with pytest.raises(WebSocketReject) as exc:
with client.websocket_connect("/"):
pass # pragma: nocover
assert exc.value.response_status == 404
assert exc.value.response_body == b"foo"
assert close_msg == {
"type": "websocket.disconnect",
"code": status.WS_1006_ABNORMAL_CLOSURE,
}


def test_send_response_unsupported(test_client_factory):
close_msg = {}

async def app(scope: Scope, receive: Receive, send: Send) -> None:
nonlocal close_msg
del scope["extensions"]["websocket.http.response"]
websocket = WebSocket(scope, receive=receive, send=send)
msg = await websocket.receive()
assert msg == {"type": "websocket.connect"}
response = Response(status_code=404, content="foo")
await websocket.send_response(response)
close_msg = await websocket.receive()

client = test_client_factory(app)
with pytest.raises(WebSocketDisconnect) as exc:
with client.websocket_connect("/"):
pass # pragma: nocover
assert exc.value.code == status.WS_1008_POLICY_VIOLATION
assert close_msg == {
"type": "websocket.disconnect",
"code": status.WS_1006_ABNORMAL_CLOSURE,
}


def test_send_response_invalid(test_client_factory):
Expand Down

0 comments on commit 463564a

Please sign in to comment.