Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion playwright/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
webkit = playwright_object.webkit
devices = playwright_object.devices
browser_types = playwright_object.browser_types
Error = helper.Error
TimeoutError = helper.TimeoutError

__all__ = ["browser_types", "chromium", "firefox", "webkit", "devices", "TimeoutError"]
__all__ = [
"browser_types",
"chromium",
"firefox",
"webkit",
"devices",
"Error",
"TimeoutError",
]
35 changes: 19 additions & 16 deletions playwright/browser_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
URLMatcher,
)
from playwright.network import Request, Route
from playwright.page import BindingCall, Page
from playwright.page import BindingCall, Page, wait_for_event
from types import SimpleNamespace
from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING

Expand Down Expand Up @@ -85,14 +85,21 @@ def _on_binding(self, binding_call: BindingCall) -> None:
func = self._bindings.get(binding_call._initializer["name"])
if func is None:
return
binding_call.call(func)
asyncio.ensure_future(binding_call.call(func))

def setDefaultNavigationTimeout(self, timeout: int) -> None:
self._channel.send("setDefaultNavigationTimeoutNoReply", dict(timeout=timeout))
self._timeout_settings.set_navigation_timeout(timeout)
asyncio.ensure_future(
self._channel.send(
"setDefaultNavigationTimeoutNoReply", dict(timeout=timeout)
)
)

def setDefaultTimeout(self, timeout: int) -> None:
self._timeout_settings.set_default_timeout(timeout)
self._channel.send("setDefaultTimeoutNoReply", dict(timeout=timeout))
self._timeout_settings.set_timeout(timeout)
asyncio.ensure_future(
self._channel.send("setDefaultTimeoutNoReply", dict(timeout=timeout))
)

@property
def pages(self) -> List[Page]:
Expand Down Expand Up @@ -175,15 +182,12 @@ async def unroute(self, match: URLMatch, handler: Optional[RouteHandler]) -> Non
"setNetworkInterceptionEnabled", dict(enabled=False)
)

async def waitForEvent(self, event: str) -> None:
# TODO: implement timeout race
future = self._scope._loop.create_future()
self.once(event, lambda e: future.set_result(e))
pending_event = PendingWaitEvent(event, future)
self._pending_wait_for_events.append(pending_event)
result = await future
self._pending_wait_for_events.remove(pending_event)
return result
async def waitForEvent(
self, event: str, predicate: Callable[[Any], bool] = None, timeout: int = None
) -> Any:
return await wait_for_event(
self, self._timeout_settings, event, predicate=predicate, timeout=timeout
)

def _on_close(self):
if self._browser:
Expand All @@ -192,9 +196,8 @@ def _on_close(self):
for pending_event in self._pending_wait_for_events:
if pending_event.event == BrowserContext.Events.Close:
continue
pending_event.future.set_exception(Error("Context closed"))
pending_event.reject(False, "Context")

self._pending_wait_for_events.clear()
self.emit(BrowserContext.Events.Close)
self._scope.dispose()

Expand Down
19 changes: 14 additions & 5 deletions playwright/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import traceback
from playwright.helper import parse_error, ParsedMessagePayload
from playwright.transport import Transport
from pyee import BaseEventEmitter
Expand Down Expand Up @@ -96,6 +97,12 @@ def create_remote_object(self, type: str, guid: str, initializer: Dict) -> Any:
return result


class ProtocolCallback:
def __init__(self, loop: asyncio.AbstractEventLoop):
self.stack_trace = "".join(traceback.format_stack()[-10:])
self.future = loop.create_future()


class Connection:
def __init__(
self,
Expand All @@ -111,7 +118,7 @@ def __init__(
self._loop = loop
self._objects: Dict[str, ChannelOwner] = dict()
self._scopes: Dict[str, ConnectionScope] = dict()
self._callbacks: Dict[int, asyncio.Future] = dict()
self._callbacks: Dict[int, ProtocolCallback] = dict()
self._root_scope = self.create_scope("", None)
self._object_factory = object_factory

Expand All @@ -134,9 +141,9 @@ async def _send_message_to_server(
params=self._replace_channels_with_guids(params),
)
self._transport.send(message)
callback = self._loop.create_future()
callback = ProtocolCallback(self._loop)
self._callbacks[id] = callback
return await callback
return await callback.future

def _dispatch(self, msg: ParsedMessagePayload):

Expand All @@ -145,10 +152,12 @@ def _dispatch(self, msg: ParsedMessagePayload):
callback = self._callbacks.pop(id)
error = msg.get("error")
if error:
callback.set_exception(parse_error(error))
parsed_error = parse_error(error)
parsed_error.stack = callback.stack_trace
callback.future.set_exception(parsed_error)
else:
result = self._replace_guids_with_channels(msg.get("result"))
callback.set_result(result)
callback.future.set_result(result)
return

guid = msg["guid"]
Expand Down
3 changes: 2 additions & 1 deletion playwright/dialog.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from playwright.connection import ChannelOwner, ConnectionScope
from playwright.helper import locals_to_params
from typing import Dict


Expand All @@ -33,7 +34,7 @@ def defaultValue(self) -> str:
return self._initializer["defaultValue"]

async def accept(self, prompt_text: str = None) -> None:
await self._channel.send("accept", dict(promptText=prompt_text))
await self._channel.send("accept", locals_to_params(locals()))

async def dismiss(self) -> None:
await self._channel.send("dismiss")
53 changes: 44 additions & 9 deletions playwright/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import asyncio
import fnmatch
import re
import traceback

from typing import (
Any,
Expand All @@ -40,7 +41,7 @@
from playwright.network import Route, Request

Cookie = List[Dict[str, Union[str, int, bool]]]
URLMatch = Union[str, Callable[[str], bool]]
URLMatch = Union[str, Pattern, Callable[[str], bool]]
RouteHandler = Callable[["Route", "Request"], None]
FunctionWithSource = Callable[[Dict], Any]

Expand Down Expand Up @@ -101,8 +102,10 @@ def __init__(self, match: URLMatch):
self._callback: Optional[Callable[[str], bool]] = None
self._regex_obj: Optional[Pattern] = None
if isinstance(match, str):
regex = "(?:http://|https://)" + fnmatch.translate(match)
regex = fnmatch.translate(match)
self._regex_obj = re.compile(regex)
elif isinstance(match, Pattern):
self._regex_obj = match
else:
self._callback = match
self.match = match
Expand All @@ -111,16 +114,35 @@ def matches(self, url: str) -> bool:
if self._callback:
return self._callback(url)
if self._regex_obj:
return cast(bool, self._regex_obj.match(url))
return cast(bool, self._regex_obj.search(url))
return False


class TimeoutSettings:
def __init__(self, parent: Optional["TimeoutSettings"]) -> None:
self._parent = parent
self._timeout = 30000
self._navigation_timeout = 30000

def set_default_timeout(self, timeout):
self.timeout = timeout
def set_timeout(self, timeout: int):
self._timeout = timeout

def timeout(self) -> int:
if self._timeout is not None:
return self._timeout
if self._parent:
return self._parent.timeout()
return 30000

def set_navigation_timeout(self, navigation_timeout: int):
self._navigation_timeout = navigation_timeout

def navigation_timeout(self) -> int:
if self._navigation_timeout is not None:
return self._navigation_timeout
if self._parent:
return self._parent.navigation_timeout()
return 30000


class Error(Exception):
Expand All @@ -133,11 +155,11 @@ class TimeoutError(Error):
pass


def serialize_error(ex: Exception) -> ErrorPayload:
return dict(message=str(ex))
def serialize_error(ex: Exception, tb) -> ErrorPayload:
return dict(message=str(ex), stack="".join(traceback.format_tb(tb)))


def parse_error(error: ErrorPayload):
def parse_error(error: ErrorPayload) -> Error:
base_error_class = Error
if error.get("name") == "TimeoutError":
base_error_class = TimeoutError
Expand All @@ -164,9 +186,22 @@ def locals_to_params(args: Dict) -> Dict:


class PendingWaitEvent:
def __init__(self, event: str, future: asyncio.Future):
def __init__(
self, event: str, future: asyncio.Future, timeout_future: asyncio.Future
):
self.event = event
self.future = future
self.timeout_future = timeout_future

def reject(self, is_crash: bool, target: str):
self.timeout_future.cancel()
if self.event == "close" and not is_crash:
return
if self.event == "crash" and is_crash:
return
self.future.set_exception(
Error(f"{target} crashed" if is_crash else f"{target} closed")
)


class RouteHandlerEntry:
Expand Down
Loading