Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions ld_eventsource/config/connect_strategy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from logging import Logger
from typing import Callable, Iterator, Optional, Union

from urllib3 import PoolManager

from ld_eventsource.http import _HttpClientImpl, _HttpConnectParams
from ld_eventsource.http import (DynamicQueryParams, _HttpClientImpl,
_HttpConnectParams)


class ConnectStrategy:
Expand Down Expand Up @@ -38,6 +40,7 @@ def http(
headers: Optional[dict] = None,
pool: Optional[PoolManager] = None,
urllib3_request_options: Optional[dict] = None,
query_params: Optional[DynamicQueryParams] = None
) -> ConnectStrategy:
"""
Creates the default HTTP implementation, specifying request parameters.
Expand All @@ -47,9 +50,11 @@ def http(
:param pool: optional urllib3 ``PoolManager`` to provide an HTTP client
:param urllib3_request_options: optional ``kwargs`` to add to the ``request`` call; these
can include any parameters supported by ``urllib3``, such as ``timeout``
:param query_params: optional callable that can be used to affect query parameters
dynamically for each connection attempt
"""
return _HttpConnectStrategy(
_HttpConnectParams(url, headers, pool, urllib3_request_options)
_HttpConnectParams(url, headers, pool, urllib3_request_options, query_params)
)


Expand Down
26 changes: 24 additions & 2 deletions ld_eventsource/http.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from logging import Logger
from typing import Callable, Iterator, Optional, Tuple
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit

from urllib3 import PoolManager
from urllib3.exceptions import MaxRetryError
Expand All @@ -9,6 +10,12 @@

_CHUNK_SIZE = 10000

DynamicQueryParams = Callable[[], dict[str, str]]
"""
A callable that returns a dictionary of query parameters to add to the URL.
This can be used to modify query parameters dynamically for each connection attempt.
"""


class _HttpConnectParams:
def __init__(
Expand All @@ -17,16 +24,22 @@ def __init__(
headers: Optional[dict] = None,
pool: Optional[PoolManager] = None,
urllib3_request_options: Optional[dict] = None,
query_params: Optional[DynamicQueryParams] = None
):
self.__url = url
self.__headers = headers
self.__pool = pool
self.__urllib3_request_options = urllib3_request_options
self.__query_params = query_params

@property
def url(self) -> str:
return self.__url

@property
def query_params(self) -> Optional[DynamicQueryParams]:
return self.__query_params

@property
def headers(self) -> Optional[dict]:
return self.__headers
Expand All @@ -48,7 +61,16 @@ def __init__(self, params: _HttpConnectParams, logger: Logger):
self.__logger = logger

def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callable]:
self.__logger.info("Connecting to stream at %s" % self.__params.url)
url = self.__params.url
if self.__params.query_params is not None:
qp = self.__params.query_params()
if qp:
url_parts = list(urlsplit(url))
query = dict(parse_qsl(url_parts[3]))
query.update(qp)
url_parts[3] = urlencode(query)
url = urlunsplit(url_parts)
self.__logger.info("Connecting to stream at %s" % url)

headers = self.__params.headers.copy() if self.__params.headers else {}
headers['Cache-Control'] = 'no-cache'
Expand All @@ -67,7 +89,7 @@ def connect(self, last_event_id: Optional[str]) -> Tuple[Iterator[bytes], Callab
try:
resp = self.__pool.request(
'GET',
self.__params.url,
url,
preload_content=False,
retries=Retry(
total=None, read=0, connect=0, status=0, other=0, redirect=3
Expand Down
2 changes: 1 addition & 1 deletion ld_eventsource/testing/http_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def do_POST(self):
def _do_request(self):
server_wrapper = self.server.server_wrapper
server_wrapper.requests.put(MockServerRequest(self))
handler = server_wrapper.matchers.get(self.path)
handler = server_wrapper.matchers.get(self.path.split("?")[0], None)
if handler:
handler.write(self)
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from urllib.parse import parse_qsl

from ld_eventsource import *
from ld_eventsource.config import *
from ld_eventsource.testing.helpers import *
Expand Down Expand Up @@ -56,6 +58,48 @@ def test_sse_client_reconnects_after_socket_closed():
assert event2.data == 'data2'


def test_sse_client_allows_modifying_query_params_dynamically():
count = 0

def dynamic_query_params() -> dict[str, str]:
nonlocal count
count += 1
params = {'count': str(count)}
if count > 1:
params['option'] = 'updated'

return params

with start_server() as server:
with make_stream() as stream1:
with make_stream() as stream2:
server.for_path('/', SequentialHandler(stream1, stream2))
stream1.push("event: a\ndata: data1\nid: id123\n\n")
stream2.push("event: b\ndata: data2\n\n")
with SSEClient(
connect=ConnectStrategy.http(f"{server.uri}?basis=unchanging&option=initial", query_params=dynamic_query_params),
error_strategy=ErrorStrategy.always_continue(),
initial_retry_delay=0,
) as client:
client.start()
next(client.events)
stream1.close()
next(client.events)
r1 = server.await_request()
r1_query_params = dict(parse_qsl(r1.path.split('?', 1)[1]))

# Ensure we can add, retain, and modify query parameters
assert r1_query_params.get('count') == '1'
assert r1_query_params.get('basis') == 'unchanging'
assert r1_query_params.get('option') == 'initial'

r2 = server.await_request()
r2_query_params = dict(parse_qsl(r2.path.split('?', 1)[1]))
assert r2_query_params.get('count') == '2'
assert r2_query_params.get('basis') == 'unchanging'
assert r2_query_params.get('option') == 'updated'


def test_sse_client_sends_last_event_id_on_reconnect():
with start_server() as server:
with make_stream() as stream1:
Expand Down