Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reason to WebSocket closure #1417

Merged
merged 20 commits into from Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/websockets.md
Expand Up @@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da

### Closing the connection

* `await websocket.close(code=1000)`
* `await websocket.close(code=1000, reason="Normal Closure")`
aminalaee marked this conversation as resolved.
Show resolved Hide resolved

### Sending and receiving messages

Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Expand Up @@ -352,7 +352,7 @@ async def _asgi_send(self, message: Message) -> None:

def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(message.get("code", 1000))
raise WebSocketDisconnect(message.get("code", 1000), message.get("reason"))
aminalaee marked this conversation as resolved.
Show resolved Hide resolved

def send(self, message: Message) -> None:
self._receive_queue.put(message)
Expand Down
26 changes: 21 additions & 5 deletions starlette/websockets.py
Expand Up @@ -13,8 +13,9 @@ class WebSocketState(enum.Enum):


class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = None, reason: str = None) -> None:
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
self.code = code
self.reason = reason


class WebSocket(HTTPConnection):
Expand Down Expand Up @@ -144,13 +145,28 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None:
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})

async def close(self, code: int = 1000) -> None:
await self.send({"type": "websocket.close", "code": code})
async def close(self, code: int = None, reason: str = None) -> None:
message: dict = {"type": "websocket.close"}

if code is not None:
message["code"] = code
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
if reason is not None:
message["reason"] = reason

await self.send(message)


class WebSocketClose:
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = None, reason: str = None) -> None:
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
self.code = code
self.reason = reason

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "websocket.close", "code": self.code})
message: dict = {"type": "websocket.close"}

if self.code is not None:
message["code"] = self.code
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
if self.reason is not None:
message["reason"] = self.reason

await send(message)
38 changes: 37 additions & 1 deletion tests/test_websockets.py
Expand Up @@ -2,7 +2,7 @@
import pytest

from starlette import status
from starlette.websockets import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocket, WebSocketClose, WebSocketDisconnect


def test_websocket_url(test_client_factory):
Expand Down Expand Up @@ -391,3 +391,39 @@ async def mock_send(message):
assert websocket == websocket
assert websocket in {websocket}
assert {websocket} == {websocket}


def test_websocket_close_reason(test_client_factory) -> None:
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.close(code=1001, reason="Closing")

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Closing"


def test_websocket_close_reason_manual(test_client_factory) -> None:
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()

websocket_close = WebSocketClose(code=1001, reason="Closing")
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
await websocket_close(scope, receive, send)

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Closing"