Skip to content

Commit

Permalink
refactor: ws receive disconnect exc type (#2690)
Browse files Browse the repository at this point in the history
Modifies `WebSocket.wrapped_receive()` to raise `WebSocketDisconnect` instead of `WebSocketException` where `receive()` is called after a disconnect message has already been received.
  • Loading branch information
peterschutt committed Nov 17, 2023
1 parent cf7e158 commit 38f3b19
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions litestar/connection/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
empty_send,
)
from litestar.datastructures.headers import Headers
from litestar.exceptions import WebSocketDisconnect, WebSocketException
from litestar.exceptions import WebSocketDisconnect
from litestar.serialization import decode_json, decode_msgpack, default_serializer, encode_json, encode_msgpack
from litestar.status_codes import WS_1000_NORMAL_CLOSURE

Expand Down Expand Up @@ -72,7 +72,7 @@ def receive_wrapper(self, receive: Receive) -> Receive:

async def wrapped_receive() -> ReceiveMessage:
if self.connection_state == "disconnect":
raise WebSocketException(detail=DISCONNECT_MESSAGE)
raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE)
message = await receive()
if message["type"] == "websocket.connect":
self.connection_state = "connect"
Expand Down Expand Up @@ -167,8 +167,6 @@ async def receive_data(self, mode: WebSocketMode) -> str | bytes:
event = cast("WebSocketReceiveEvent | WebSocketDisconnectEvent", await self.receive())
if event["type"] == "websocket.disconnect":
raise WebSocketDisconnect(detail="disconnect event", code=event["code"])
if self.connection_state == "disconnect":
raise WebSocketDisconnect(detail=DISCONNECT_MESSAGE) # pragma: no cover
return event.get("text") or "" if mode == "text" else event.get("bytes") or b""

@overload
Expand Down

0 comments on commit 38f3b19

Please sign in to comment.