diff --git a/local-requirements.txt b/local-requirements.txt index 37f448669..70384b335 100644 --- a/local-requirements.txt +++ b/local-requirements.txt @@ -1,3 +1,4 @@ +autobahn==20.7.1 pytest==6.1.0 pytest-asyncio==0.14.0 pytest-cov==2.10.1 diff --git a/playwright/async_api.py b/playwright/async_api.py index 5360ddd73..0de64d3a0 100644 --- a/playwright/async_api.py +++ b/playwright/async_api.py @@ -506,6 +506,54 @@ def url(self) -> str: """ return mapping.from_maybe_impl(self._impl_obj.url) + async def waitForEvent( + self, + event: str, + predicate: typing.Union[typing.Callable[[typing.Any], bool]] = None, + timeout: int = None, + ) -> typing.Any: + """WebSocket.waitForEvent + + Waits for event to fire and passes its value into the predicate function. Resolves when the predicate returns truthy value. Will throw an error if the webSocket is closed before the event + is fired. + + Parameters + ---------- + event : str + Event name, same one would pass into `webSocket.on(event)`. + + Returns + ------- + Any + Promise which resolves to the event data value. + """ + return mapping.from_maybe_impl( + await self._impl_obj.waitForEvent( + event=event, predicate=self._wrap_handler(predicate), timeout=timeout + ) + ) + + def expect_event( + self, + event: str, + predicate: typing.Union[typing.Callable[[typing.Any], bool]] = None, + timeout: int = None, + ) -> AsyncEventContextManager: + return AsyncEventContextManager( + self._impl_obj.waitForEvent(event, predicate, timeout) + ) + + def isClosed(self) -> bool: + """WebSocket.isClosed + + Indicates that the web socket has been closed. + + Returns + ------- + bool + """ + return mapping.from_maybe_impl(self._impl_obj.isClosed()) + mapping.register(WebSocketImpl, WebSocket) diff --git a/playwright/network.py b/playwright/network.py index cc96ba515..407ad3a04 100644 --- a/playwright/network.py +++ b/playwright/network.py @@ -17,10 +17,11 @@ import mimetypes from pathlib import Path from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast from urllib import parse from playwright.connection import ChannelOwner, from_channel, from_nullable_channel +from playwright.event_context_manager import EventContextManagerImpl from playwright.helper import ( ContinueParameters, Error, @@ -29,6 +30,7 @@ ResourceTiming, locals_to_params, ) +from playwright.wait_helper import WaitHelper if TYPE_CHECKING: # pragma: no cover from playwright.frame import Frame @@ -271,6 +273,7 @@ def __init__( self, parent: ChannelOwner, type: str, guid: str, initializer: Dict ) -> None: super().__init__(parent, type, guid, initializer) + self._is_closed = False self._channel.on( "frameSent", lambda params: self._on_frame_sent(params["opcode"], params["data"]), @@ -282,12 +285,40 @@ def __init__( self._channel.on( "error", lambda params: self.emit(WebSocket.Events.Error, params["error"]) ) - self._channel.on("close", lambda params: self.emit(WebSocket.Events.Close)) + self._channel.on("close", lambda params: self._on_close()) @property def url(self) -> str: return self._initializer["url"] + async def waitForEvent( + self, event: str, predicate: Callable[[Any], bool] = None, timeout: int = None + ) -> Any: + if timeout is None: + timeout = cast(Any, self._parent)._timeout_settings.timeout() + wait_helper = WaitHelper(self._loop) + wait_helper.reject_on_timeout( + timeout, f'Timeout while waiting for event "${event}"' + ) + if event != WebSocket.Events.Close: + wait_helper.reject_on_event( + self, WebSocket.Events.Close, Error("Socket closed") + ) + if event != WebSocket.Events.Error: + wait_helper.reject_on_event( + self, WebSocket.Events.Error, Error("Socket error") + ) + wait_helper.reject_on_event(self._parent, "close", Error("Page closed")) + return await wait_helper.wait_for_event(self, event, predicate) + + def expect_event( + self, + event: str, + predicate: Callable[[Any], bool] = None, + timeout: int = None, + ) -> EventContextManagerImpl: + return EventContextManagerImpl(self.waitForEvent(event, predicate, timeout)) + def _on_frame_sent(self, opcode: int, data: str) -> None: if opcode == 2: self.emit(WebSocket.Events.FrameSent, base64.b64decode(data)) @@ -300,6 +331,13 @@ def _on_frame_received(self, opcode: int, data: str) -> None: else: self.emit(WebSocket.Events.FrameReceived, data) + def isClosed(self) -> bool: + return self._is_closed + + def _on_close(self) -> None: + self._is_closed = True + self.emit(WebSocket.Events.Close) + def serialize_headers(headers: Dict[str, str]) -> List[Header]: return [{"name": name, "value": value} for name, value in headers.items()] diff --git a/playwright/page.py b/playwright/page.py index 0abb88e01..3ceb1c0da 100644 --- a/playwright/page.py +++ b/playwright/page.py @@ -294,10 +294,6 @@ def _add_event_handler(self, event: str, k: Any, v: Any) -> None: self._channel.send_no_reply( "setFileChooserInterceptedNoReply", {"intercepted": True} ) - if event == Page.Events.WebSocket and len(self.listeners(event)) == 0: - self._channel.send_no_reply( - "setWebSocketFramesReportingEnabledNoReply", {"enabled": True} - ) super()._add_event_handler(event, k, v) def remove_listener(self, event: str, f: Any) -> None: @@ -306,9 +302,6 @@ def remove_listener(self, event: str, f: Any) -> None: self._channel.send_no_reply( "setFileChooserInterceptedNoReply", {"intercepted": False} ) - # Note: we do not stop reporting web socket frames, since - # user might not listen to 'websocket' anymore, but still have - # a functioning WebSocket object. @property def context(self) -> "BrowserContext": diff --git a/playwright/sync_api.py b/playwright/sync_api.py index 85204247d..040c985f3 100644 --- a/playwright/sync_api.py +++ b/playwright/sync_api.py @@ -512,6 +512,58 @@ def url(self) -> str: """ return mapping.from_maybe_impl(self._impl_obj.url) + def waitForEvent( + self, + event: str, + predicate: typing.Union[typing.Callable[[typing.Any], bool]] = None, + timeout: int = None, + ) -> typing.Any: + """WebSocket.waitForEvent + + Waits for event to fire and passes its value into the predicate function. Resolves when the predicate returns truthy value. Will throw an error if the webSocket is closed before the event + is fired. + + Parameters + ---------- + event : str + Event name, same one would pass into `webSocket.on(event)`. + + Returns + ------- + Any + Promise which resolves to the event data value. + """ + return mapping.from_maybe_impl( + self._sync( + self._impl_obj.waitForEvent( + event=event, + predicate=self._wrap_handler(predicate), + timeout=timeout, + ) + ) + ) + + def expect_event( + self, + event: str, + predicate: typing.Union[typing.Callable[[typing.Any], bool]] = None, + timeout: int = None, + ) -> EventContextManager: + return EventContextManager( + self._loop, self._impl_obj.waitForEvent(event, predicate, timeout) + ) + + def isClosed(self) -> bool: + """WebSocket.isClosed + + Indicates that the web socket has been closed. + + Returns + ------- + bool + """ + return mapping.from_maybe_impl(self._impl_obj.isClosed()) + mapping.register(WebSocketImpl, WebSocket) diff --git a/scripts/documentation_provider.py b/scripts/documentation_provider.py index 644c2d7b0..6269854ea 100644 --- a/scripts/documentation_provider.py +++ b/scripts/documentation_provider.py @@ -89,6 +89,11 @@ def print_entry( or super_clazz["methods"].get(method_name) ) fqname = f"{class_name}.{method_name}" + + if not method: + self.errors.add(f"Method not documented: {fqname}") + return + indent = " " * 8 print(f'{indent}"""{class_name}.{original_method_name}') if method.get("comment"): diff --git a/scripts/expected_api_mismatch.txt b/scripts/expected_api_mismatch.txt index ec9b20ef9..c62aac506 100644 --- a/scripts/expected_api_mismatch.txt +++ b/scripts/expected_api_mismatch.txt @@ -19,8 +19,6 @@ Method not implemented: Download.createReadStream Method not implemented: Logger.isEnabled Method not implemented: Logger.log Method not implemented: Page.coverage -Method not implemented: WebSocket.isClosed -Method not implemented: WebSocket.waitForEvent # Parameter overloads Parameter not documented: BrowserContext.waitForEvent(predicate=) @@ -30,6 +28,8 @@ Parameter not documented: Page.waitForEvent(timeout=) Parameter not documented: Page.waitForRequest(predicate=) Parameter not documented: Page.waitForResponse(predicate=) Parameter not documented: Selectors.register(path=) +Parameter not documented: WebSocket.waitForEvent(timeout=) +Parameter not documented: WebSocket.waitForEvent(predicate=) # Documented as Dict / Any Parameter type mismatch in BrowserContext.setGeolocation(geolocation=): documented as Optional[Dict], code has Optional[{"latitude": float, "longitude": float, "accuracy": Optional[float]}] @@ -42,6 +42,7 @@ Parameter type mismatch in Page.viewportSize(return=): documented as Optional[Di Parameter type mismatch in Page.waitForEvent(return=): documented as Dict, code has Any Parameter type mismatch in Request.failure(return=): documented as Optional[Dict], code has Optional[{"errorText": str}] Parameter type mismatch in Response.json(return=): documented as Any, code has Union[Dict, List] +Parameter type mismatch in WebSocket.waitForEvent(return=): documented as Dict, code has Any # Pathlib Parameter type mismatch in BrowserType.launch(executablePath=): documented as Optional[str], code has Union[str, pathlib.Path, NoneType] @@ -118,3 +119,4 @@ Method not implemented: BrowserType.connect # OptionsOr Parameter not implemented: Page.waitForEvent(optionsOrPredicate=) Parameter not implemented: BrowserContext.waitForEvent(optionsOrPredicate=) +Parameter not implemented: WebSocket.waitForEvent(optionsOrPredicate=) diff --git a/tests/async/test_websocket.py b/tests/async/test_websocket.py new file mode 100644 index 000000000..8f8bcd18e --- /dev/null +++ b/tests/async/test_websocket.py @@ -0,0 +1,124 @@ +# Copyright (c) Microsoft Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from playwright import Error + + +async def test_should_work(page, ws_server): + value = await page.evaluate( + """port => { + let cb; + const result = new Promise(f => cb = f); + const ws = new WebSocket('ws://localhost:' + port + '/ws'); + ws.addEventListener('message', data => { ws.close(); cb(data.data); }); + return result; + }""", + ws_server.PORT, + ) + assert value == "incoming" + pass + + +async def test_should_emit_close_events(page, ws_server): + async with page.expect_event("websocket") as ws_info: + await page.evaluate( + """port => { + let cb; + const result = new Promise(f => cb = f); + const ws = new WebSocket('ws://localhost:' + port + '/ws'); + ws.addEventListener('message', data => { ws.close(); cb(data.data); }); + return result; + }""", + ws_server.PORT, + ) + ws = await ws_info.value + assert ws.url == f"ws://localhost:{ws_server.PORT}/ws" + if not ws.isClosed(): + await ws.waitForEvent("close") + assert ws.isClosed() + + +async def test_should_emit_frame_events(page, ws_server): + sent = [] + received = [] + + def on_web_socket(ws): + ws.on("framesent", lambda payload: sent.append(payload)) + ws.on("framereceived", lambda payload: received.append(payload)) + + page.on("websocket", on_web_socket) + async with page.expect_event("websocket") as ws_info: + await page.evaluate( + """port => { + const ws = new WebSocket('ws://localhost:' + port + '/ws'); + ws.addEventListener('open', () => { + ws.send('echo-text'); + }); + }""", + ws_server.PORT, + ) + ws = await ws_info.value + if not ws.isClosed(): + await ws.waitForEvent("close") + + assert sent == ["echo-text"] + assert received == ["incoming", "text"] + + +async def test_should_emit_binary_frame_events(page, ws_server): + sent = [] + received = [] + + def on_web_socket(ws): + ws.on("framesent", lambda payload: sent.append(payload)) + ws.on("framereceived", lambda payload: received.append(payload)) + + page.on("websocket", on_web_socket) + async with page.expect_event("websocket") as ws_info: + await page.evaluate( + """port => { + const ws = new WebSocket('ws://localhost:' + port + '/ws'); + ws.addEventListener('open', () => { + const binary = new Uint8Array(5); + for (let i = 0; i < 5; ++i) + binary[i] = i; + ws.send(binary); + ws.send('echo-bin'); + }); + }""", + ws_server.PORT, + ) + ws = await ws_info.value + if not ws.isClosed(): + await ws.waitForEvent("close") + assert sent == [b"\x00\x01\x02\x03\x04", "echo-bin"] + assert received == ["incoming", b"\x04\x02"] + + +async def test_should_reject_wait_for_event_on_close_and_error(page, ws_server): + async with page.expect_event("websocket") as ws_info: + await page.evaluate( + """port => { + window.ws = new WebSocket('ws://localhost:' + port + '/ws'); + }""", + ws_server.PORT, + ) + ws = await ws_info.value + await ws.waitForEvent("framereceived") + with pytest.raises(Error) as exc_info: + async with ws.expect_event("framesent"): + await page.evaluate("window.ws.close()") + assert exc_info.value.message == "Socket closed" diff --git a/tests/conftest.py b/tests/conftest.py index 4225100b1..44934b078 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -65,6 +65,11 @@ def https_server(): yield test_server.https_server +@pytest.fixture +def ws_server(): + yield test_server.ws_server + + @pytest.fixture def utils(): yield utils_object diff --git a/tests/server.py b/tests/server.py index a052fa824..e69a42f5f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -23,6 +23,7 @@ from http import HTTPStatus import greenlet +from autobahn.twisted.websocket import WebSocketServerFactory, WebSocketServerProtocol from OpenSSL import crypto from twisted.internet import reactor, ssl from twisted.web import http @@ -212,14 +213,48 @@ def listen(self, factory): reactor.listenSSL(self.PORT, factory, contextFactory) +class WebSocketServerServer(WebSocketServerProtocol): + def __init__(self) -> None: + super().__init__() + self.PORT = _find_free_port() + + def start(self): + ws = WebSocketServerFactory("ws://127.0.0.1:" + str(self.PORT)) + ws.protocol = WebSocketProtocol + reactor.listenTCP(self.PORT, ws) + + +class WebSocketProtocol(WebSocketServerProtocol): + def onConnect(self, request): + pass + + def onOpen(self): + self.sendMessage(b"incoming") + + def onMessage(self, payload, isBinary): + if payload == b"echo-bin": + self.sendMessage(b"\x04\x02", True) + self.sendClose() + if payload == b"echo-text": + self.sendMessage(b"text", False) + self.sendClose() + if payload == b"close": + self.sendClose() + + def onClose(self, wasClean, code, reason): + pass + + class TestServer: def __init__(self) -> None: self.server = HTTPServer() self.https_server = HTTPSServer() + self.ws_server = WebSocketServerServer() def start(self) -> None: self.server.start() self.https_server.start() + self.ws_server.start() self.thread = threading.Thread( target=lambda: reactor.run(installSignalHandlers=0) )