Skip to content

Commit

Permalink
Add reason to WebSocket closure (#1417)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
aminalaee and Kludex committed Jan 22, 2022
1 parent 9d282a9 commit 34d9f0f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/websockets.md
Expand Up @@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da

### Closing the connection

* `await websocket.close(code=1000)`
* `await websocket.close(code=1000, reason=None)`

### Sending and receiving messages

Expand Down
4 changes: 3 additions & 1 deletion starlette/testclient.py
Expand Up @@ -352,7 +352,9 @@ async def _asgi_send(self, message: Message) -> None:

def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(message.get("code", 1000))
raise WebSocketDisconnect(
message.get("code", 1000), message.get("reason", "")
)

def send(self, message: Message) -> None:
self._receive_queue.put(message)
Expand Down
16 changes: 11 additions & 5 deletions starlette/websockets.py
Expand Up @@ -13,8 +13,9 @@ class WebSocketState(enum.Enum):


class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = 1000, reason: str = None) -> None:
self.code = code
self.reason = reason or ""


class WebSocket(HTTPConnection):
Expand Down Expand Up @@ -146,13 +147,18 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None:
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})

async def close(self, code: int = 1000) -> None:
await self.send({"type": "websocket.close", "code": code})
async def close(self, code: int = 1000, reason: str = None) -> None:
await self.send(
{"type": "websocket.close", "code": code, "reason": reason or ""}
)


class WebSocketClose:
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = 1000, reason: str = None) -> None:
self.code = code
self.reason = reason or ""

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "websocket.close", "code": self.code})
await send(
{"type": "websocket.close", "code": self.code, "reason": self.reason}
)
17 changes: 17 additions & 0 deletions tests/test_websockets.py
Expand Up @@ -405,3 +405,20 @@ async def mock_send(message):
assert websocket == websocket
assert websocket in {websocket}
assert {websocket} == {websocket}


def test_websocket_close_reason(test_client_factory) -> None:
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.close(code=status.WS_1001_GOING_AWAY, reason="Going Away")

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Going Away"

0 comments on commit 34d9f0f

Please sign in to comment.