From 388ef7fd088ddb4ce990839b9071c1d52ff9d3b9 Mon Sep 17 00:00:00 2001 From: Dennis Vink Date: Fri, 24 Apr 2026 15:03:15 +0200 Subject: [PATCH] Fix AsyncClient ownership and add WebChannel tests --- pyproject.toml | 5 ++ src/python_webchannel/channel.py | 112 ++++++++++--------------------- tests/test_channel.py | 98 +++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 75 deletions(-) create mode 100644 tests/test_channel.py diff --git a/pyproject.toml b/pyproject.toml index bc94970..850fbc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,11 @@ dependencies = [ "anyio>=3.7" ] +[project.optional-dependencies] +test = [ + "pytest>=8.0" +] + [project.urls] Homepage = "https://hurozo.com" Repository = "https://github.com/hurozo/python-webchannel" diff --git a/src/python_webchannel/channel.py b/src/python_webchannel/channel.py index 2e1085b..5d7780a 100644 --- a/src/python_webchannel/channel.py +++ b/src/python_webchannel/channel.py @@ -10,13 +10,13 @@ from dataclasses import dataclass from enum import Enum from http.cookiejar import CookieJar -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence -from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl +from typing import Any, Dict, List, Mapping, Optional, Sequence +from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse import httpx from .errors import WebChannelError -from .events import EventType, EventTarget, MessageEvent, get_stat_event_target, Event, Stat, StatEvent +from .events import Event, EventTarget, EventType, MessageEvent, Stat, StatEvent, get_stat_event_target from .httpcors import generate_encoded_http_headers_overwrite_param from .options import WebChannelOptions from .wire import LATEST_CHANNEL_VERSION, RAW_DATA_KEY, QueuedMap, WireV8 @@ -53,14 +53,13 @@ def __init__( self._url = url self._base_url = url self._options = options.copy() + self._owns_client = http_client is None self._client = http_client or httpx.AsyncClient() self._wire = WireV8() self._state = ChannelState.INIT self._sid: Optional[str] = None self._host_prefix: Optional[str] = None - self._http_session_id_param: Optional[str] = ( - self._options.http_session_id_param - ) + self._http_session_id_param: Optional[str] = self._options.http_session_id_param self._http_session_id_value: Optional[str] = None self._channel_version = LATEST_CHANNEL_VERSION self._client_version = 22 @@ -95,7 +94,6 @@ def __init__( self._fetch_headers = dict(self._options.fetch_headers or {}) self._session_id_placeholder = "gsessionid" - # ------------------------------------------------------------------ async def open(self) -> None: if self._state != ChannelState.INIT: return @@ -129,19 +127,14 @@ async def close(self) -> None: except asyncio.CancelledError: pass self._backchannel_started = False - await self._client.aclose() + if self._owns_client: + await self._client.aclose() self.dispatch_event(EventType.CLOSE, None) - # ------------------------------------------------------------------ async def _perform_handshake(self) -> None: params = self._build_base_params() rid = self._consume_rid() - params.update( - { - "RID": str(rid), - "CVER": str(self._client_version), - } - ) + params.update({"RID": str(rid), "CVER": str(self._client_version)}) body, headers = self._build_handshake_payload() headers.setdefault("X-Client-Protocol", "webchannel") @@ -179,10 +172,7 @@ async def _perform_handshake(self) -> None: status=response.status_code, body=text[:2048].decode(errors="replace"), ) - raise WebChannelError( - status=str(response.status_code), - message=text.decode(errors="replace"), - ) + raise WebChannelError(status=str(response.status_code), message=text.decode(errors="replace")) buffer = "" async for chunk in response.aiter_text(): @@ -292,14 +282,12 @@ async def _post_maps(self, maps: Sequence[QueuedMap]) -> None: params = self._build_base_params() rid = self._consume_rid() aid = self._acknowledged_array_id if self._acknowledged_array_id >= 0 else -1 - params.update( - { - "SID": self._sid or "", - "RID": str(rid), - "AID": str(aid), - "CVER": str(self._client_version), - } - ) + params.update({ + "SID": self._sid or "", + "RID": str(rid), + "AID": str(aid), + "CVER": str(self._client_version), + }) body = self._wire.encode_message_queue(maps, len(maps)) headers = self._build_message_headers() headers.pop("Content-Type", None) @@ -330,10 +318,7 @@ async def _post_maps(self, maps: Sequence[QueuedMap]) -> None: status=response.status_code, body=text[:2048].decode(errors="replace"), ) - raise WebChannelError( - status=str(response.status_code), - message=text.decode(errors="replace"), - ) + raise WebChannelError(status=str(response.status_code), message=text.decode(errors="replace")) buffer = "" async for chunk in response.aiter_text(): @@ -358,16 +343,11 @@ async def _handle_post_response(self, chunk: str) -> None: self._acknowledged_array_id = self._last_post_response_array_id if arrays_outstanding == 0: return - LOGGER.debug( - "Outstanding backchannel arrays=%s bytes=%s", - arrays_outstanding, - outstanding_bytes, - ) + LOGGER.debug("Outstanding backchannel arrays=%s bytes=%s", arrays_outstanding, outstanding_bytes) if self._detect_buffering_proxy and not self._stats_emitted: self._stats_target.dispatch_event(Event.STAT_EVENT, StatEvent(Stat.PROXY)) self._stats_emitted = True elif isinstance(payload, list): - # Forward unexpected array payloads to the normal handler. for entry in payload: if isinstance(entry, list) and len(entry) >= 2: await self._handle_channel_payload(entry[1]) @@ -376,16 +356,14 @@ async def _run_backchannel(self) -> None: while not self._closed and self._state == ChannelState.OPENED: params = self._build_base_params() aid = self._acknowledged_array_id if self._acknowledged_array_id >= 0 else -1 - params.update( - { - "RID": "rpc", - "SID": self._sid or "", - "AID": str(aid), - "CI": "0" if self._enable_streaming else "1", - "TYPE": "xmlhttp", - "CVER": str(self._client_version), - } - ) + params.update({ + "RID": "rpc", + "SID": self._sid or "", + "AID": str(aid), + "CI": "0" if self._enable_streaming else "1", + "TYPE": "xmlhttp", + "CVER": str(self._client_version), + }) if not self._enable_streaming and self._long_polling_timeout: params["TO"] = str(self._long_polling_timeout) @@ -397,11 +375,7 @@ async def _run_backchannel(self) -> None: self._apply_fetch_headers(headers) if "Cookie" not in headers: self._log("backchannel:missing_cookie", cookies=str(self._client.cookies)) - self._log( - "backchannel:request", - url=url, - headers=self._sanitize_headers(headers), - ) + self._log("backchannel:request", url=url, headers=self._sanitize_headers(headers)) async with self._client.stream( "GET", url, @@ -421,17 +395,12 @@ async def _run_backchannel(self) -> None: status=response.status_code, body=text[:2048].decode(errors="replace"), ) - raise WebChannelError( - status=str(response.status_code), - message=text.decode(errors="replace"), - ) + raise WebChannelError(status=str(response.status_code), message=text.decode(errors="replace")) buffer = "" async for chunk in response.aiter_text(): if self._detect_buffering_proxy and not self._stats_emitted: - self._stats_target.dispatch_event( - Event.STAT_EVENT, StatEvent(Stat.NOPROXY) - ) + self._stats_target.dispatch_event(Event.STAT_EVENT, StatEvent(Stat.NOPROXY)) self._stats_emitted = True buffer += chunk messages, buffer = self._extract_chunks(buffer) @@ -455,14 +424,12 @@ def _ensure_backchannel(self) -> None: async def _send_terminate(self) -> None: params = self._build_base_params() - params.update( - { - "SID": self._sid or "", - "RID": str(self._consume_rid()), - "TYPE": "terminate", - "CVER": str(self._client_version), - } - ) + params.update({ + "SID": self._sid or "", + "RID": str(self._consume_rid()), + "TYPE": "terminate", + "CVER": str(self._client_version), + }) url = self._build_url(self._base_url, params) self._log("terminate:request", url=url) try: @@ -472,7 +439,6 @@ async def _send_terminate(self) -> None: except httpx.HTTPError: LOGGER.debug("Failed to send terminate request", exc_info=True) - # ------------------------------------------------------------------ def _normalize_outgoing_message(self, message: Any) -> Dict[str, Any]: if isinstance(message, dict): if self._options.send_raw_json: @@ -504,10 +470,7 @@ def _build_message_headers(self) -> Dict[str, str]: return dict(self._message_headers) def _build_base_params(self) -> Dict[str, str]: - params = { - "VER": str(self._channel_version), - "zx": self._generate_zx(), - } + params = {"VER": str(self._channel_version), "zx": self._generate_zx()} if self._options.message_url_params: params.update({k: str(v) for k, v in self._options.message_url_params.items()}) if self._http_session_id_param and self._http_session_id_value: @@ -597,9 +560,9 @@ def _refresh_cookie_header(self) -> None: if isinstance(jar, CookieJar): for cookie in jar: items.append(f"{cookie.name}={cookie.value}") - else: # pragma: no cover - fallback for alternate cookie containers + else: # pragma: no cover try: - items = [f"{key}={value}" for key, value in jar.items()] # type: ignore[attr-defined] + items = [f"{key}={value}" for key, value in jar.items()] except Exception: items = [] if items: @@ -607,7 +570,6 @@ def _refresh_cookie_header(self) -> None: items.append(f"SID={self._sid}") self._cookie_header = "; ".join(items) elif self._cookie_header: - # Clear cached header if cookies disappeared. self._cookie_header = None def _apply_fetch_headers(self, headers: Dict[str, str]) -> None: diff --git a/tests/test_channel.py b/tests/test_channel.py new file mode 100644 index 0000000..7115057 --- /dev/null +++ b/tests/test_channel.py @@ -0,0 +1,98 @@ +import asyncio + +import httpx + +from python_webchannel.channel import ChannelState, WebChannel +from python_webchannel.events import Event, EventType, Stat +from python_webchannel.options import WebChannelOptions + + +class StubAsyncClient: + def __init__(self): + self.cookies = httpx.Cookies() + self.closed = False + self.posts = [] + + async def aclose(self): + self.closed = True + + async def post(self, *args, **kwargs): + self.posts.append((args, kwargs)) + return None + + +def test_injected_client_is_not_closed_by_channel(): + client = StubAsyncClient() + channel = WebChannel("https://example.com/channel", WebChannelOptions(), http_client=client) + + asyncio.run(channel.close()) + + assert client.closed is False + + +def test_owned_client_is_closed_by_channel(): + channel = WebChannel("https://example.com/channel", WebChannelOptions()) + client = channel._client + + asyncio.run(channel.close()) + + assert client.is_closed is True + + +def test_handshake_payload_opens_channel_and_starts_backchannel(): + client = StubAsyncClient() + channel = WebChannel("https://example.com/channel", WebChannelOptions(), http_client=client) + channel._state = ChannelState.OPENING + channel._last_array_id = 7 + started = [] + opened = [] + + channel.listen(EventType.OPEN, lambda payload: opened.append(payload)) + channel._ensure_backchannel = lambda: started.append(True) + + asyncio.run(channel._handle_handshake_payload(["c", "SID123", "hostprefix", 8, 1, 10])) + + assert channel._state == ChannelState.OPENED + assert channel._sid == "SID123" + assert channel._host_prefix == "hostprefix" + assert channel._acknowledged_array_id == 7 + assert started == [True] + assert opened == [None] + + +def test_post_response_updates_ack_and_emits_proxy_stat_once(): + client = StubAsyncClient() + channel = WebChannel( + "https://example.com/channel", + WebChannelOptions(detect_buffering_proxy=True), + http_client=client, + ) + stats = [] + channel._stats_target.listen(Event.STAT_EVENT, lambda event: stats.append(event.stat)) + + asyncio.run(channel._handle_post_response("[1,5,128]")) + asyncio.run(channel._handle_post_response("[1,6,64]")) + + assert channel._acknowledged_array_id == 6 + assert stats == [Stat.PROXY] + + +def test_message_dispatch_supports_sync_and_async_listeners(): + client = StubAsyncClient() + channel = WebChannel("https://example.com/channel", WebChannelOptions(), http_client=client) + received = [] + + async def async_listener(event): + received.append(("async", event.data)) + + def sync_listener(event): + received.append(("sync", event.data)) + + channel.listen(EventType.MESSAGE, sync_listener) + channel.listen(EventType.MESSAGE, async_listener) + + asyncio.run(channel._handle_channel_payload(["d", {"__data__": '{"hello":"world"}'}])) + asyncio.run(asyncio.sleep(0)) + + assert ("sync", {"hello": "world"}) in received + assert ("async", {"hello": "world"}) in received