diff --git a/playwright/_impl/_browser_context.py b/playwright/_impl/_browser_context.py index e3c12fd61..a5b81bb04 100644 --- a/playwright/_impl/_browser_context.py +++ b/playwright/_impl/_browser_context.py @@ -96,8 +96,11 @@ def __init__( ) self._channel.on( "route", - lambda params: self._on_route( - from_channel(params.get("route")), from_channel(params.get("request")) + lambda params: asyncio.create_task( + self._on_route( + from_channel(params.get("route")), + from_channel(params.get("request")), + ) ), ) @@ -156,18 +159,21 @@ def _on_page(self, page: Page) -> None: if page._opener and not page._opener.is_closed(): page._opener.emit(Page.Events.Popup, page) - def _on_route(self, route: Route, request: Request) -> None: - for handler_entry in self._routes: - if handler_entry.matches(request.url): - try: - handler_entry.handle(route, request) - finally: - if not handler_entry.is_active: - self._routes.remove(handler_entry) - if not len(self._routes) == 0: - asyncio.create_task(self._disable_interception()) - break - route._internal_continue() + async def _on_route(self, route: Route, request: Request) -> None: + route_handlers = self._routes.copy() + for route_handler in route_handlers: + if not route_handler.matches(request.url): + continue + if route_handler.will_expire: + self._routes.remove(route_handler) + try: + handled = await route_handler.handle(route, request) + finally: + if len(self._routes) == 0: + asyncio.create_task(self._disable_interception()) + if handled: + return + await route._internal_continue(is_internal=True) def _on_binding(self, binding_call: BindingCall) -> None: func = self._bindings.get(binding_call._initializer["name"]) diff --git a/playwright/_impl/_helper.py b/playwright/_impl/_helper.py index 71d55b917..0a80e76e3 100644 --- a/playwright/_impl/_helper.py +++ b/playwright/_impl/_helper.py @@ -76,11 +76,11 @@ class ErrorPayload(TypedDict, total=False): value: Optional[Any] -class ContinueParameters(TypedDict, total=False): +class FallbackOverrideParameters(TypedDict, total=False): url: Optional[str] method: Optional[str] - headers: Optional[List[NameValue]] - postData: Optional[str] + headers: Optional[Dict[str, str]] + postData: Optional[Union[str, bytes]] class ParsedMessageParams(TypedDict): @@ -225,14 +225,17 @@ def __init__( def matches(self, request_url: str) -> bool: return self.matcher.matches(request_url) - def handle(self, route: "Route", request: "Request") -> None: + async def handle(self, route: "Route", request: "Request") -> bool: + handled_future = route._start_handling() + handler_task = [] + def impl() -> None: self._handled_count += 1 result = cast( Callable[["Route", "Request"], Union[Coroutine, Any]], self.handler )(route, request) if inspect.iscoroutine(result): - asyncio.create_task(result) + handler_task.append(asyncio.create_task(result)) # As with event handlers, each route handler is a potentially blocking context # so it needs a fiber. @@ -242,9 +245,12 @@ def impl() -> None: else: impl() + [handled, *_] = await asyncio.gather(handled_future, *handler_task) + return handled + @property - def is_active(self) -> bool: - return self._handled_count < self._times + def will_expire(self) -> bool: + return self._handled_count + 1 >= self._times def is_safe_close_error(error: Exception) -> bool: diff --git a/playwright/_impl/_network.py b/playwright/_impl/_network.py index cf10bbebf..88b820c58 100644 --- a/playwright/_impl/_network.py +++ b/playwright/_impl/_network.py @@ -47,7 +47,7 @@ from_nullable_channel, ) from playwright._impl._event_context_manager import EventContextManagerImpl -from playwright._impl._helper import ContinueParameters, locals_to_params +from playwright._impl._helper import FallbackOverrideParameters, locals_to_params from playwright._impl._wait_helper import WaitHelper if TYPE_CHECKING: # pragma: no cover @@ -55,6 +55,14 @@ from playwright._impl._frame import Frame +def serialize_headers(headers: Dict[str, str]) -> HeadersArray: + return [ + {"name": name, "value": value} + for name, value in headers.items() + if value is not None + ] + + class Request(ChannelOwner): def __init__( self, parent: ChannelOwner, type: str, guid: str, initializer: Dict @@ -80,13 +88,21 @@ def __init__( } self._provisional_headers = RawHeaders(self._initializer["headers"]) self._all_headers_future: Optional[asyncio.Future[RawHeaders]] = None + self._fallback_overrides: FallbackOverrideParameters = ( + FallbackOverrideParameters() + ) def __repr__(self) -> str: return f"" + def _apply_fallback_overrides(self, overrides: FallbackOverrideParameters) -> None: + self._fallback_overrides = cast( + FallbackOverrideParameters, {**self._fallback_overrides, **overrides} + ) + @property def url(self) -> str: - return self._initializer["url"] + return cast(str, self._fallback_overrides.get("url", self._initializer["url"])) @property def resource_type(self) -> str: @@ -94,7 +110,9 @@ def resource_type(self) -> str: @property def method(self) -> str: - return self._initializer["method"] + return cast( + str, self._fallback_overrides.get("method", self._initializer["method"]) + ) async def sizes(self) -> RequestSizes: response = await self.response() @@ -104,10 +122,10 @@ async def sizes(self) -> RequestSizes: @property def post_data(self) -> Optional[str]: - data = self.post_data_buffer + data = self._fallback_overrides.get("postData", self.post_data_buffer) if not data: return None - return data.decode() + return data.decode() if isinstance(data, bytes) else data @property def post_data_json(self) -> Optional[Any]: @@ -124,6 +142,13 @@ def post_data_json(self) -> Optional[Any]: @property def post_data_buffer(self) -> Optional[bytes]: + override = self._fallback_overrides.get("post_data") + if override: + return ( + override.encode() + if isinstance(override, str) + else cast(bytes, override) + ) b64_content = self._initializer.get("postData") if b64_content is None: return None @@ -157,6 +182,9 @@ def timing(self) -> ResourceTiming: @property def headers(self) -> Headers: + override = self._fallback_overrides.get("headers") + if override: + return RawHeaders._from_headers_dict_lossy(override).headers() return self._provisional_headers.headers() async def all_headers(self) -> Headers: @@ -169,6 +197,9 @@ async def header_value(self, name: str) -> Optional[str]: return (await self._actual_headers()).get(name) async def _actual_headers(self) -> "RawHeaders": + override = self._fallback_overrides.get("headers") + if override: + return RawHeaders(serialize_headers(override)) if not self._all_headers_future: self._all_headers_future = asyncio.Future() headers = await self._channel.send("rawRequestHeaders") @@ -181,6 +212,21 @@ def __init__( self, parent: ChannelOwner, type: str, guid: str, initializer: Dict ) -> None: super().__init__(parent, type, guid, initializer) + self._handling_future: Optional[asyncio.Future["bool"]] = None + + def _start_handling(self) -> "asyncio.Future[bool]": + self._handling_future = asyncio.Future() + return self._handling_future + + def _report_handled(self, done: bool) -> None: + chain = self._handling_future + assert chain + self._handling_future = None + chain.set_result(done) + + def _check_not_handled(self) -> None: + if not self._handling_future: + raise Error("Route is already handled!") def __repr__(self) -> str: return f"" @@ -203,6 +249,7 @@ async def fulfill( contentType: str = None, response: "APIResponse" = None, ) -> None: + self._check_not_handled() params = locals_to_params(locals()) if response: del params["response"] @@ -247,37 +294,74 @@ async def fulfill( headers["content-length"] = str(length) params["headers"] = serialize_headers(headers) await self._race_with_page_close(self._channel.send("fulfill", params)) + self._report_handled(True) - async def continue_( + async def fallback( self, url: str = None, method: str = None, headers: Dict[str, str] = None, postData: Union[str, bytes] = None, ) -> None: - overrides: ContinueParameters = {} - if url: - overrides["url"] = url - if method: - overrides["method"] = method - if headers: - overrides["headers"] = serialize_headers(headers) - if isinstance(postData, str): - overrides["postData"] = base64.b64encode(postData.encode()).decode() - elif isinstance(postData, bytes): - overrides["postData"] = base64.b64encode(postData).decode() - await self._race_with_page_close( - self._channel.send("continue", cast(Any, overrides)) - ) + overrides = cast(FallbackOverrideParameters, locals_to_params(locals())) + self._check_not_handled() + self.request._apply_fallback_overrides(overrides) + self._report_handled(False) - def _internal_continue(self) -> None: + async def continue_( + self, + url: str = None, + method: str = None, + headers: Dict[str, str] = None, + postData: Union[str, bytes] = None, + ) -> None: + overrides = cast(FallbackOverrideParameters, locals_to_params(locals())) + self._check_not_handled() + self.request._apply_fallback_overrides(overrides) + await self._internal_continue() + self._report_handled(True) + + def _internal_continue( + self, is_internal: bool = False + ) -> Coroutine[Any, Any, None]: async def continue_route() -> None: try: - await self.continue_() - except Exception: - pass - - asyncio.create_task(continue_route()) + post_data_for_wire: Optional[str] = None + post_data_from_overrides = self.request._fallback_overrides.get( + "postData" + ) + if post_data_from_overrides is not None: + post_data_for_wire = ( + base64.b64encode(post_data_from_overrides.encode()).decode() + if isinstance(post_data_from_overrides, str) + else base64.b64encode(post_data_from_overrides).decode() + ) + params = locals_to_params( + cast(Dict[str, str], self.request._fallback_overrides) + ) + if "headers" in params: + params["headers"] = serialize_headers(params["headers"]) + if post_data_for_wire is not None: + params["postData"] = post_data_for_wire + await self._race_with_page_close( + self._channel.send( + "continue", + params, + ) + ) + except Exception as e: + if not is_internal: + raise e + + return continue_route() + + # FIXME: Port corresponding tests, and call this method + async def _redirected_navigation_request(self, url: str) -> None: + self._check_not_handled() + await self._race_with_page_close( + self._channel.send("redirectNavigationRequest", {"url": url}) + ) + self._report_handled(True) async def _race_with_page_close(self, future: Coroutine) -> None: if hasattr(self.request.frame, "_page"): @@ -484,10 +568,6 @@ def _on_close(self) -> None: self.emit(WebSocket.Events.Close, self) -def serialize_headers(headers: Dict[str, str]) -> HeadersArray: - return [{"name": name, "value": value} for name, value in headers.items()] - - class RawHeaders: def __init__(self, headers: HeadersArray) -> None: self._headers_array = headers @@ -495,6 +575,10 @@ def __init__(self, headers: HeadersArray) -> None: for header in headers: self._headers_map[header["name"].lower()][header["value"]] = True + @staticmethod + def _from_headers_dict_lossy(headers: Dict[str, str]) -> "RawHeaders": + return RawHeaders(serialize_headers(headers)) + def get(self, name: str) -> Optional[str]: values = self.get_all(name) if not values: diff --git a/playwright/_impl/_page.py b/playwright/_impl/_page.py index b13cbba1f..1245ff819 100644 --- a/playwright/_impl/_page.py +++ b/playwright/_impl/_page.py @@ -190,8 +190,10 @@ def __init__( ) self._channel.on( "route", - lambda params: self._on_route( - from_channel(params["route"]), from_channel(params["request"]) + lambda params: asyncio.create_task( + self._on_route( + from_channel(params["route"]), from_channel(params["request"]) + ) ), ) self._channel.on("video", lambda params: self._on_video(params)) @@ -231,22 +233,21 @@ def _on_frame_detached(self, frame: Frame) -> None: frame._detached = True self.emit(Page.Events.FrameDetached, frame) - def _on_route(self, route: Route, request: Request) -> None: - # Make this artificially async so that we could chain routes. - async def inner_route() -> None: - for handler_entry in self._routes: - if handler_entry.matches(request.url): - try: - handler_entry.handle(route, request) - finally: - if not handler_entry.is_active: - self._routes.remove(handler_entry) - if len(self._routes) == 0: - asyncio.create_task(self._disable_interception()) - return - self._browser_context._on_route(route, request) - - asyncio.create_task(inner_route()) + async def _on_route(self, route: Route, request: Request) -> None: + route_handlers = self._routes.copy() + for route_handler in route_handlers: + if not route_handler.matches(request.url): + continue + if route_handler.will_expire: + self._routes.remove(route_handler) + try: + handled = await route_handler.handle(route, request) + finally: + if len(self._routes) == 0: + asyncio.create_task(self._disable_interception()) + if handled: + return + await self._browser_context._on_route(route, request) def _on_binding(self, binding_call: "BindingCall") -> None: func = self._bindings.get(binding_call._initializer["name"]) diff --git a/playwright/async_api/_generated.py b/playwright/async_api/_generated.py index 1eca729b8..a2f2198be 100644 --- a/playwright/async_api/_generated.py +++ b/playwright/async_api/_generated.py @@ -691,6 +691,89 @@ async def fulfill( ) ) + async def fallback( + self, + *, + url: str = None, + method: str = None, + headers: typing.Optional[typing.Dict[str, str]] = None, + post_data: typing.Union[str, bytes] = None + ) -> NoneType: + """Route.fallback + + When several routes match the given pattern, they run in the order opposite to their registration. That way the last + registered route can always override all the previos ones. In the example below, request will be handled by the + bottom-most handler first, then it'll fall back to the previous one and in the end will be aborted by the first + registered route. + + ```py + await page.route(\"**/*\", lambda route: route.abort()) # Runs last. + await page.route(\"**/*\", lambda route: route.fallback()) # Runs second. + await page.route(\"**/*\", lambda route: route.fallback()) # Runs first. + ``` + + Registering multiple routes is useful when you want separate handlers to handle different kinds of requests, for example + API calls vs page resources or GET requests vs POST requests as in the example below. + + ```py + # Handle GET requests. + def handle_post(route): + if route.request.method != \"GET\": + route.fallback() + return + # Handling GET only. + # ... + + # Handle POST requests. + def handle_post(route): + if route.request.method != \"POST\": + route.fallback() + return + # Handling POST only. + # ... + + await page.route(\"**/*\", handle_get) + await page.route(\"**/*\", handle_post) + ``` + + One can also modify request while falling back to the subsequent handler, that way intermediate route handler can modify + url, method, headers and postData of the request. + + ```py + async def handle(route, request): + # override headers + headers = { + **request.headers, + \"foo\": \"foo-value\" # set \"foo\" header + \"bar\": None # remove \"bar\" header + } + await route.fallback(headers=headers) + } + await page.route(\"**/*\", handle) + ``` + + Parameters + ---------- + url : Union[str, NoneType] + If set changes the request URL. New URL must have same protocol as original one. Changing the URL won't affect the route + matching, all the routes are matched using the original request URL. + method : Union[str, NoneType] + If set changes the request method (e.g. GET or POST) + headers : Union[Dict[str, str], NoneType] + If set changes the request HTTP headers. Header values will be converted to a string. + post_data : Union[bytes, str, NoneType] + If set changes the post data of request + """ + + return mapping.from_maybe_impl( + await self._impl_obj.fallback( + url=url, + method=method, + headers=mapping.to_impl(headers), + postData=post_data, + ) + ) + async def continue_( self, *, diff --git a/playwright/sync_api/_generated.py b/playwright/sync_api/_generated.py index 88c3aca7d..56c7c954b 100644 --- a/playwright/sync_api/_generated.py +++ b/playwright/sync_api/_generated.py @@ -701,6 +701,91 @@ def fulfill( ) ) + def fallback( + self, + *, + url: str = None, + method: str = None, + headers: typing.Optional[typing.Dict[str, str]] = None, + post_data: typing.Union[str, bytes] = None + ) -> NoneType: + """Route.fallback + + When several routes match the given pattern, they run in the order opposite to their registration. That way the last + registered route can always override all the previos ones. In the example below, request will be handled by the + bottom-most handler first, then it'll fall back to the previous one and in the end will be aborted by the first + registered route. + + ```py + page.route(\"**/*\", lambda route: route.abort()) # Runs last. + page.route(\"**/*\", lambda route: route.fallback()) # Runs second. + page.route(\"**/*\", lambda route: route.fallback()) # Runs first. + ``` + + Registering multiple routes is useful when you want separate handlers to handle different kinds of requests, for example + API calls vs page resources or GET requests vs POST requests as in the example below. + + ```py + # Handle GET requests. + def handle_post(route): + if route.request.method != \"GET\": + route.fallback() + return + # Handling GET only. + # ... + + # Handle POST requests. + def handle_post(route): + if route.request.method != \"POST\": + route.fallback() + return + # Handling POST only. + # ... + + page.route(\"**/*\", handle_get) + page.route(\"**/*\", handle_post) + ``` + + One can also modify request while falling back to the subsequent handler, that way intermediate route handler can modify + url, method, headers and postData of the request. + + ```py + def handle(route, request): + # override headers + headers = { + **request.headers, + \"foo\": \"foo-value\" # set \"foo\" header + \"bar\": None # remove \"bar\" header + } + route.fallback(headers=headers) + } + page.route(\"**/*\", handle) + ``` + + Parameters + ---------- + url : Union[str, NoneType] + If set changes the request URL. New URL must have same protocol as original one. Changing the URL won't affect the route + matching, all the routes are matched using the original request URL. + method : Union[str, NoneType] + If set changes the request method (e.g. GET or POST) + headers : Union[Dict[str, str], NoneType] + If set changes the request HTTP headers. Header values will be converted to a string. + post_data : Union[bytes, str, NoneType] + If set changes the post data of request + """ + + return mapping.from_maybe_impl( + self._sync( + self._impl_obj.fallback( + url=url, + method=method, + headers=mapping.to_impl(headers), + postData=post_data, + ) + ) + ) + def continue_( self, *, diff --git a/scripts/expected_api_mismatch.txt b/scripts/expected_api_mismatch.txt index 5fad86551..75018a398 100644 --- a/scripts/expected_api_mismatch.txt +++ b/scripts/expected_api_mismatch.txt @@ -21,5 +21,4 @@ Method not implemented: PlaywrightAssertions.expect # Pending 1.23 ports Method not implemented: BrowserContext.route_from_har -Method not implemented: Route.fallback Method not implemented: Page.route_from_har diff --git a/tests/assets/global-var.html b/tests/assets/global-var.html new file mode 100644 index 000000000..52eb94e55 --- /dev/null +++ b/tests/assets/global-var.html @@ -0,0 +1,3 @@ + diff --git a/tests/async/test_browsercontext_request_fallback.py b/tests/async/test_browsercontext_request_fallback.py new file mode 100644 index 000000000..9f583e0ff --- /dev/null +++ b/tests/async/test_browsercontext_request_fallback.py @@ -0,0 +1,383 @@ +# Copyright (c) Microsoft Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + +from playwright.async_api import BrowserContext, Error, Page, Request, Route +from tests.server import Server + + +async def test_should_work(page: Page, context: BrowserContext, server: Server) -> None: + await context.route("**/*", lambda route: asyncio.create_task(route.fallback())) + await page.goto(server.EMPTY_PAGE) + + +async def test_should_fall_back( + page: Page, context: BrowserContext, server: Server +) -> None: + intercepted = [] + await context.route( + "**/empty.html", + lambda route: ( + intercepted.append(1), + asyncio.create_task(route.fallback()), + ), + ) + await context.route( + "**/empty.html", + lambda route: ( + intercepted.append(2), + asyncio.create_task(route.fallback()), + ), + ) + await context.route( + "**/empty.html", + lambda route: ( + intercepted.append(3), + asyncio.create_task(route.fallback()), + ), + ) + + await page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] + + +async def test_should_fall_back_async_delayed( + page: Page, context: BrowserContext, server: Server +) -> None: + intercepted = [] + + def create_handler(i: int): + async def handler(route): + intercepted.append(i) + await asyncio.sleep(0.1) + await route.fallback() + + return handler + + await context.route("**/empty.html", create_handler(1)) + await context.route("**/empty.html", create_handler(2)) + await context.route("**/empty.html", create_handler(3)) + await page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] + + +async def test_should_chain_once( + page: Page, context: BrowserContext, server: Server +) -> None: + await context.route( + "**/madeup.txt", + lambda route: asyncio.create_task( + route.fulfill(status=200, body="fulfilled one") + ), + times=1, + ) + await context.route( + "**/madeup.txt", lambda route: asyncio.create_task(route.fallback()), times=1 + ) + + resp = await page.goto(server.PREFIX + "/madeup.txt") + body = await resp.body() + assert body == b"fulfilled one" + + +async def test_should_not_chain_fulfill( + page: Page, context: BrowserContext, server: Server +) -> None: + failed = [False] + + def handler(route: Route): + failed[0] = True + + await context.route("**/empty.html", handler) + await context.route( + "**/empty.html", + lambda route: asyncio.create_task(route.fulfill(status=200, body="fulfilled")), + ) + await context.route( + "**/empty.html", lambda route: asyncio.create_task(route.fallback()) + ) + + response = await page.goto(server.EMPTY_PAGE) + body = await response.body() + assert body == b"fulfilled" + assert not failed[0] + + +async def test_should_not_chain_abort( + page: Page, + context: BrowserContext, + server: Server, + is_webkit: bool, + is_firefox: bool, +) -> None: + failed = [False] + + def handler(route: Route): + failed[0] = True + + await context.route("**/empty.html", handler) + await context.route( + "**/empty.html", lambda route: asyncio.create_task(route.abort()) + ) + await context.route( + "**/empty.html", lambda route: asyncio.create_task(route.fallback()) + ) + + with pytest.raises(Error) as excinfo: + await page.goto(server.EMPTY_PAGE) + if is_webkit: + assert "Blocked by Web Inspector" in excinfo.value.message + elif is_firefox: + assert "NS_ERROR_FAILURE" in excinfo.value.message + else: + assert "net::ERR_FAILED" in excinfo.value.message + assert not failed[0] + + +async def test_should_fall_back_after_exception( + page: Page, context: BrowserContext, server: Server +) -> None: + await context.route("**/empty.html", lambda route: route.continue_()) + + async def handler(route: Route): + try: + await route.fulfill(response=47) + except Exception: + await route.fallback() + + await context.route("**/empty.html", handler) + + await page.goto(server.EMPTY_PAGE) + + +async def test_should_amend_http_headers( + page: Page, context: BrowserContext, server: Server +) -> None: + values = [] + + async def handler(route: Route): + values.append(route.request.headers.get("foo")) + values.append(await route.request.header_value("FOO")) + await route.continue_() + + await context.route("**/sleep.zzz", handler) + + async def handler_with_header_mods(route: Route): + await route.fallback(headers={**route.request.headers, "FOO": "bar"}) + + await context.route("**/*", handler_with_header_mods) + + await page.goto(server.EMPTY_PAGE) + async with page.expect_request("/sleep.zzz") as request_info: + await page.evaluate("() => fetch('/sleep.zzz')") + request = await request_info.value + values.append(request.headers.get("foo")) + assert values == ["bar", "bar", "bar"] + + +async def test_should_delete_header_with_undefined_value( + page: Page, context: BrowserContext, server: Server +) -> None: + await page.goto(server.EMPTY_PAGE) + server.set_route( + "/something", + lambda r: ( + r.setHeader("Acces-Control-Allow-Origin", "*"), + r.write(b"done"), + r.finish(), + ), + ) + + intercepted_request = [] + + async def capture_and_continue(route: Route, request: Request): + intercepted_request.append(request) + await route.continue_() + + await context.route("**/*", capture_and_continue) + + async def delete_foo_header(route: Route, request: Request): + headers = await request.all_headers() + await route.fallback(headers={**headers, "foo": None}) + + await context.route(server.PREFIX + "/something", delete_foo_header) + + [server_req, text] = await asyncio.gather( + server.wait_for_request("/something"), + page.evaluate( + """ + async url => { + const data = await fetch(url, { + headers: { + foo: 'a', + bar: 'b', + } + }); + return data.text(); + } + """, + server.PREFIX + "/something", + ), + ) + + assert text == "done" + assert not intercepted_request[0].headers.get("foo") + assert intercepted_request[0].headers.get("bar") == "b" + assert not server_req.getHeader("foo") + assert server_req.getHeader("bar") == "b" + + +async def test_should_amend_method( + page: Page, context: BrowserContext, server: Server +) -> None: + await page.goto(server.EMPTY_PAGE) + + method = [] + await context.route( + "**/*", + lambda route: ( + method.append(route.request.method), + asyncio.create_task(route.continue_()), + ), + ) + await context.route( + "**/*", lambda route: asyncio.create_task(route.fallback(method="POST")) + ) + + [request, _] = await asyncio.gather( + server.wait_for_request("/sleep.zzz"), + page.evaluate("() => fetch('/sleep.zzz')"), + ) + + assert method == ["POST"] + assert request.method == b"POST" + + +async def test_should_override_request_url( + page: Page, context: BrowserContext, server: Server +) -> None: + url = [] + await context.route( + "**/global-var.html", + lambda route: ( + url.append(route.request.url), + asyncio.create_task(route.continue_()), + ), + ) + await context.route( + "**/foo", + lambda route: asyncio.create_task( + route.fallback(url=server.PREFIX + "/global-var.html") + ), + ) + + [server_request, response, _] = await asyncio.gather( + server.wait_for_request("/global-var.html"), + page.wait_for_event("response"), + page.goto(server.PREFIX + "/foo"), + ) + + assert url == [server.PREFIX + "/global-var.html"] + assert response.url == server.PREFIX + "/foo" + assert await page.evaluate("() => window['globalVar']") == 123 + assert server_request.uri == b"/global-var.html" + assert server_request.method == b"GET" + + +async def test_should_amend_post_data( + page: Page, context: BrowserContext, server: Server +) -> None: + await page.goto(server.EMPTY_PAGE) + post_data = [] + await context.route( + "**/*", + lambda route: ( + post_data.append(route.request.post_data), + asyncio.create_task(route.continue_()), + ), + ) + await context.route( + "**/*", lambda route: asyncio.create_task(route.fallback(post_data="doggo")) + ) + [server_request, _] = await asyncio.gather( + server.wait_for_request("/sleep.zzz"), + page.evaluate("() => fetch('/sleep.zzz', { method: 'POST', body: 'birdy' })"), + ) + assert post_data == ["doggo"] + assert server_request.post_body == b"doggo" + + +async def test_should_amend_binary_post_data( + page: Page, context: BrowserContext, server: Server +): + await page.goto(server.EMPTY_PAGE) + post_data_buffer = [] + await context.route( + "**/*", + lambda route: ( + post_data_buffer.append(route.request.post_data), + asyncio.create_task(route.continue_()), + ), + ) + await context.route( + "**/*", + lambda route: asyncio.create_task( + route.fallback(post_data=b"\x00\x01\x02\x03\x04") + ), + ) + + [server_request, result] = await asyncio.gather( + server.wait_for_request("/sleep.zzz"), + page.evaluate("fetch('/sleep.zzz', { method: 'POST', body: 'birdy' })"), + ) + # FIXME: should this be bytes? + assert post_data_buffer == ["\x00\x01\x02\x03\x04"] + assert server_request.method == b"POST" + assert server_request.post_body == b"\x00\x01\x02\x03\x04" + + +async def test_should_chain_fallback_into_page( + context: BrowserContext, page: Page, server: Server +) -> None: + intercepted = [] + await context.route( + "**/empty.html", + lambda route: (intercepted.append(1), asyncio.create_task(route.fallback())), + ) + await context.route( + "**/empty.html", + lambda route: (intercepted.append(2), asyncio.create_task(route.fallback())), + ) + await context.route( + "**/empty.html", + lambda route: (intercepted.append(3), asyncio.create_task(route.fallback())), + ) + await page.route( + "**/empty.html", + lambda route: (intercepted.append(4), asyncio.create_task(route.fallback())), + ) + await page.route( + "**/empty.html", + lambda route: (intercepted.append(5), asyncio.create_task(route.fallback())), + ) + await page.route( + "**/empty.html", + lambda route: (intercepted.append(6), asyncio.create_task(route.fallback())), + ) + + await page.goto(server.EMPTY_PAGE) + assert intercepted == [6, 5, 4, 3, 2, 1] diff --git a/tests/async/test_browsercontext_request_intercept.py b/tests/async/test_browsercontext_request_intercept.py new file mode 100644 index 000000000..763073df0 --- /dev/null +++ b/tests/async/test_browsercontext_request_intercept.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from pathlib import Path + +from twisted.web import http + +from playwright.async_api import BrowserContext, Page, Route +from tests.server import Server + + +async def test_should_fulfill_intercepted_response( + page: Page, context: BrowserContext, server: Server +): + async def handle(route: Route): + response = await page.request.fetch(route.request) + await route.fulfill( + response=response, + status=201, + headers={"foo": "bar"}, + content_type="text/plain", + body="Yo, page!", + ) + + await context.route("**/*", handle) + response = await page.goto(server.PREFIX + "/empty.html") + assert response.status == 201 + assert response.headers["foo"] == "bar" + assert response.headers["content-type"] == "text/plain" + assert await page.evaluate("() => document.body.textContent") == "Yo, page!" + + +async def test_should_fulfill_response_with_empty_body( + page: Page, context: BrowserContext, server: Server +): + async def handle(route: Route): + response = await page.request.fetch(route.request) + await route.fulfill( + response=response, status=201, body="", headers={"content-length": "0"} + ) + + await context.route("**/*", handle) + response = await page.goto(server.PREFIX + "/title.html") + assert response.status == 201 + assert await response.text() == "" + + +async def test_should_override_with_defaults_when_intercepted_response_not_provided( + page: Page, context: BrowserContext, server: Server, browser_name: str +): + def server_handler(request: http.Request): + request.setHeader("foo", "bar") + request.write("my content".encode()) + request.finish() + + server.set_route("/empty.html", server_handler) + + async def handle(route: Route): + await page.request.fetch(route.request) + await route.fulfill(status=201) + + await context.route("**/*", handle) + response = await page.goto(server.EMPTY_PAGE) + assert response.status == 201 + assert await response.text() == "" + if browser_name == "webkit": + assert response.headers == {"content-type": "text/plain"} + else: + assert response.headers == {} + + +async def test_should_fulfill_with_any_response( + page: Page, context: BrowserContext, server: Server +): + def server_handler(request: http.Request): + request.setHeader("foo", "bar") + request.write("Woo-hoo".encode()) + request.finish() + + server.set_route("/sample", server_handler) + sample_response = await page.request.get(server.PREFIX + "/sample") + await context.route( + "**/*", + lambda route: route.fulfill( + response=sample_response, status=201, content_type="text/plain" + ), + ) + response = await page.goto(server.EMPTY_PAGE) + assert response.status == 201 + assert await response.text() == "Woo-hoo" + assert response.headers["foo"] == "bar" + + +async def test_should_support_fulfill_after_intercept( + page: Page, context: BrowserContext, server: Server, assetdir: Path +): + request_future = asyncio.create_task(server.wait_for_request("/title.html")) + + async def handle_route(route: Route): + response = await page.request.fetch(route.request) + await route.fulfill(response=response) + + await context.route("**", handle_route) + response = await page.goto(server.PREFIX + "/title.html") + request = await request_future + assert request.uri.decode() == "/title.html" + original = (assetdir / "title.html").read_text() + assert await response.text() == original + + +async def test_should_give_access_to_the_intercepted_response( + page: Page, context: BrowserContext, server: Server +): + await page.goto(server.EMPTY_PAGE) + + route_task = asyncio.Future() + await context.route("**/title.html", lambda route: route_task.set_result(route)) + + eval_task = asyncio.create_task( + page.evaluate("url => fetch(url)", server.PREFIX + "/title.html") + ) + + route = await route_task + response = await page.request.fetch(route.request) + + assert response.status == 200 + assert response.status_text == "OK" + assert response.ok is True + assert response.url.endswith("/title.html") is True + assert response.headers["content-type"] == "text/html; charset=utf-8" + assert list( + filter( + lambda header: header["name"].lower() == "content-type", + response.headers_array, + ) + ) == [{"name": "Content-Type", "value": "text/html; charset=utf-8"}] + + await asyncio.gather( + route.fulfill(response=response), + eval_task, + ) + + +async def test_should_give_access_to_the_intercepted_response_body( + page: Page, context: BrowserContext, server: Server +): + await page.goto(server.EMPTY_PAGE) + + route_task = asyncio.Future() + await context.route("**/simple.json", lambda route: route_task.set_result(route)) + + eval_task = asyncio.create_task( + page.evaluate("url => fetch(url)", server.PREFIX + "/simple.json") + ) + + route = await route_task + response = await page.request.fetch(route.request) + + assert await response.text() == '{"foo": "bar"}\n' + + await asyncio.gather( + route.fulfill(response=response), + eval_task, + ) diff --git a/tests/async/test_page_request_fallback.py b/tests/async/test_page_request_fallback.py new file mode 100644 index 000000000..1196190e2 --- /dev/null +++ b/tests/async/test_page_request_fallback.py @@ -0,0 +1,354 @@ +# Copyright (c) Microsoft Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +import pytest + +from playwright.async_api import Error, Page, Request, Route +from tests.server import Server + + +async def test_should_work(page: Page, server: Server) -> None: + await page.route("**/*", lambda route: asyncio.create_task(route.fallback())) + await page.goto(server.EMPTY_PAGE) + + +async def test_should_fall_back(page: Page, server: Server) -> None: + intercepted = [] + await page.route( + "**/empty.html", + lambda route: ( + intercepted.append(1), + asyncio.create_task(route.fallback()), + ), + ) + await page.route( + "**/empty.html", + lambda route: ( + intercepted.append(2), + asyncio.create_task(route.fallback()), + ), + ) + await page.route( + "**/empty.html", + lambda route: ( + intercepted.append(3), + asyncio.create_task(route.fallback()), + ), + ) + + await page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] + + +async def test_should_fall_back_async_delayed(page: Page, server: Server) -> None: + intercepted = [] + + def create_handler(i: int): + async def handler(route): + intercepted.append(i) + await asyncio.sleep(0.1) + await route.fallback() + + return handler + + await page.route("**/empty.html", create_handler(1)) + await page.route("**/empty.html", create_handler(2)) + await page.route("**/empty.html", create_handler(3)) + await page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] + + +async def test_should_chain_once(page: Page, server: Server) -> None: + await page.route( + "**/madeup.txt", + lambda route: asyncio.create_task( + route.fulfill(status=200, body="fulfilled one") + ), + times=1, + ) + await page.route( + "**/madeup.txt", lambda route: asyncio.create_task(route.fallback()), times=1 + ) + + resp = await page.goto(server.PREFIX + "/madeup.txt") + body = await resp.body() + assert body == b"fulfilled one" + + +async def test_should_not_chain_fulfill(page: Page, server: Server) -> None: + failed = [False] + + def handler(route: Route): + failed[0] = True + + await page.route("**/empty.html", handler) + await page.route( + "**/empty.html", + lambda route: asyncio.create_task(route.fulfill(status=200, body="fulfilled")), + ) + await page.route( + "**/empty.html", lambda route: asyncio.create_task(route.fallback()) + ) + + response = await page.goto(server.EMPTY_PAGE) + body = await response.body() + assert body == b"fulfilled" + assert not failed[0] + + +async def test_should_not_chain_abort( + page: Page, server: Server, is_webkit: bool, is_firefox: bool +) -> None: + failed = [False] + + def handler(route: Route): + failed[0] = True + + await page.route("**/empty.html", handler) + await page.route("**/empty.html", lambda route: asyncio.create_task(route.abort())) + await page.route( + "**/empty.html", lambda route: asyncio.create_task(route.fallback()) + ) + + with pytest.raises(Error) as excinfo: + await page.goto(server.EMPTY_PAGE) + if is_webkit: + assert "Blocked by Web Inspector" in excinfo.value.message + elif is_firefox: + assert "NS_ERROR_FAILURE" in excinfo.value.message + else: + assert "net::ERR_FAILED" in excinfo.value.message + assert not failed[0] + + +async def test_should_fall_back_after_exception(page: Page, server: Server) -> None: + await page.route("**/empty.html", lambda route: route.continue_()) + + async def handler(route: Route): + try: + await route.fulfill(response=47) + except Exception: + await route.fallback() + + await page.route("**/empty.html", handler) + + await page.goto(server.EMPTY_PAGE) + + +async def test_should_amend_http_headers(page: Page, server: Server) -> None: + values = [] + + async def handler(route: Route): + values.append(route.request.headers.get("foo")) + values.append(await route.request.header_value("FOO")) + await route.continue_() + + await page.route("**/sleep.zzz", handler) + + async def handler_with_header_mods(route: Route): + await route.fallback(headers={**route.request.headers, "FOO": "bar"}) + + await page.route("**/*", handler_with_header_mods) + + await page.goto(server.EMPTY_PAGE) + async with page.expect_request("/sleep.zzz") as request_info: + await page.evaluate("() => fetch('/sleep.zzz')") + request = await request_info.value + values.append(request.headers.get("foo")) + assert values == ["bar", "bar", "bar"] + + +async def test_should_delete_header_with_undefined_value( + page: Page, server: Server +) -> None: + await page.goto(server.EMPTY_PAGE) + server.set_route( + "/something", + lambda r: ( + r.setHeader("Acces-Control-Allow-Origin", "*"), + r.write(b"done"), + r.finish(), + ), + ) + + intercepted_request = [] + + async def capture_and_continue(route: Route, request: Request): + intercepted_request.append(request) + await route.continue_() + + await page.route("**/*", capture_and_continue) + + async def delete_foo_header(route: Route, request: Request): + headers = await request.all_headers() + await route.fallback(headers={**headers, "foo": None}) + + await page.route(server.PREFIX + "/something", delete_foo_header) + + [server_req, text] = await asyncio.gather( + server.wait_for_request("/something"), + page.evaluate( + """ + async url => { + const data = await fetch(url, { + headers: { + foo: 'a', + bar: 'b', + } + }); + return data.text(); + } + """, + server.PREFIX + "/something", + ), + ) + + assert text == "done" + assert not intercepted_request[0].headers.get("foo") + assert intercepted_request[0].headers.get("bar") == "b" + assert not server_req.getHeader("foo") + assert server_req.getHeader("bar") == "b" + + +async def test_should_amend_method(page: Page, server: Server) -> None: + await page.goto(server.EMPTY_PAGE) + + method = [] + await page.route( + "**/*", + lambda route: ( + method.append(route.request.method), + asyncio.create_task(route.continue_()), + ), + ) + await page.route( + "**/*", lambda route: asyncio.create_task(route.fallback(method="POST")) + ) + + [request, _] = await asyncio.gather( + server.wait_for_request("/sleep.zzz"), + page.evaluate("() => fetch('/sleep.zzz')"), + ) + + assert method == ["POST"] + assert request.method == b"POST" + + +async def test_should_override_request_url(page: Page, server: Server) -> None: + url = [] + await page.route( + "**/global-var.html", + lambda route: ( + url.append(route.request.url), + asyncio.create_task(route.continue_()), + ), + ) + await page.route( + "**/foo", + lambda route: asyncio.create_task( + route.fallback(url=server.PREFIX + "/global-var.html") + ), + ) + + [server_request, response, _] = await asyncio.gather( + server.wait_for_request("/global-var.html"), + page.wait_for_event("response"), + page.goto(server.PREFIX + "/foo"), + ) + + assert url == [server.PREFIX + "/global-var.html"] + assert response.url == server.PREFIX + "/foo" + assert await page.evaluate("() => window['globalVar']") == 123 + assert server_request.uri == b"/global-var.html" + assert server_request.method == b"GET" + + +async def test_should_amend_post_data(page: Page, server: Server) -> None: + await page.goto(server.EMPTY_PAGE) + post_data = [] + await page.route( + "**/*", + lambda route: ( + post_data.append(route.request.post_data), + asyncio.create_task(route.continue_()), + ), + ) + await page.route( + "**/*", lambda route: asyncio.create_task(route.fallback(post_data="doggo")) + ) + [server_request, _] = await asyncio.gather( + server.wait_for_request("/sleep.zzz"), + page.evaluate("() => fetch('/sleep.zzz', { method: 'POST', body: 'birdy' })"), + ) + assert post_data == ["doggo"] + assert server_request.post_body == b"doggo" + + +async def test_should_amend_binary_post_data(page, server): + await page.goto(server.EMPTY_PAGE) + post_data_buffer = [] + await page.route( + "**/*", + lambda route: ( + post_data_buffer.append(route.request.post_data), + asyncio.create_task(route.continue_()), + ), + ) + await page.route( + "**/*", + lambda route: asyncio.create_task( + route.fallback(post_data=b"\x00\x01\x02\x03\x04") + ), + ) + + [server_request, result] = await asyncio.gather( + server.wait_for_request("/sleep.zzz"), + page.evaluate("fetch('/sleep.zzz', { method: 'POST', body: 'birdy' })"), + ) + # FIXME: should this be bytes? + assert post_data_buffer == ["\x00\x01\x02\x03\x04"] + assert server_request.method == b"POST" + assert server_request.post_body == b"\x00\x01\x02\x03\x04" + + +async def test_should_chain_fallback_with_dynamic_url( + server: Server, page: Page +) -> None: + intercepted = [] + await page.route( + "**/bar", + lambda route: ( + intercepted.append(1), + asyncio.create_task(route.fallback(url=server.EMPTY_PAGE)), + ), + ) + await page.route( + "**/foo", + lambda route: ( + intercepted.append(2), + asyncio.create_task(route.fallback(url="http://localhost/bar")), + ), + ) + await page.route( + "**/empty.html", + lambda route: ( + intercepted.append(3), + asyncio.create_task(route.fallback(url="http://localhost/foo")), + ), + ) + + await page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] diff --git a/tests/sync/test_browsercontext_request_fallback.py b/tests/sync/test_browsercontext_request_fallback.py new file mode 100644 index 000000000..aae1b087f --- /dev/null +++ b/tests/sync/test_browsercontext_request_fallback.py @@ -0,0 +1,341 @@ +# Copyright (c) Microsoft Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +# pyright: reportUndefinedVariable=false, reportGeneralTypeIssues=false, reportOptionalMemberAccess=false + +import pytest + +from playwright.sync_api import BrowserContext, Error, Page, Request, Route +from tests.server import Server + + +def test_should_work(page: Page, context: BrowserContext, server: Server) -> None: + context.route("**/*", lambda route: route.fallback()) + page.goto(server.EMPTY_PAGE) + + +def test_should_fall_back(page: Page, context: BrowserContext, server: Server) -> None: + intercepted = [] + context.route( + "**/empty.html", + lambda route: ( + intercepted.append(1), + route.fallback(), + ), + ) + context.route( + "**/empty.html", + lambda route: ( + intercepted.append(2), + route.fallback(), + ), + ) + context.route( + "**/empty.html", + lambda route: ( + intercepted.append(3), + route.fallback(), + ), + ) + + page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] + + +def test_should_fall_back_async_delayed( + page: Page, context: BrowserContext, server: Server +) -> None: + intercepted = [] + + def create_handler(i: int): + def handler(route): + intercepted.append(i) + page.wait_for_timeout(500) + route.fallback() + + return handler + + context.route("**/empty.html", create_handler(1)) + context.route("**/empty.html", create_handler(2)) + context.route("**/empty.html", create_handler(3)) + page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] + + +def test_should_chain_once(page: Page, context: BrowserContext, server: Server) -> None: + context.route( + "**/madeup.txt", + lambda route: route.fulfill(status=200, body="fulfilled one"), + times=1, + ) + context.route("**/madeup.txt", lambda route: route.fallback(), times=1) + + resp = page.goto(server.PREFIX + "/madeup.txt") + body = resp.body() + assert body == b"fulfilled one" + + +def test_should_not_chain_fulfill( + page: Page, context: BrowserContext, server: Server +) -> None: + failed = [False] + + def handler(route: Route): + failed[0] = True + + context.route("**/empty.html", handler) + context.route( + "**/empty.html", + lambda route: route.fulfill(status=200, body="fulfilled"), + ) + context.route("**/empty.html", lambda route: route.fallback()) + + response = page.goto(server.EMPTY_PAGE) + body = response.body() + assert body == b"fulfilled" + assert not failed[0] + + +def test_should_not_chain_abort( + page: Page, + context: BrowserContext, + server: Server, + is_webkit: bool, + is_firefox: bool, +) -> None: + failed = [False] + + def handler(route: Route): + failed[0] = True + + context.route("**/empty.html", handler) + context.route("**/empty.html", lambda route: route.abort()) + context.route("**/empty.html", lambda route: route.fallback()) + + with pytest.raises(Error) as excinfo: + page.goto(server.EMPTY_PAGE) + if is_webkit: + assert "Blocked by Web Inspector" in excinfo.value.message + elif is_firefox: + assert "NS_ERROR_FAILURE" in excinfo.value.message + else: + assert "net::ERR_FAILED" in excinfo.value.message + assert not failed[0] + + +def test_should_fall_back_after_exception( + page: Page, context: BrowserContext, server: Server +) -> None: + context.route("**/empty.html", lambda route: route.continue_()) + + def handler(route: Route): + try: + route.fulfill(response=47) + except Exception: + route.fallback() + + context.route("**/empty.html", handler) + + page.goto(server.EMPTY_PAGE) + + +def test_should_amend_http_headers( + page: Page, context: BrowserContext, server: Server +) -> None: + values = [] + + def handler(route: Route): + values.append(route.request.headers.get("foo")) + values.append(route.request.header_value("FOO")) + route.continue_() + + context.route("**/sleep.zzz", handler) + + def handler_with_header_mods(route: Route): + route.fallback(headers={**route.request.headers, "FOO": "bar"}) + + context.route("**/*", handler_with_header_mods) + + page.goto(server.EMPTY_PAGE) + with page.expect_request("/sleep.zzz") as request_info: + page.evaluate("() => fetch('/sleep.zzz')") + request = request_info.value + values.append(request.headers.get("foo")) + assert values == ["bar", "bar", "bar"] + + +def test_should_delete_header_with_undefined_value( + page: Page, context: BrowserContext, server: Server +) -> None: + page.goto(server.EMPTY_PAGE) + server.set_route( + "/something", + lambda r: ( + r.setHeader("Acces-Control-Allow-Origin", "*"), + r.write(b"done"), + r.finish(), + ), + ) + + intercepted_request = [] + + def capture_and_continue(route: Route, request: Request): + intercepted_request.append(request) + route.continue_() + + context.route("**/*", capture_and_continue) + + def delete_foo_header(route: Route, request: Request): + headers = request.all_headers() + route.fallback(headers={**headers, "foo": None}) + + context.route(server.PREFIX + "/something", delete_foo_header) + with server.expect_request("/something") as server_req_info: + text = page.evaluate( + """ + async url => { + const data = await fetch(url, { + headers: { + foo: 'a', + bar: 'b', + } + }); + return data.text(); + } + """, + server.PREFIX + "/something", + ) + server_req = server_req_info.value + assert text == "done" + assert not intercepted_request[0].headers.get("foo") + assert intercepted_request[0].headers.get("bar") == "b" + assert not server_req.getHeader("foo") + assert server_req.getHeader("bar") == "b" + + +def test_should_amend_method( + page: Page, context: BrowserContext, server: Server +) -> None: + page.goto(server.EMPTY_PAGE) + method = [] + context.route( + "**/*", + lambda route: ( + method.append(route.request.method), + route.continue_(), + ), + ) + context.route("**/*", lambda route: route.fallback(method="POST")) + + with server.expect_request("/sleep.zzz") as request_info: + page.evaluate("() => fetch('/sleep.zzz')") + request = request_info.value + assert method == ["POST"] + assert request.method == b"POST" + + +def test_should_override_request_url( + page: Page, context: BrowserContext, server: Server +) -> None: + url = [] + context.route( + "**/global-var.html", + lambda route: ( + url.append(route.request.url), + route.continue_(), + ), + ) + context.route( + "**/foo", + lambda route: route.fallback(url=server.PREFIX + "/global-var.html"), + ) + + with server.expect_request("/global-var.html") as server_request_info: + with page.expect_event("response") as response_info: + page.goto(server.PREFIX + "/foo") + server_request = server_request_info.value + response = response_info.value + assert url == [server.PREFIX + "/global-var.html"] + assert response.url == server.PREFIX + "/foo" + assert page.evaluate("() => window['globalVar']") == 123 + assert server_request.uri == b"/global-var.html" + assert server_request.method == b"GET" + + +def test_should_amend_post_data( + page: Page, context: BrowserContext, server: Server +) -> None: + page.goto(server.EMPTY_PAGE) + post_data = [] + context.route( + "**/*", + lambda route: ( + post_data.append(route.request.post_data), + route.continue_(), + ), + ) + context.route("**/*", lambda route: route.fallback(post_data="doggo")) + + with server.expect_request("/sleep.zzz") as server_request_info: + page.evaluate("() => fetch('/sleep.zzz', { method: 'POST', body: 'birdy' })"), + server_request = server_request_info.value + assert post_data == ["doggo"] + assert server_request.post_body == b"doggo" + + +def test_should_amend_binary_post_data( + page: Page, context: BrowserContext, server: Server +) -> None: + page.goto(server.EMPTY_PAGE) + post_data_buffer = [] + context.route( + "**/*", + lambda route: ( + post_data_buffer.append(route.request.post_data), + route.continue_(), + ), + ) + context.route( + "**/*", lambda route: route.fallback(post_data=b"\x00\x01\x02\x03\x04") + ) + + with server.expect_request("/sleep.zzz") as server_request_info: + page.evaluate("() => fetch('/sleep.zzz', { method: 'POST', body: 'birdy' })") + server_request = server_request_info.value + # FIXME: should this be bytes? + assert post_data_buffer == ["\x00\x01\x02\x03\x04"] + assert server_request.method == b"POST" + assert server_request.post_body == b"\x00\x01\x02\x03\x04" + + +def test_should_chain_fallback_into_page( + context: BrowserContext, page: Page, server: Server +) -> None: + intercepted = [] + context.route( + "**/empty.html", lambda route: (intercepted.append(1), route.fallback()) + ) + context.route( + "**/empty.html", lambda route: (intercepted.append(2), route.fallback()) + ) + context.route( + "**/empty.html", lambda route: (intercepted.append(3), route.fallback()) + ) + page.route("**/empty.html", lambda route: (intercepted.append(4), route.fallback())) + page.route("**/empty.html", lambda route: (intercepted.append(5), route.fallback())) + page.route("**/empty.html", lambda route: (intercepted.append(6), route.fallback())) + + page.goto(server.EMPTY_PAGE) + assert intercepted == [6, 5, 4, 3, 2, 1] diff --git a/tests/sync/test_browsercontext_request_intercept.py b/tests/sync/test_browsercontext_request_intercept.py new file mode 100644 index 000000000..b136038ec --- /dev/null +++ b/tests/sync/test_browsercontext_request_intercept.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from twisted.web import http + +from playwright.sync_api import BrowserContext, Page, Route +from tests.server import Server + + +def test_should_fulfill_intercepted_response( + page: Page, context: BrowserContext, server: Server +) -> None: + def handle(route: Route) -> None: + response = page.request.fetch(route.request) + route.fulfill( + response=response, + status=201, + headers={"foo": "bar"}, + content_type="text/plain", + body="Yo, page!", + ) + + context.route("**/*", handle) + response = page.goto(server.PREFIX + "/empty.html") + assert response + assert response.status == 201 + assert response.headers["foo"] == "bar" + assert response.headers["content-type"] == "text/plain" + assert page.evaluate("() => document.body.textContent") == "Yo, page!" + + +def test_should_fulfill_response_with_empty_body( + page: Page, context: BrowserContext, server: Server +) -> None: + def handle(route: Route) -> None: + response = page.request.fetch(route.request) + route.fulfill( + response=response, status=201, body="", headers={"content-length": "0"} + ) + + context.route("**/*", handle) + response = page.goto(server.PREFIX + "/title.html") + assert response + assert response.status == 201 + assert response.text() == "" + + +def test_should_override_with_defaults_when_intercepted_response_not_provided( + page: Page, context: BrowserContext, server: Server, browser_name: str +) -> None: + def server_handler(request: http.Request) -> None: + request.setHeader("foo", "bar") + request.write("my content".encode()) + request.finish() + + server.set_route("/empty.html", server_handler) + + def handle(route: Route) -> None: + page.request.fetch(route.request) + route.fulfill(status=201) + + context.route("**/*", handle) + response = page.goto(server.EMPTY_PAGE) + assert response + assert response.status == 201 + assert response.text() == "" + if browser_name == "webkit": + assert response.headers == {"content-type": "text/plain"} + else: + assert response.headers == {} + + +def test_should_fulfill_with_any_response( + page: Page, context: BrowserContext, server: Server +) -> None: + def server_handler(request: http.Request) -> None: + request.setHeader("foo", "bar") + request.write("Woo-hoo".encode()) + request.finish() + + server.set_route("/sample", server_handler) + sample_response = page.request.get(server.PREFIX + "/sample") + context.route( + "**/*", + lambda route: route.fulfill( + response=sample_response, status=201, content_type="text/plain" + ), + ) + response = page.goto(server.EMPTY_PAGE) + assert response + assert response.status == 201 + assert response.text() == "Woo-hoo" + assert response.headers["foo"] == "bar" + + +def test_should_support_fulfill_after_intercept( + page: Page, context: BrowserContext, server: Server, assetdir: Path +) -> None: + def handle_route(route: Route) -> None: + response = page.request.fetch(route.request) + route.fulfill(response=response) + + context.route("**", handle_route) + with server.expect_request("/title.html") as request_info: + response = page.goto(server.PREFIX + "/title.html") + assert response + request = request_info.value + assert request.uri.decode() == "/title.html" + original = (assetdir / "title.html").read_text() + assert response.text() == original diff --git a/tests/sync/test_page_request_fallback.py b/tests/sync/test_page_request_fallback.py new file mode 100644 index 000000000..9fd273b94 --- /dev/null +++ b/tests/sync/test_page_request_fallback.py @@ -0,0 +1,321 @@ +# Copyright (c) Microsoft Corporation. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# mypy: ignore-errors +# pyright: reportUndefinedVariable=false, reportGeneralTypeIssues=false, reportOptionalMemberAccess=false + +import pytest + +from playwright.sync_api import Error, Page, Request, Route +from tests.server import Server + + +def test_should_work(page: Page, server: Server) -> None: + page.route("**/*", lambda route: route.fallback()) + page.goto(server.EMPTY_PAGE) + + +def test_should_fall_back(page: Page, server: Server) -> None: + intercepted = [] + page.route( + "**/empty.html", + lambda route: ( + intercepted.append(1), + route.fallback(), + ), + ) + page.route( + "**/empty.html", + lambda route: ( + intercepted.append(2), + route.fallback(), + ), + ) + page.route( + "**/empty.html", + lambda route: ( + intercepted.append(3), + route.fallback(), + ), + ) + + page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] + + +def test_should_fall_back_async_delayed(page: Page, server: Server) -> None: + intercepted = [] + + def create_handler(i: int): + def handler(route): + intercepted.append(i) + page.wait_for_timeout(500) + route.fallback() + + return handler + + page.route("**/empty.html", create_handler(1)) + page.route("**/empty.html", create_handler(2)) + page.route("**/empty.html", create_handler(3)) + page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1] + + +def test_should_chain_once(page: Page, server: Server) -> None: + page.route( + "**/madeup.txt", + lambda route: route.fulfill(status=200, body="fulfilled one"), + times=1, + ) + page.route("**/madeup.txt", lambda route: route.fallback(), times=1) + + resp = page.goto(server.PREFIX + "/madeup.txt") + body = resp.body() + assert body == b"fulfilled one" + + +def test_should_not_chain_fulfill(page: Page, server: Server) -> None: + failed = [False] + + def handler(route: Route): + failed[0] = True + + page.route("**/empty.html", handler) + page.route( + "**/empty.html", + lambda route: route.fulfill(status=200, body="fulfilled"), + ) + page.route("**/empty.html", lambda route: route.fallback()) + + response = page.goto(server.EMPTY_PAGE) + body = response.body() + assert body == b"fulfilled" + assert not failed[0] + + +def test_should_not_chain_abort( + page: Page, server: Server, is_webkit: bool, is_firefox: bool +) -> None: + failed = [False] + + def handler(route: Route): + failed[0] = True + + page.route("**/empty.html", handler) + page.route("**/empty.html", lambda route: route.abort()) + page.route("**/empty.html", lambda route: route.fallback()) + + with pytest.raises(Error) as excinfo: + page.goto(server.EMPTY_PAGE) + if is_webkit: + assert "Blocked by Web Inspector" in excinfo.value.message + elif is_firefox: + assert "NS_ERROR_FAILURE" in excinfo.value.message + else: + assert "net::ERR_FAILED" in excinfo.value.message + assert not failed[0] + + +def test_should_fall_back_after_exception(page: Page, server: Server) -> None: + page.route("**/empty.html", lambda route: route.continue_()) + + def handler(route: Route): + try: + route.fulfill(response=47) + except Exception: + route.fallback() + + page.route("**/empty.html", handler) + + page.goto(server.EMPTY_PAGE) + + +def test_should_amend_http_headers(page: Page, server: Server) -> None: + values = [] + + def handler(route: Route): + values.append(route.request.headers.get("foo")) + values.append(route.request.header_value("FOO")) + route.continue_() + + page.route("**/sleep.zzz", handler) + + def handler_with_header_mods(route: Route): + route.fallback(headers={**route.request.headers, "FOO": "bar"}) + + page.route("**/*", handler_with_header_mods) + + page.goto(server.EMPTY_PAGE) + with page.expect_request("/sleep.zzz") as request_info: + page.evaluate("() => fetch('/sleep.zzz')") + request = request_info.value + values.append(request.headers.get("foo")) + assert values == ["bar", "bar", "bar"] + + +def test_should_delete_header_with_undefined_value(page: Page, server: Server) -> None: + page.goto(server.EMPTY_PAGE) + server.set_route( + "/something", + lambda r: ( + r.setHeader("Acces-Control-Allow-Origin", "*"), + r.write(b"done"), + r.finish(), + ), + ) + + intercepted_request = [] + + def capture_and_continue(route: Route, request: Request): + intercepted_request.append(request) + route.continue_() + + page.route("**/*", capture_and_continue) + + def delete_foo_header(route: Route, request: Request): + headers = request.all_headers() + route.fallback(headers={**headers, "foo": None}) + + page.route(server.PREFIX + "/something", delete_foo_header) + with server.expect_request("/something") as server_req_info: + text = page.evaluate( + """ + async url => { + const data = await fetch(url, { + headers: { + foo: 'a', + bar: 'b', + } + }); + return data.text(); + } + """, + server.PREFIX + "/something", + ) + server_req = server_req_info.value + assert text == "done" + assert not intercepted_request[0].headers.get("foo") + assert intercepted_request[0].headers.get("bar") == "b" + assert not server_req.getHeader("foo") + assert server_req.getHeader("bar") == "b" + + +def test_should_amend_method(page: Page, server: Server) -> None: + page.goto(server.EMPTY_PAGE) + method = [] + page.route( + "**/*", + lambda route: ( + method.append(route.request.method), + route.continue_(), + ), + ) + page.route("**/*", lambda route: route.fallback(method="POST")) + + with server.expect_request("/sleep.zzz") as request_info: + page.evaluate("() => fetch('/sleep.zzz')") + request = request_info.value + assert method == ["POST"] + assert request.method == b"POST" + + +def test_should_override_request_url(page: Page, server: Server) -> None: + url = [] + page.route( + "**/global-var.html", + lambda route: ( + url.append(route.request.url), + route.continue_(), + ), + ) + page.route( + "**/foo", + lambda route: route.fallback(url=server.PREFIX + "/global-var.html"), + ) + + with server.expect_request("/global-var.html") as server_request_info: + with page.expect_event("response") as response_info: + page.goto(server.PREFIX + "/foo") + server_request = server_request_info.value + response = response_info.value + assert url == [server.PREFIX + "/global-var.html"] + assert response.url == server.PREFIX + "/foo" + assert page.evaluate("() => window['globalVar']") == 123 + assert server_request.uri == b"/global-var.html" + assert server_request.method == b"GET" + + +def test_should_amend_post_data(page: Page, server: Server) -> None: + page.goto(server.EMPTY_PAGE) + post_data = [] + page.route( + "**/*", + lambda route: ( + post_data.append(route.request.post_data), + route.continue_(), + ), + ) + page.route("**/*", lambda route: route.fallback(post_data="doggo")) + + with server.expect_request("/sleep.zzz") as server_request_info: + page.evaluate("() => fetch('/sleep.zzz', { method: 'POST', body: 'birdy' })"), + server_request = server_request_info.value + assert post_data == ["doggo"] + assert server_request.post_body == b"doggo" + + +def test_should_amend_binary_post_data(page: Page, server: Server) -> None: + page.goto(server.EMPTY_PAGE) + post_data_buffer = [] + page.route( + "**/*", + lambda route: ( + post_data_buffer.append(route.request.post_data), + route.continue_(), + ), + ) + page.route("**/*", lambda route: route.fallback(post_data=b"\x00\x01\x02\x03\x04")) + + with server.expect_request("/sleep.zzz") as server_request_info: + page.evaluate("() => fetch('/sleep.zzz', { method: 'POST', body: 'birdy' })") + server_request = server_request_info.value + # FIXME: should this be bytes? + assert post_data_buffer == ["\x00\x01\x02\x03\x04"] + assert server_request.method == b"POST" + assert server_request.post_body == b"\x00\x01\x02\x03\x04" + + +def test_should_chain_fallback_with_dynamic_url(server: Server, page: Page) -> None: + intercepted = [] + page.route( + "**/bar", + lambda route: (intercepted.append(1), route.fallback(url=server.EMPTY_PAGE)), + ) + page.route( + "**/foo", + lambda route: ( + intercepted.append(2), + route.fallback(url="http://localhost/bar"), + ), + ) + page.route( + "**/empty.html", + lambda route: ( + intercepted.append(3), + route.fallback(url="http://localhost/foo"), + ), + ) + + page.goto(server.EMPTY_PAGE) + assert intercepted == [3, 2, 1]