diff --git a/playwright/_impl/_browser_context.py b/playwright/_impl/_browser_context.py index a28e2b27b..8424bef43 100644 --- a/playwright/_impl/_browser_context.py +++ b/playwright/_impl/_browser_context.py @@ -13,7 +13,6 @@ # limitations under the License. import asyncio -import inspect import json from pathlib import Path from types import SimpleNamespace @@ -154,20 +153,14 @@ def _on_page(self, page: Page) -> None: page._opener.emit(Page.Events.Popup, page) def _on_route(self, route: Route, request: Request) -> None: - handled = False for handler_entry in self._routes: if handler_entry.matches(request.url): - result = handler_entry.handle(route, request) - if inspect.iscoroutine(result): - asyncio.create_task(result) - handled = True + if handler_entry.handle(route, request): + self._routes.remove(handler_entry) + if not len(self._routes) == 0: + asyncio.create_task(self._disable_interception()) break - if not handled: - asyncio.create_task(route.continue_()) - else: - self._routes = list( - filter(lambda route: route.expired() is False, self._routes) - ) + route._internal_continue() def _on_binding(self, binding_call: BindingCall) -> None: func = self._bindings.get(binding_call._initializer["name"]) @@ -279,9 +272,10 @@ async def unroute( ) ) if len(self._routes) == 0: - await self._channel.send( - "setNetworkInterceptionEnabled", dict(enabled=False) - ) + await self._disable_interception() + + async def _disable_interception(self) -> None: + await self._channel.send("setNetworkInterceptionEnabled", dict(enabled=False)) def expect_event( self, diff --git a/playwright/_impl/_helper.py b/playwright/_impl/_helper.py index 282b4472f..87cb7ca9d 100644 --- a/playwright/_impl/_helper.py +++ b/playwright/_impl/_helper.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio import fnmatch +import inspect import math import os import re @@ -206,25 +207,26 @@ def __init__( self, matcher: URLMatcher, handler: RouteHandlerCallback, - times: Optional[int], + times: Optional[int] = None, ): self.matcher = matcher self.handler = handler - self._times = times + self._times = times if times else 2 ** 32 self._handled_count = 0 - def expired(self) -> bool: - return self._times is not None and self._handled_count >= self._times - def matches(self, request_url: str) -> bool: return self.matcher.matches(request_url) - def handle(self, route: "Route", request: "Request") -> Union[Coroutine, Any]: - if self._times: + def handle(self, route: "Route", request: "Request") -> bool: + try: + result = cast( + Callable[["Route", "Request"], Union[Coroutine, Any]], self.handler + )(route, request) + if inspect.iscoroutine(result): + asyncio.create_task(result) + finally: self._handled_count += 1 - return cast( - Callable[["Route", "Request"], Union[Coroutine, Any]], self.handler - )(route, request) + return self._handled_count >= self._times def is_safe_close_error(error: Exception) -> bool: diff --git a/playwright/_impl/_network.py b/playwright/_impl/_network.py index e0c769d35..b480e16a6 100644 --- a/playwright/_impl/_network.py +++ b/playwright/_impl/_network.py @@ -238,6 +238,15 @@ async def continue_( overrides["postData"] = base64.b64encode(postData).decode() await self._channel.send("continue", cast(Any, overrides)) + def _internal_continue(self) -> None: + async def continue_route() -> None: + try: + await self.continue_() + except Exception: + pass + + asyncio.create_task(continue_route()) + class Response(ChannelOwner): def __init__( diff --git a/playwright/_impl/_page.py b/playwright/_impl/_page.py index b07e5733e..022cdb2de 100644 --- a/playwright/_impl/_page.py +++ b/playwright/_impl/_page.py @@ -212,20 +212,14 @@ def _on_frame_detached(self, frame: Frame) -> None: self.emit(Page.Events.FrameDetached, frame) def _on_route(self, route: Route, request: Request) -> None: - handled = False for handler_entry in self._routes: if handler_entry.matches(request.url): - result = handler_entry.handle(route, request) - if inspect.iscoroutine(result): - asyncio.create_task(result) - handled = True - break - if not handled: - self._browser_context._on_route(route, request) - else: - self._routes = list( - filter(lambda route: route.expired() is False, self._routes) - ) + if handler_entry.handle(route, request): + self._routes.remove(handler_entry) + if len(self._routes) == 0: + asyncio.create_task(self._disable_interception()) + return + self._browser_context._on_route(route, request) def _on_binding(self, binding_call: "BindingCall") -> None: func = self._bindings.get(binding_call._initializer["name"]) @@ -575,9 +569,10 @@ async def unroute( ) ) if len(self._routes) == 0: - await self._channel.send( - "setNetworkInterceptionEnabled", dict(enabled=False) - ) + await self._disable_interception() + + async def _disable_interception(self) -> None: + await self._channel.send("setNetworkInterceptionEnabled", dict(enabled=False)) async def screenshot( self,