From f769acd46dbdd47829cace5cb2d3796c956a0a8e Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Thu, 3 Feb 2022 12:40:00 +0100 Subject: [PATCH 1/4] Switch WebSocket assertions with RuntimeError --- starlette/websockets.py | 43 +++++++++++++++++++----- tests/test_websockets.py | 72 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 9 deletions(-) diff --git a/starlette/websockets.py b/starlette/websockets.py index da7406047..58eb23c4b 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -34,13 +34,20 @@ async def receive(self) -> Message: if self.client_state == WebSocketState.CONNECTING: message = await self._receive() message_type = message["type"] - assert message_type == "websocket.connect" + if message_type != "websocket.connect": + raise RuntimeError( + 'WebSocket is not connected. Need to call "accept" first.' + ) self.client_state = WebSocketState.CONNECTED return message elif self.client_state == WebSocketState.CONNECTED: message = await self._receive() message_type = message["type"] - assert message_type in {"websocket.receive", "websocket.disconnect"} + if message_type not in {"websocket.receive", "websocket.disconnect"}: + raise RuntimeError( + "Websocket is connected." + 'Message type should be either "receive" or "disconnect".' + ) if message_type == "websocket.disconnect": self.client_state = WebSocketState.DISCONNECTED return message @@ -55,7 +62,10 @@ async def send(self, message: Message) -> None: """ if self.application_state == WebSocketState.CONNECTING: message_type = message["type"] - assert message_type in {"websocket.accept", "websocket.close"} + if message_type not in {"websocket.accept", "websocket.close"}: + raise RuntimeError( + 'WebSocket is not connected. Need to call "accept" first.' + ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED else: @@ -63,7 +73,11 @@ async def send(self, message: Message) -> None: await self._send(message) elif self.application_state == WebSocketState.CONNECTED: message_type = message["type"] - assert message_type in {"websocket.send", "websocket.close"} + if message_type not in {"websocket.send", "websocket.close"}: + raise RuntimeError( + "Websocket is connected." + 'Message type should be either "send" or "close".' + ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED await self._send(message) @@ -89,20 +103,30 @@ def _raise_on_disconnect(self, message: Message) -> None: raise WebSocketDisconnect(message["code"]) async def receive_text(self) -> str: - assert self.application_state == WebSocketState.CONNECTED + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError( + 'WebSocket is not connected. Need to call "accept" first.' + ) message = await self.receive() self._raise_on_disconnect(message) return message["text"] async def receive_bytes(self) -> bytes: - assert self.application_state == WebSocketState.CONNECTED + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError( + 'WebSocket is not connected. Need to call "accept" first.' + ) message = await self.receive() self._raise_on_disconnect(message) return message["bytes"] async def receive_json(self, mode: str = "text") -> typing.Any: - assert mode in ["text", "binary"] - assert self.application_state == WebSocketState.CONNECTED + if mode not in {"text", "binary"}: + raise RuntimeError('The "mode" argument should be "text" or "binary".') + if self.application_state != WebSocketState.CONNECTED: + raise RuntimeError( + 'WebSocket is not connected. Need to call "accept" first.' + ) message = await self.receive() self._raise_on_disconnect(message) @@ -140,7 +164,8 @@ async def send_bytes(self, data: bytes) -> None: await self.send({"type": "websocket.send", "bytes": data}) async def send_json(self, data: typing.Any, mode: str = "text") -> None: - assert mode in ["text", "binary"] + if mode not in {"text", "binary"}: + raise RuntimeError('The "mode" argument should be "text" or "binary".') text = json.dumps(data) if mode == "text": await self.send({"type": "websocket.send", "text": text}) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index b11685cbc..edccd50f5 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -422,3 +422,75 @@ async def asgi(receive, send): websocket.receive_text() assert exc.value.code == status.WS_1001_GOING_AWAY assert exc.value.reason == "Going Away" + + +def test_send_json_invalid_mode(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.send_json({"url": str(websocket.url)}, mode="invalid") + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError) as exc: + with client.websocket_connect("/"): + assert exc.match('The "mode" argument should be "text" or "binary".') + + +def test_receive_json_invalid_mode(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.receive_json(mode="invalid") + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_text_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.receive_text() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_bytes_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.receive_bytes() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_json_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.receive_json() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover From e698099f0c4e4154cd55faa0ac0dfd8ace6b6b47 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Thu, 3 Feb 2022 16:16:53 +0100 Subject: [PATCH 2/4] uplift coverage --- tests/test_websockets.py | 68 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 4 deletions(-) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index edccd50f5..e3a52762a 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -2,7 +2,7 @@ import pytest from starlette import status -from starlette.websockets import WebSocket, WebSocketDisconnect +from starlette.websockets import WebSocket, WebSocketDisconnect, WebSocketState def test_websocket_url(test_client_factory): @@ -429,14 +429,14 @@ def app(scope): async def asgi(receive, send): websocket = WebSocket(scope, receive=receive, send=send) await websocket.accept() - await websocket.send_json({"url": str(websocket.url)}, mode="invalid") + await websocket.send_json({}, mode="invalid") return asgi client = test_client_factory(app) - with pytest.raises(RuntimeError) as exc: + with pytest.raises(RuntimeError): with client.websocket_connect("/"): - assert exc.match('The "mode" argument should be "text" or "binary".') + pass # pragma: nocover def test_receive_json_invalid_mode(test_client_factory): @@ -494,3 +494,63 @@ async def asgi(receive, send): with pytest.raises(RuntimeError): with client.websocket_connect("/"): pass # pragma: nocover + + +def test_send_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.send({"type": "websocket.send"}) + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_send_wrong_message_type(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.send({"type": "websocket.accept"}) + await websocket.send({"type": "websocket.accept"}) + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/"): + pass # pragma: nocover + + +def test_receive_before_accept(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + websocket.client_state = WebSocketState.CONNECTING + await websocket.receive() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/") as websocket: + websocket.send({"type": "websocket.send"}) + + +def test_receive_wrong_message_type(test_client_factory): + def app(scope): + async def asgi(receive, send): + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await websocket.receive() + + return asgi + + client = test_client_factory(app) + with pytest.raises(RuntimeError): + with client.websocket_connect("/") as websocket: + websocket.send({"type": "websocket.connect"}) From 430b9fd14a4276860625cb6b20671395f07eb8c0 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Fri, 4 Feb 2022 11:44:11 +0100 Subject: [PATCH 3/4] update errors messages for send and receive --- starlette/websockets.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/starlette/websockets.py b/starlette/websockets.py index 58eb23c4b..736b5c2c7 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -36,7 +36,8 @@ async def receive(self) -> Message: message_type = message["type"] if message_type != "websocket.connect": raise RuntimeError( - 'WebSocket is not connected. Need to call "accept" first.' + 'Expected ASGI message "websocket.connect",' + f"but got {message_type!r}" ) self.client_state = WebSocketState.CONNECTED return message @@ -45,8 +46,8 @@ async def receive(self) -> Message: message_type = message["type"] if message_type not in {"websocket.receive", "websocket.disconnect"}: raise RuntimeError( - "Websocket is connected." - 'Message type should be either "receive" or "disconnect".' + 'Expected ASGI message "websocket.receive" or' + f'"websocket.disconnect", but got {message_type!r}' ) if message_type == "websocket.disconnect": self.client_state = WebSocketState.DISCONNECTED @@ -64,7 +65,8 @@ async def send(self, message: Message) -> None: message_type = message["type"] if message_type not in {"websocket.accept", "websocket.close"}: raise RuntimeError( - 'WebSocket is not connected. Need to call "accept" first.' + 'Expected ASGI message "websocket.connect",' + f"but got {message_type!r}" ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED @@ -75,8 +77,8 @@ async def send(self, message: Message) -> None: message_type = message["type"] if message_type not in {"websocket.send", "websocket.close"}: raise RuntimeError( - "Websocket is connected." - 'Message type should be either "send" or "close".' + 'Expected ASGI message "websocket.send" or "websocket.close",' + f"but got {message_type!r}" ) if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED From 655854b7edf0e0a2f8c1c2547428afbf06d3a3d1 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 4 Feb 2022 10:48:26 +0000 Subject: [PATCH 4/4] Tweak exception message - whitespace after comma. Co-authored-by: Marcelo Trylesinski --- starlette/websockets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/starlette/websockets.py b/starlette/websockets.py index 736b5c2c7..03ed19972 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -36,7 +36,7 @@ async def receive(self) -> Message: message_type = message["type"] if message_type != "websocket.connect": raise RuntimeError( - 'Expected ASGI message "websocket.connect",' + 'Expected ASGI message "websocket.connect", ' f"but got {message_type!r}" ) self.client_state = WebSocketState.CONNECTED @@ -46,7 +46,7 @@ async def receive(self) -> Message: message_type = message["type"] if message_type not in {"websocket.receive", "websocket.disconnect"}: raise RuntimeError( - 'Expected ASGI message "websocket.receive" or' + 'Expected ASGI message "websocket.receive" or ' f'"websocket.disconnect", but got {message_type!r}' ) if message_type == "websocket.disconnect": @@ -65,7 +65,7 @@ async def send(self, message: Message) -> None: message_type = message["type"] if message_type not in {"websocket.accept", "websocket.close"}: raise RuntimeError( - 'Expected ASGI message "websocket.connect",' + 'Expected ASGI message "websocket.connect", ' f"but got {message_type!r}" ) if message_type == "websocket.close": @@ -77,7 +77,7 @@ async def send(self, message: Message) -> None: message_type = message["type"] if message_type not in {"websocket.send", "websocket.close"}: raise RuntimeError( - 'Expected ASGI message "websocket.send" or "websocket.close",' + 'Expected ASGI message "websocket.send" or "websocket.close", ' f"but got {message_type!r}" ) if message_type == "websocket.close":