Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions fila/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from fila.errors import (
Expand Down Expand Up @@ -29,6 +30,8 @@
import ssl
from collections.abc import AsyncIterator

_log = logging.getLogger(__name__)


class AsyncClient:
"""Asynchronous client for the Fila message broker (FIBP transport).
Expand Down Expand Up @@ -175,6 +178,8 @@ async def enqueue_many(
) -> list[EnqueueResult]:
"""Enqueue multiple messages, possibly targeting different queues.

Per-queue FIBP requests are issued concurrently via ``asyncio.gather``.

Args:
messages: List of ``(queue, headers, payload)`` tuples.

Expand All @@ -184,6 +189,7 @@ async def enqueue_many(
Raises:
TransportError: For unexpected FIBP failures.
"""
import asyncio
from collections import defaultdict

by_queue: dict[str, list[tuple[int, dict[str, str], bytes]]] = defaultdict(list)
Expand All @@ -196,27 +202,35 @@ async def enqueue_many(
by_queue[queue].append((idx, hdrs or {}, payload))
order.append((queue, idx))

results_by_queue: dict[str, list[EnqueueResult]] = {}
for queue_name, items in by_queue.items():
async def _send_one_queue(
queue_name: str,
items: list[tuple[int, dict[str, str], bytes]],
) -> tuple[str, list[EnqueueResult]]:
corr_id = self._conn.alloc_corr_id()
msgs = [(queue_name, h, p) for _, h, p in items]
frame = encode_enqueue(corr_id, msgs)
try:
body = await self._conn.send_request(frame, corr_id)
except FibpError as e:
err = str(e)
results_by_queue[queue_name] = [
return queue_name, [
EnqueueResult(message_id=None, error=err) for _ in items
]
continue
decoded = decode_enqueue_response(body)
per_queue: list[EnqueueResult] = []
for ok, msg_id, _err_code, err_msg in decoded:
if ok:
per_queue.append(EnqueueResult(message_id=msg_id, error=None))
else:
per_queue.append(EnqueueResult(message_id=None, error=err_msg))
results_by_queue[queue_name] = per_queue
return queue_name, per_queue

coros = [
_send_one_queue(queue_name, items)
for queue_name, items in by_queue.items()
]
gathered = await asyncio.gather(*coros)
results_by_queue: dict[str, list[EnqueueResult]] = dict(gathered)

per_queue_counters: dict[str, int] = defaultdict(int)
final: list[EnqueueResult] = []
Expand Down Expand Up @@ -263,6 +277,10 @@ async def _consume_iter(
decode_consume_message(body)
)
except Exception:
_log.warning(
"failed to decode consume message; skipping frame",
exc_info=True,
)
continue
yield ConsumeMessage(
id=msg_id,
Expand Down
4 changes: 2 additions & 2 deletions fila/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING

from fila.errors import EnqueueError, _map_enqueue_error_code
from fila.errors import _map_enqueue_error_code
from fila.fibp import (
FibpError,
decode_enqueue_response,
Expand Down Expand Up @@ -99,7 +99,7 @@ def _flush_queue_batch(
try:
body = conn.send_request(frame, corr_id).result()
except FibpError as e:
err = EnqueueError(f"enqueue transport error: {e.message}")
err = _map_enqueue_error_code(e.code, e.message)
for item in items:
item.future.set_exception(err)
return
Expand Down
10 changes: 7 additions & 3 deletions fila/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

from fila.batcher import AutoAccumulator, LingerAccumulator
Expand Down Expand Up @@ -30,6 +31,8 @@
import ssl
from collections.abc import Iterator

_log = logging.getLogger(__name__)


class Client:
"""Synchronous client for the Fila message broker (FIBP transport).
Expand Down Expand Up @@ -276,9 +279,6 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]:
Yields messages as they become available. The iterator ends when the
server closes the stream.

If the server returns a leader-hint error, the client transparently
reconnects to the leader address and retries once.

Args:
queue: Queue to consume from.

Expand Down Expand Up @@ -310,6 +310,10 @@ def _consume_iter(self, cq: object) -> Iterator[ConsumeMessage]:
decode_consume_message(body)
)
except Exception:
_log.warning(
"failed to decode consume message; skipping frame",
exc_info=True,
)
continue
yield ConsumeMessage(
id=msg_id,
Expand Down
17 changes: 7 additions & 10 deletions fila/fibp.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def __init__(
self._api_key = api_key

self._lock = threading.Lock()
self._send_lock = threading.Lock()
self._next_corr_id: int = 1
# corr_id → Future[bytes] for request/response ops
self._pending: dict[int, Future[bytes]] = {}
Expand Down Expand Up @@ -410,15 +411,17 @@ def send_request(self, frame: bytes, corr_id: int) -> Future[bytes]:
fut: Future[bytes] = Future()
with self._lock:
self._pending[corr_id] = fut
self._sock.sendall(frame)
with self._send_lock:
self._sock.sendall(frame)
return fut

def open_consume_stream(self, frame: bytes, corr_id: int) -> _ConsumeQueue:
"""Register a consume queue, send *frame*, and return the queue."""
cq = _ConsumeQueue()
with self._lock:
self._consume_queues[corr_id] = cq
self._sock.sendall(frame)
with self._send_lock:
self._sock.sendall(frame)
return cq

def alloc_corr_id(self) -> int:
Expand Down Expand Up @@ -704,14 +707,8 @@ def make_ssl_context(
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

if ca_cert is not None:
# Write CA cert to a temp file (SSLContext only accepts file paths).
with tempfile.NamedTemporaryFile(delete=False, suffix=".pem") as f:
f.write(ca_cert)
ca_path = f.name
try:
ctx.load_verify_locations(ca_path)
finally:
os.unlink(ca_path)
# Pass PEM bytes directly via cadata to avoid writing a temp file.
ctx.load_verify_locations(cadata=ca_cert.decode())
else:
ctx.load_default_certs()

Expand Down
Loading