Skip to content

Commit

Permalink
chore: port route chaining, fallback, async, times (#1376)
Browse files Browse the repository at this point in the history
This is part 2/n of the 1.23 port.

Relates #1308, #1374

Ports:

  - [x] microsoft/playwright@a1324bd (fix(route): support route w/ async handler & times (#14317))
  - [x] microsoft/playwright@7a568a2 (feat(route): chain routes (#14771))
  - [x] microsoft/playwright@dcdd3c3 (feat(route): explicitly fall back to the next handler (#14834))
  - [x] microsoft/playwright@9cf068a (feat(fallback): allow falling back w/ overrides (#14849))
  - [x] microsoft/playwright@48f9867 (chore: remove stray fallback overrides check)
  - [x] microsoft/playwright@ae6f48c (fix(route): match against updated url while chaining (#15112))
  • Loading branch information
rwoll committed Jun 27, 2022
1 parent 84d94a3 commit 7b424eb
Show file tree
Hide file tree
Showing 14 changed files with 2,035 additions and 70 deletions.
34 changes: 20 additions & 14 deletions playwright/_impl/_browser_context.py
Expand Up @@ -96,8 +96,11 @@ def __init__(
)
self._channel.on(
"route",
lambda params: self._on_route(
from_channel(params.get("route")), from_channel(params.get("request"))
lambda params: asyncio.create_task(
self._on_route(
from_channel(params.get("route")),
from_channel(params.get("request")),
)
),
)

Expand Down Expand Up @@ -156,18 +159,21 @@ def _on_page(self, page: Page) -> None:
if page._opener and not page._opener.is_closed():
page._opener.emit(Page.Events.Popup, page)

def _on_route(self, route: Route, request: Request) -> None:
for handler_entry in self._routes:
if handler_entry.matches(request.url):
try:
handler_entry.handle(route, request)
finally:
if not handler_entry.is_active:
self._routes.remove(handler_entry)
if not len(self._routes) == 0:
asyncio.create_task(self._disable_interception())
break
route._internal_continue()
async def _on_route(self, route: Route, request: Request) -> None:
route_handlers = self._routes.copy()
for route_handler in route_handlers:
if not route_handler.matches(request.url):
continue
if route_handler.will_expire:
self._routes.remove(route_handler)
try:
handled = await route_handler.handle(route, request)
finally:
if len(self._routes) == 0:
asyncio.create_task(self._disable_interception())
if handled:
return
await route._internal_continue(is_internal=True)

def _on_binding(self, binding_call: BindingCall) -> None:
func = self._bindings.get(binding_call._initializer["name"])
Expand Down
20 changes: 13 additions & 7 deletions playwright/_impl/_helper.py
Expand Up @@ -76,11 +76,11 @@ class ErrorPayload(TypedDict, total=False):
value: Optional[Any]


class ContinueParameters(TypedDict, total=False):
class FallbackOverrideParameters(TypedDict, total=False):
url: Optional[str]
method: Optional[str]
headers: Optional[List[NameValue]]
postData: Optional[str]
headers: Optional[Dict[str, str]]
postData: Optional[Union[str, bytes]]


class ParsedMessageParams(TypedDict):
Expand Down Expand Up @@ -225,14 +225,17 @@ def __init__(
def matches(self, request_url: str) -> bool:
return self.matcher.matches(request_url)

def handle(self, route: "Route", request: "Request") -> None:
async def handle(self, route: "Route", request: "Request") -> bool:
handled_future = route._start_handling()
handler_task = []

def impl() -> None:
self._handled_count += 1
result = cast(
Callable[["Route", "Request"], Union[Coroutine, Any]], self.handler
)(route, request)
if inspect.iscoroutine(result):
asyncio.create_task(result)
handler_task.append(asyncio.create_task(result))

# As with event handlers, each route handler is a potentially blocking context
# so it needs a fiber.
Expand All @@ -242,9 +245,12 @@ def impl() -> None:
else:
impl()

[handled, *_] = await asyncio.gather(handled_future, *handler_task)
return handled

@property
def is_active(self) -> bool:
return self._handled_count < self._times
def will_expire(self) -> bool:
return self._handled_count + 1 >= self._times


def is_safe_close_error(error: Exception) -> bool:
Expand Down
144 changes: 114 additions & 30 deletions playwright/_impl/_network.py
Expand Up @@ -47,14 +47,22 @@
from_nullable_channel,
)
from playwright._impl._event_context_manager import EventContextManagerImpl
from playwright._impl._helper import ContinueParameters, locals_to_params
from playwright._impl._helper import FallbackOverrideParameters, locals_to_params
from playwright._impl._wait_helper import WaitHelper

if TYPE_CHECKING: # pragma: no cover
from playwright._impl._fetch import APIResponse
from playwright._impl._frame import Frame


def serialize_headers(headers: Dict[str, str]) -> HeadersArray:
return [
{"name": name, "value": value}
for name, value in headers.items()
if value is not None
]


class Request(ChannelOwner):
def __init__(
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
Expand All @@ -80,21 +88,31 @@ def __init__(
}
self._provisional_headers = RawHeaders(self._initializer["headers"])
self._all_headers_future: Optional[asyncio.Future[RawHeaders]] = None
self._fallback_overrides: FallbackOverrideParameters = (
FallbackOverrideParameters()
)

def __repr__(self) -> str:
return f"<Request url={self.url!r} method={self.method!r}>"

def _apply_fallback_overrides(self, overrides: FallbackOverrideParameters) -> None:
self._fallback_overrides = cast(
FallbackOverrideParameters, {**self._fallback_overrides, **overrides}
)

@property
def url(self) -> str:
return self._initializer["url"]
return cast(str, self._fallback_overrides.get("url", self._initializer["url"]))

@property
def resource_type(self) -> str:
return self._initializer["resourceType"]

@property
def method(self) -> str:
return self._initializer["method"]
return cast(
str, self._fallback_overrides.get("method", self._initializer["method"])
)

async def sizes(self) -> RequestSizes:
response = await self.response()
Expand All @@ -104,10 +122,10 @@ async def sizes(self) -> RequestSizes:

@property
def post_data(self) -> Optional[str]:
data = self.post_data_buffer
data = self._fallback_overrides.get("postData", self.post_data_buffer)
if not data:
return None
return data.decode()
return data.decode() if isinstance(data, bytes) else data

@property
def post_data_json(self) -> Optional[Any]:
Expand All @@ -124,6 +142,13 @@ def post_data_json(self) -> Optional[Any]:

@property
def post_data_buffer(self) -> Optional[bytes]:
override = self._fallback_overrides.get("post_data")
if override:
return (
override.encode()
if isinstance(override, str)
else cast(bytes, override)
)
b64_content = self._initializer.get("postData")
if b64_content is None:
return None
Expand Down Expand Up @@ -157,6 +182,9 @@ def timing(self) -> ResourceTiming:

@property
def headers(self) -> Headers:
override = self._fallback_overrides.get("headers")
if override:
return RawHeaders._from_headers_dict_lossy(override).headers()
return self._provisional_headers.headers()

async def all_headers(self) -> Headers:
Expand All @@ -169,6 +197,9 @@ async def header_value(self, name: str) -> Optional[str]:
return (await self._actual_headers()).get(name)

async def _actual_headers(self) -> "RawHeaders":
override = self._fallback_overrides.get("headers")
if override:
return RawHeaders(serialize_headers(override))
if not self._all_headers_future:
self._all_headers_future = asyncio.Future()
headers = await self._channel.send("rawRequestHeaders")
Expand All @@ -181,6 +212,21 @@ def __init__(
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
) -> None:
super().__init__(parent, type, guid, initializer)
self._handling_future: Optional[asyncio.Future["bool"]] = None

def _start_handling(self) -> "asyncio.Future[bool]":
self._handling_future = asyncio.Future()
return self._handling_future

def _report_handled(self, done: bool) -> None:
chain = self._handling_future
assert chain
self._handling_future = None
chain.set_result(done)

def _check_not_handled(self) -> None:
if not self._handling_future:
raise Error("Route is already handled!")

def __repr__(self) -> str:
return f"<Route request={self.request}>"
Expand All @@ -203,6 +249,7 @@ async def fulfill(
contentType: str = None,
response: "APIResponse" = None,
) -> None:
self._check_not_handled()
params = locals_to_params(locals())
if response:
del params["response"]
Expand Down Expand Up @@ -247,37 +294,74 @@ async def fulfill(
headers["content-length"] = str(length)
params["headers"] = serialize_headers(headers)
await self._race_with_page_close(self._channel.send("fulfill", params))
self._report_handled(True)

async def continue_(
async def fallback(
self,
url: str = None,
method: str = None,
headers: Dict[str, str] = None,
postData: Union[str, bytes] = None,
) -> None:
overrides: ContinueParameters = {}
if url:
overrides["url"] = url
if method:
overrides["method"] = method
if headers:
overrides["headers"] = serialize_headers(headers)
if isinstance(postData, str):
overrides["postData"] = base64.b64encode(postData.encode()).decode()
elif isinstance(postData, bytes):
overrides["postData"] = base64.b64encode(postData).decode()
await self._race_with_page_close(
self._channel.send("continue", cast(Any, overrides))
)
overrides = cast(FallbackOverrideParameters, locals_to_params(locals()))
self._check_not_handled()
self.request._apply_fallback_overrides(overrides)
self._report_handled(False)

def _internal_continue(self) -> None:
async def continue_(
self,
url: str = None,
method: str = None,
headers: Dict[str, str] = None,
postData: Union[str, bytes] = None,
) -> None:
overrides = cast(FallbackOverrideParameters, locals_to_params(locals()))
self._check_not_handled()
self.request._apply_fallback_overrides(overrides)
await self._internal_continue()
self._report_handled(True)

def _internal_continue(
self, is_internal: bool = False
) -> Coroutine[Any, Any, None]:
async def continue_route() -> None:
try:
await self.continue_()
except Exception:
pass

asyncio.create_task(continue_route())
post_data_for_wire: Optional[str] = None
post_data_from_overrides = self.request._fallback_overrides.get(
"postData"
)
if post_data_from_overrides is not None:
post_data_for_wire = (
base64.b64encode(post_data_from_overrides.encode()).decode()
if isinstance(post_data_from_overrides, str)
else base64.b64encode(post_data_from_overrides).decode()
)
params = locals_to_params(
cast(Dict[str, str], self.request._fallback_overrides)
)
if "headers" in params:
params["headers"] = serialize_headers(params["headers"])
if post_data_for_wire is not None:
params["postData"] = post_data_for_wire
await self._race_with_page_close(
self._channel.send(
"continue",
params,
)
)
except Exception as e:
if not is_internal:
raise e

return continue_route()

# FIXME: Port corresponding tests, and call this method
async def _redirected_navigation_request(self, url: str) -> None:
self._check_not_handled()
await self._race_with_page_close(
self._channel.send("redirectNavigationRequest", {"url": url})
)
self._report_handled(True)

async def _race_with_page_close(self, future: Coroutine) -> None:
if hasattr(self.request.frame, "_page"):
Expand Down Expand Up @@ -484,17 +568,17 @@ def _on_close(self) -> None:
self.emit(WebSocket.Events.Close, self)


def serialize_headers(headers: Dict[str, str]) -> HeadersArray:
return [{"name": name, "value": value} for name, value in headers.items()]


class RawHeaders:
def __init__(self, headers: HeadersArray) -> None:
self._headers_array = headers
self._headers_map: Dict[str, Dict[str, bool]] = defaultdict(dict)
for header in headers:
self._headers_map[header["name"].lower()][header["value"]] = True

@staticmethod
def _from_headers_dict_lossy(headers: Dict[str, str]) -> "RawHeaders":
return RawHeaders(serialize_headers(headers))

def get(self, name: str) -> Optional[str]:
values = self.get_all(name)
if not values:
Expand Down

0 comments on commit 7b424eb

Please sign in to comment.