diff --git a/docs/_newsfragments/2025.newandimproved.rst b/docs/_newsfragments/2025.newandimproved.rst new file mode 100644 index 000000000..e06826b19 --- /dev/null +++ b/docs/_newsfragments/2025.newandimproved.rst @@ -0,0 +1 @@ +Support closing a :class:`falcon.asgi.WebSocket` with a reason. diff --git a/falcon/asgi/app.py b/falcon/asgi/app.py index 92b3ac2d4..a11a718c3 100644 --- a/falcon/asgi/app.py +++ b/falcon/asgi/app.py @@ -48,6 +48,7 @@ from .request import Request from .response import Response from .structures import SSEvent +from .ws import check_support_reason from .ws import http_status_to_ws_code from .ws import WebSocket from .ws import WebSocketOptions @@ -988,7 +989,11 @@ async def _handle_websocket(self, ver, scope, receive, send): # we don't support, so bail out. This also fulfills the ASGI # spec requirement to only process the request after # receiving and verifying the first event. - await send({'type': EventType.WS_CLOSE, 'code': WSCloseCode.SERVER_ERROR}) + response = {'type': EventType.WS_CLOSE, 'code': WSCloseCode.SERVER_ERROR} + if check_support_reason(ver): + response['reason'] = 'Internal Server Error' + + await send(response) return req = self._request_type(scope, receive, options=self.req_options) @@ -1000,6 +1005,7 @@ async def _handle_websocket(self, ver, scope, receive, send): send, self.ws_options.media_handlers, self.ws_options.max_receive_queue, + self.ws_options.default_close_reasons, ) on_websocket = None diff --git a/falcon/asgi/ws.py b/falcon/asgi/ws.py index 3879a6e52..05e9a6fed 100644 --- a/falcon/asgi/ws.py +++ b/falcon/asgi/ws.py @@ -50,7 +50,9 @@ class WebSocket: '_asgi_send', '_buffered_receiver', '_close_code', + '_close_reasons', '_supports_accept_headers', + '_supports_reason', '_mh_bin_deserialize', '_mh_bin_serialize', '_mh_text_deserialize', @@ -70,8 +72,10 @@ def __init__( Union[media.BinaryBaseHandlerWS, media.TextBaseHandlerWS], ], max_receive_queue: int, + default_close_reasons: Dict[Optional[int], str], ): self._supports_accept_headers = ver != '2.0' + self._supports_reason = check_support_reason(ver) # NOTE(kgriffs): Normalize the iterable to a stable tuple; note that # ordering is significant, and so we preserve it here. @@ -95,6 +99,7 @@ def __init__( self._mh_bin_serialize = mh_bin.serialize self._mh_bin_deserialize = mh_bin.deserialize + self._close_reasons = default_close_reasons self._state = _WebSocketState.HANDSHAKE self._close_code = None # type: Optional[int] @@ -258,12 +263,15 @@ async def close(self, code: Optional[int] = None) -> None: if self.closed: return - await self._asgi_send( - { - 'type': EventType.WS_CLOSE, - 'code': code, - } - ) + response = {'type': EventType.WS_CLOSE, 'code': code} + + if self._supports_reason: + if code in self._close_reasons: + response['reason'] = self._close_reasons[code] + elif 3100 <= code <= 3999: + response['reason'] = falcon.util.code_to_http_status(code - 3000) + + await self._asgi_send(response) self._state = _WebSocketState.CLOSED self._close_code = code @@ -513,6 +521,10 @@ class WebSocketOptions: unhandled error is raised while handling a WebSocket connection (default ``1011``). For a list of valid close codes and ranges, see also: https://tools.ietf.org/html/rfc6455#section-7.4 + default_close_reasons (dict): A default mapping between the Websocket + close code and the reason why the connection is close. Close codes + corresponding to HTTPErrors are not included as they will be rendered + automatically using HTTP status. media_handlers (dict): A dict-like object for configuring media handlers according to the WebSocket payload type (TEXT vs. BINARY) of a given message. See also: :ref:`ws_media_handlers`. @@ -529,7 +541,12 @@ class WebSocketOptions: """ - __slots__ = ['error_close_code', 'max_receive_queue', 'media_handlers'] + __slots__ = [ + 'error_close_code', + 'default_close_reasons', + 'max_receive_queue', + 'media_handlers', + ] def __init__(self) -> None: try: @@ -560,6 +577,12 @@ def __init__(self) -> None: # self.error_close_code: int = WSCloseCode.SERVER_ERROR + self.default_close_reasons: Dict[int, str] = { + 1000: 'Normal Closure', + 1011: 'Internal Server Error', + 3011: 'Internal Server Error', + } + # NOTE(kgriffs): The websockets library itself will buffer, so we keep # this value fairly small by default to mitigate buffer bloat. But in # the case that we have a large spillover from the websocket server's @@ -706,6 +729,18 @@ async def _pump(self): self._pop_message_waiter = None +def check_support_reason(asgi_ver): + """Check if the websocket version support a close reason.""" + target_ver = [2, 3] + current_ver = asgi_ver.split('.') + + for i in range(2): + if int(current_ver[i]) < target_ver[i]: + return False + + return True + + def http_status_to_ws_code(http_status: int) -> int: """Convert the provided http status to a websocket close code by adding 3000.""" return http_status + 3000 diff --git a/falcon/testing/helpers.py b/falcon/testing/helpers.py index 1b38e975e..2afee75c5 100644 --- a/falcon/testing/helpers.py +++ b/falcon/testing/helpers.py @@ -389,6 +389,8 @@ class ASGIWebSocketSimulator: denied or closed by the app, or the client has disconnected. close_code (int): The WebSocket close code provided by the app if the connection is closed, or ``None`` if the connection is open. + close_reason (str): The WebSocket close reason provided by the app if + the connection is closed, or ``None`` if the connection is open. subprotocol (str): The subprotocol the app wishes to accept, or ``None`` if not specified. headers (Iterable[Iterable[bytes]]): An iterable of ``[name, value]`` @@ -406,6 +408,7 @@ def __init__(self): self._state = _WebSocketState.CONNECT self._disconnect_emitted = False self._close_code = None + self._close_reason = None self._accepted_subprotocol = None self._accepted_headers = None self._collected_server_events = deque() @@ -425,6 +428,10 @@ def closed(self) -> bool: def close_code(self) -> int: return self._close_code + @property + def close_reason(self) -> str: + return self._close_reason + @property def subprotocol(self) -> str: return self._accepted_subprotocol @@ -463,12 +470,14 @@ async def wait_ready(self, timeout: Optional[int] = None): # NOTE(kgriffs): This is a coroutine just in case we need it to be # in a future code revision. It also makes it more consistent # with the other methods. - async def close(self, code: Optional[int] = None): + async def close(self, code: Optional[int] = None, reason: Optional[str] = None): """Close the simulated connection. Keyword Args: code (int): The WebSocket close code to send to the application per the WebSocket spec (default: ``1000``). + reason (str): The WebSocket close reason to send to the application + per the WebSocket spec (default: empty string). """ # NOTE(kgriffs): Give our collector a chance in case the @@ -487,8 +496,12 @@ async def close(self, code: Optional[int] = None): if code is None: code = WSCloseCode.NORMAL + if reason is None: + reason = '' + self._state = _WebSocketState.CLOSED self._close_code = code + self._close_reason = reason async def send_text(self, payload: str): """Send a message to the app with a Unicode string payload. @@ -726,6 +739,7 @@ async def _collect(self, event: Dict[str, Any]): self._state = _WebSocketState.DENIED desired_code = event.get('code', WSCloseCode.NORMAL) + reason = event.get('reason', '') if desired_code == WSCloseCode.SERVER_ERROR or ( 3000 <= desired_code < 4000 ): @@ -734,12 +748,16 @@ async def _collect(self, event: Dict[str, Any]): # different raised error types or to pass through a # raised HTTPError status code. self._close_code = desired_code + self._close_reason = reason else: # NOTE(kgriffs): Force the close code to this since it is # similar to what happens with a real web server (the HTTP # connection is closed with a 403 and there is no websocket # close code). self._close_code = WSCloseCode.FORBIDDEN + self._close_reason = falcon.util.code_to_http_status( + WSCloseCode.FORBIDDEN - 3000 + ) self._event_handshake_complete.set() @@ -754,6 +772,7 @@ async def _collect(self, event: Dict[str, Any]): if event_type == EventType.WS_CLOSE: self._state = _WebSocketState.CLOSED self._close_code = event.get('code', WSCloseCode.NORMAL) + self._close_reason = event.get('reason', '') else: assert event_type == EventType.WS_SEND self._collected_server_events.append(event) @@ -779,7 +798,12 @@ def _create_checked_disconnect(self) -> Dict[str, Any]: ) self._disconnect_emitted = True - return {'type': EventType.WS_DISCONNECT, 'code': self._close_code} + response = {'type': EventType.WS_DISCONNECT, 'code': self._close_code} + + if self._close_reason: + response['reason'] = self._close_reason + + return response # get_encoding_from_headers() is Copyright 2016 Kenneth Reitz, and is diff --git a/tests/asgi/test_ws.py b/tests/asgi/test_ws.py index e83911043..8c350f24d 100644 --- a/tests/asgi/test_ws.py +++ b/tests/asgi/test_ws.py @@ -893,11 +893,11 @@ async def test_bad_http_version(version, conductor): @pytest.mark.asyncio -async def test_bad_first_event(): +@pytest.mark.parametrize('version', ['2.1', '2.3', '2.10.3']) +async def test_bad_first_event(version): app = App() - scope = testing.create_scope_ws() - del scope['asgi']['spec_version'] + scope = testing.create_scope_ws(spec_version=version) ws = testing.ASGIWebSocketSimulator() wrapped_emit = ws._emit @@ -917,6 +917,10 @@ async def _emit(): assert ws.closed assert ws.close_code == CloseCode.SERVER_ERROR + if version != '2.1': + assert ws.close_reason == 'Internal Server Error' + else: + assert ws.close_reason == '' @pytest.mark.asyncio @@ -1129,6 +1133,76 @@ def test_msgpack_missing(): @pytest.mark.asyncio +@pytest.mark.parametrize('reason', ['Client closing connection', '', None]) +async def test_client_close_with_reason(reason, conductor): + class Resource: + def __init__(self): + pass + + async def on_websocket(self, req, ws): + await ws.accept() + while True: + try: + await ws.receive_data() + + except falcon.WebSocketDisconnected: + break + + resource = Resource() + conductor.app.add_route('/', resource) + + async with conductor as c: + async with c.simulate_ws('/', spec_version='2.3') as ws: + await ws.close(4099, reason) + + assert ws.close_code == 4099 + if reason: + assert ws.close_reason == reason + else: + assert ws.close_reason == '' + + +@pytest.mark.asyncio +@pytest.mark.parametrize('no_default', [True, False]) +@pytest.mark.parametrize('code', [None, 1011, 4099, 4042, 3405]) +async def test_no_reason_mapping(no_default, code, conductor): + class Resource: + def __init__(self): + pass + + async def on_websocket(self, req, ws): + await ws.accept() + await ws.close(code) + + resource = Resource() + conductor.app.add_route('/', resource) + if no_default: + conductor.app.ws_options.default_close_reasons = {} + else: + conductor.app.ws_options.default_close_reasons[4099] = '4099 reason' + + async with conductor as c: + with pytest.raises(falcon.WebSocketDisconnected): + async with c.simulate_ws('/', spec_version='2.10.3') as ws: + await ws.receive_data() + + if code: + assert ws.close_code == code + else: + assert ws.close_code == CloseCode.NORMAL + + if 3100 <= ws.close_code <= 3999: + assert ws.close_reason == falcon.util.code_to_http_status(ws.close_code - 3000) + elif ( + no_default + or ws.close_code not in conductor.app.ws_options.default_close_reasons + ): + assert ws.close_reason == '' + else: + reason = conductor.app.ws_options.default_close_reasons[ws.close_code] + assert ws.close_reason == reason + + @pytest.mark.parametrize('status', [200, 500, 422, 400]) @pytest.mark.parametrize('thing', [falcon.HTTPStatus, falcon.HTTPError]) @pytest.mark.parametrize('accept', [True, False])