Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ dependencies = [
"anyio>=3.7"
]

[project.optional-dependencies]
test = [
"pytest>=8.0"
]

[project.urls]
Homepage = "https://hurozo.com"
Repository = "https://github.com/hurozo/python-webchannel"
Expand Down
112 changes: 37 additions & 75 deletions src/python_webchannel/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from dataclasses import dataclass
from enum import Enum
from http.cookiejar import CookieJar
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence
from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl
from typing import Any, Dict, List, Mapping, Optional, Sequence
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse

import httpx

from .errors import WebChannelError
from .events import EventType, EventTarget, MessageEvent, get_stat_event_target, Event, Stat, StatEvent
from .events import Event, EventTarget, EventType, MessageEvent, Stat, StatEvent, get_stat_event_target
from .httpcors import generate_encoded_http_headers_overwrite_param
from .options import WebChannelOptions
from .wire import LATEST_CHANNEL_VERSION, RAW_DATA_KEY, QueuedMap, WireV8
Expand Down Expand Up @@ -53,14 +53,13 @@ def __init__(
self._url = url
self._base_url = url
self._options = options.copy()
self._owns_client = http_client is None
self._client = http_client or httpx.AsyncClient()
self._wire = WireV8()
self._state = ChannelState.INIT
self._sid: Optional[str] = None
self._host_prefix: Optional[str] = None
self._http_session_id_param: Optional[str] = (
self._options.http_session_id_param
)
self._http_session_id_param: Optional[str] = self._options.http_session_id_param
self._http_session_id_value: Optional[str] = None
self._channel_version = LATEST_CHANNEL_VERSION
self._client_version = 22
Expand Down Expand Up @@ -95,7 +94,6 @@ def __init__(
self._fetch_headers = dict(self._options.fetch_headers or {})
self._session_id_placeholder = "gsessionid"

# ------------------------------------------------------------------
async def open(self) -> None:
if self._state != ChannelState.INIT:
return
Expand Down Expand Up @@ -129,19 +127,14 @@ async def close(self) -> None:
except asyncio.CancelledError:
pass
self._backchannel_started = False
await self._client.aclose()
if self._owns_client:
await self._client.aclose()
self.dispatch_event(EventType.CLOSE, None)

# ------------------------------------------------------------------
async def _perform_handshake(self) -> None:
params = self._build_base_params()
rid = self._consume_rid()
params.update(
{
"RID": str(rid),
"CVER": str(self._client_version),
}
)
params.update({"RID": str(rid), "CVER": str(self._client_version)})

body, headers = self._build_handshake_payload()
headers.setdefault("X-Client-Protocol", "webchannel")
Expand Down Expand Up @@ -179,10 +172,7 @@ async def _perform_handshake(self) -> None:
status=response.status_code,
body=text[:2048].decode(errors="replace"),
)
raise WebChannelError(
status=str(response.status_code),
message=text.decode(errors="replace"),
)
raise WebChannelError(status=str(response.status_code), message=text.decode(errors="replace"))

buffer = ""
async for chunk in response.aiter_text():
Expand Down Expand Up @@ -292,14 +282,12 @@ async def _post_maps(self, maps: Sequence[QueuedMap]) -> None:
params = self._build_base_params()
rid = self._consume_rid()
aid = self._acknowledged_array_id if self._acknowledged_array_id >= 0 else -1
params.update(
{
"SID": self._sid or "",
"RID": str(rid),
"AID": str(aid),
"CVER": str(self._client_version),
}
)
params.update({
"SID": self._sid or "",
"RID": str(rid),
"AID": str(aid),
"CVER": str(self._client_version),
})
body = self._wire.encode_message_queue(maps, len(maps))
headers = self._build_message_headers()
headers.pop("Content-Type", None)
Expand Down Expand Up @@ -330,10 +318,7 @@ async def _post_maps(self, maps: Sequence[QueuedMap]) -> None:
status=response.status_code,
body=text[:2048].decode(errors="replace"),
)
raise WebChannelError(
status=str(response.status_code),
message=text.decode(errors="replace"),
)
raise WebChannelError(status=str(response.status_code), message=text.decode(errors="replace"))

buffer = ""
async for chunk in response.aiter_text():
Expand All @@ -358,16 +343,11 @@ async def _handle_post_response(self, chunk: str) -> None:
self._acknowledged_array_id = self._last_post_response_array_id
if arrays_outstanding == 0:
return
LOGGER.debug(
"Outstanding backchannel arrays=%s bytes=%s",
arrays_outstanding,
outstanding_bytes,
)
LOGGER.debug("Outstanding backchannel arrays=%s bytes=%s", arrays_outstanding, outstanding_bytes)
if self._detect_buffering_proxy and not self._stats_emitted:
self._stats_target.dispatch_event(Event.STAT_EVENT, StatEvent(Stat.PROXY))
self._stats_emitted = True
elif isinstance(payload, list):
# Forward unexpected array payloads to the normal handler.
for entry in payload:
if isinstance(entry, list) and len(entry) >= 2:
await self._handle_channel_payload(entry[1])
Expand All @@ -376,16 +356,14 @@ async def _run_backchannel(self) -> None:
while not self._closed and self._state == ChannelState.OPENED:
params = self._build_base_params()
aid = self._acknowledged_array_id if self._acknowledged_array_id >= 0 else -1
params.update(
{
"RID": "rpc",
"SID": self._sid or "",
"AID": str(aid),
"CI": "0" if self._enable_streaming else "1",
"TYPE": "xmlhttp",
"CVER": str(self._client_version),
}
)
params.update({
"RID": "rpc",
"SID": self._sid or "",
"AID": str(aid),
"CI": "0" if self._enable_streaming else "1",
"TYPE": "xmlhttp",
"CVER": str(self._client_version),
})
if not self._enable_streaming and self._long_polling_timeout:
params["TO"] = str(self._long_polling_timeout)

Expand All @@ -397,11 +375,7 @@ async def _run_backchannel(self) -> None:
self._apply_fetch_headers(headers)
if "Cookie" not in headers:
self._log("backchannel:missing_cookie", cookies=str(self._client.cookies))
self._log(
"backchannel:request",
url=url,
headers=self._sanitize_headers(headers),
)
self._log("backchannel:request", url=url, headers=self._sanitize_headers(headers))
async with self._client.stream(
"GET",
url,
Expand All @@ -421,17 +395,12 @@ async def _run_backchannel(self) -> None:
status=response.status_code,
body=text[:2048].decode(errors="replace"),
)
raise WebChannelError(
status=str(response.status_code),
message=text.decode(errors="replace"),
)
raise WebChannelError(status=str(response.status_code), message=text.decode(errors="replace"))

buffer = ""
async for chunk in response.aiter_text():
if self._detect_buffering_proxy and not self._stats_emitted:
self._stats_target.dispatch_event(
Event.STAT_EVENT, StatEvent(Stat.NOPROXY)
)
self._stats_target.dispatch_event(Event.STAT_EVENT, StatEvent(Stat.NOPROXY))
self._stats_emitted = True
buffer += chunk
messages, buffer = self._extract_chunks(buffer)
Expand All @@ -455,14 +424,12 @@ def _ensure_backchannel(self) -> None:

async def _send_terminate(self) -> None:
params = self._build_base_params()
params.update(
{
"SID": self._sid or "",
"RID": str(self._consume_rid()),
"TYPE": "terminate",
"CVER": str(self._client_version),
}
)
params.update({
"SID": self._sid or "",
"RID": str(self._consume_rid()),
"TYPE": "terminate",
"CVER": str(self._client_version),
})
url = self._build_url(self._base_url, params)
self._log("terminate:request", url=url)
try:
Expand All @@ -472,7 +439,6 @@ async def _send_terminate(self) -> None:
except httpx.HTTPError:
LOGGER.debug("Failed to send terminate request", exc_info=True)

# ------------------------------------------------------------------
def _normalize_outgoing_message(self, message: Any) -> Dict[str, Any]:
if isinstance(message, dict):
if self._options.send_raw_json:
Expand Down Expand Up @@ -504,10 +470,7 @@ def _build_message_headers(self) -> Dict[str, str]:
return dict(self._message_headers)

def _build_base_params(self) -> Dict[str, str]:
params = {
"VER": str(self._channel_version),
"zx": self._generate_zx(),
}
params = {"VER": str(self._channel_version), "zx": self._generate_zx()}
if self._options.message_url_params:
params.update({k: str(v) for k, v in self._options.message_url_params.items()})
if self._http_session_id_param and self._http_session_id_value:
Expand Down Expand Up @@ -597,17 +560,16 @@ def _refresh_cookie_header(self) -> None:
if isinstance(jar, CookieJar):
for cookie in jar:
items.append(f"{cookie.name}={cookie.value}")
else: # pragma: no cover - fallback for alternate cookie containers
else: # pragma: no cover
try:
items = [f"{key}={value}" for key, value in jar.items()] # type: ignore[attr-defined]
items = [f"{key}={value}" for key, value in jar.items()]
except Exception:
items = []
if items:
if self._sid and not any(cookie_str.startswith("SID=") for cookie_str in items):
items.append(f"SID={self._sid}")
self._cookie_header = "; ".join(items)
elif self._cookie_header:
# Clear cached header if cookies disappeared.
self._cookie_header = None

def _apply_fetch_headers(self, headers: Dict[str, str]) -> None:
Expand Down
98 changes: 98 additions & 0 deletions tests/test_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import asyncio

import httpx

from python_webchannel.channel import ChannelState, WebChannel
from python_webchannel.events import Event, EventType, Stat
from python_webchannel.options import WebChannelOptions


class StubAsyncClient:
def __init__(self):
self.cookies = httpx.Cookies()
self.closed = False
self.posts = []

async def aclose(self):
self.closed = True

async def post(self, *args, **kwargs):
self.posts.append((args, kwargs))
return None


def test_injected_client_is_not_closed_by_channel():
client = StubAsyncClient()
channel = WebChannel("https://example.com/channel", WebChannelOptions(), http_client=client)

asyncio.run(channel.close())

assert client.closed is False


def test_owned_client_is_closed_by_channel():
channel = WebChannel("https://example.com/channel", WebChannelOptions())
client = channel._client

asyncio.run(channel.close())

assert client.is_closed is True


def test_handshake_payload_opens_channel_and_starts_backchannel():
client = StubAsyncClient()
channel = WebChannel("https://example.com/channel", WebChannelOptions(), http_client=client)
channel._state = ChannelState.OPENING
channel._last_array_id = 7
started = []
opened = []

channel.listen(EventType.OPEN, lambda payload: opened.append(payload))
channel._ensure_backchannel = lambda: started.append(True)

asyncio.run(channel._handle_handshake_payload(["c", "SID123", "hostprefix", 8, 1, 10]))

assert channel._state == ChannelState.OPENED
assert channel._sid == "SID123"
assert channel._host_prefix == "hostprefix"
assert channel._acknowledged_array_id == 7
assert started == [True]
assert opened == [None]


def test_post_response_updates_ack_and_emits_proxy_stat_once():
client = StubAsyncClient()
channel = WebChannel(
"https://example.com/channel",
WebChannelOptions(detect_buffering_proxy=True),
http_client=client,
)
stats = []
channel._stats_target.listen(Event.STAT_EVENT, lambda event: stats.append(event.stat))

asyncio.run(channel._handle_post_response("[1,5,128]"))
asyncio.run(channel._handle_post_response("[1,6,64]"))

assert channel._acknowledged_array_id == 6
assert stats == [Stat.PROXY]


def test_message_dispatch_supports_sync_and_async_listeners():
client = StubAsyncClient()
channel = WebChannel("https://example.com/channel", WebChannelOptions(), http_client=client)
received = []

async def async_listener(event):
received.append(("async", event.data))

def sync_listener(event):
received.append(("sync", event.data))

channel.listen(EventType.MESSAGE, sync_listener)
channel.listen(EventType.MESSAGE, async_listener)

asyncio.run(channel._handle_channel_payload(["d", {"__data__": '{"hello":"world"}'}]))
asyncio.run(asyncio.sleep(0))

assert ("sync", {"hello": "world"}) in received
assert ("async", {"hello": "world"}) in received