Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace WebSocket assertions with RuntimeError #1472

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 36 additions & 9 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,21 @@ async def receive(self) -> Message:
if self.client_state == WebSocketState.CONNECTING:
message = await self._receive()
message_type = message["type"]
assert message_type == "websocket.connect"
if message_type != "websocket.connect":
raise RuntimeError(
'Expected ASGI message "websocket.connect",'
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
f"but got {message_type!r}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we think about this kind of style for the message_type case?...

f'Expected ASGI message "websocket.connect", but got {message_type!r}'

self.client_state = WebSocketState.CONNECTED
return message
elif self.client_state == WebSocketState.CONNECTED:
message = await self._receive()
message_type = message["type"]
assert message_type in {"websocket.receive", "websocket.disconnect"}
if message_type not in {"websocket.receive", "websocket.disconnect"}:
raise RuntimeError(
'Expected ASGI message "websocket.receive" or'
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
f'"websocket.disconnect", but got {message_type!r}'
)
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
if message_type == "websocket.disconnect":
self.client_state = WebSocketState.DISCONNECTED
return message
Expand All @@ -55,15 +63,23 @@ async def send(self, message: Message) -> None:
"""
if self.application_state == WebSocketState.CONNECTING:
message_type = message["type"]
assert message_type in {"websocket.accept", "websocket.close"}
if message_type not in {"websocket.accept", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.connect",'
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
f"but got {message_type!r}"
)
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
else:
self.application_state = WebSocketState.CONNECTED
await self._send(message)
elif self.application_state == WebSocketState.CONNECTED:
message_type = message["type"]
assert message_type in {"websocket.send", "websocket.close"}
if message_type not in {"websocket.send", "websocket.close"}:
raise RuntimeError(
'Expected ASGI message "websocket.send" or "websocket.close",'
tomchristie marked this conversation as resolved.
Show resolved Hide resolved
f"but got {message_type!r}"
)
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
if message_type == "websocket.close":
self.application_state = WebSocketState.DISCONNECTED
await self._send(message)
Expand All @@ -89,20 +105,30 @@ def _raise_on_disconnect(self, message: Message) -> None:
raise WebSocketDisconnect(message["code"])

async def receive_text(self) -> str:
assert self.application_state == WebSocketState.CONNECTED
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)
return message["text"]

async def receive_bytes(self) -> bytes:
assert self.application_state == WebSocketState.CONNECTED
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)
return message["bytes"]

async def receive_json(self, mode: str = "text") -> typing.Any:
assert mode in ["text", "binary"]
assert self.application_state == WebSocketState.CONNECTED
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
if self.application_state != WebSocketState.CONNECTED:
raise RuntimeError(
'WebSocket is not connected. Need to call "accept" first.'
)
message = await self.receive()
self._raise_on_disconnect(message)

Expand Down Expand Up @@ -140,7 +166,8 @@ async def send_bytes(self, data: bytes) -> None:
await self.send({"type": "websocket.send", "bytes": data})

async def send_json(self, data: typing.Any, mode: str = "text") -> None:
assert mode in ["text", "binary"]
if mode not in {"text", "binary"}:
raise RuntimeError('The "mode" argument should be "text" or "binary".')
text = json.dumps(data)
if mode == "text":
await self.send({"type": "websocket.send", "text": text})
Expand Down
134 changes: 133 additions & 1 deletion tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from starlette import status
from starlette.websockets import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState


def test_websocket_url(test_client_factory):
Expand Down Expand Up @@ -422,3 +422,135 @@ async def asgi(receive, send):
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Going Away"


def test_send_json_invalid_mode(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.send_json({}, mode="invalid")

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_json_invalid_mode(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.receive_json(mode="invalid")

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_text_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_text()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_bytes_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_bytes()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_json_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.receive_json()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_send_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.send({"type": "websocket.send"})

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_send_wrong_message_type(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.send({"type": "websocket.accept"})
await websocket.send({"type": "websocket.accept"})

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/"):
pass # pragma: nocover


def test_receive_before_accept(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
websocket.client_state = WebSocketState.CONNECTING
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't test this without modifying the client state.

await websocket.receive()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/") as websocket:
websocket.send({"type": "websocket.send"})


def test_receive_wrong_message_type(test_client_factory):
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.receive()

return asgi

client = test_client_factory(app)
with pytest.raises(RuntimeError):
with client.websocket_connect("/") as websocket:
websocket.send({"type": "websocket.connect"})