Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/s2python/connection/async_/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ async def send_msg_and_await_reception_status(
if reception_status_task in done:
try:
reception_status = await reception_status_task
except TimeoutError:
except (TimeoutError, asyncio.TimeoutError):
logger.error("Did not receive a reception status on time for %s", s2_msg.message_id)
self._stop_event.set()
raise
Expand Down
44 changes: 24 additions & 20 deletions src/s2python/reception_status_awaiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@


class ReceptionStatusAwaiter:
"""Notify coroutines waiting for a `ReceptionStatus` by subject message ID.

Reception statuses are single-consumer: once awaited, they are removed."""

received: Dict[uuid.UUID, ReceptionStatus]
awaiting: Dict[uuid.UUID, asyncio.Event]

Expand All @@ -23,38 +27,38 @@ def __init__(self) -> None:
async def wait_for_reception_status(
self, message_id: uuid.UUID, timeout_reception_status: float
) -> ReceptionStatus:
if message_id in self.received:
reception_status = self.received[message_id]
else:
if message_id in self.awaiting:
received_event = self.awaiting[message_id]
else:
received_event = asyncio.Event()
self.awaiting[message_id] = received_event

await asyncio.wait_for(received_event.wait(), timeout_reception_status)
reception_status = self.received[message_id]
existing = self.received.pop(message_id, None)
if existing is not None:
return existing

if message_id in self.awaiting:
del self.awaiting[message_id]
received_event = self.awaiting.get(message_id)
if received_event is None:
received_event = asyncio.Event()
self.awaiting[message_id] = received_event

return reception_status
try:
await asyncio.wait_for(received_event.wait(), timeout_reception_status)
return self.received.pop(message_id)
finally:
self.awaiting.pop(message_id, None)

async def receive_reception_status(self, reception_status: ReceptionStatus) -> None:
if not isinstance(reception_status, ReceptionStatus):
raise RuntimeError(
f"Expected a ReceptionStatus but received message {reception_status}"
)

if reception_status.subject_message_id in self.received:
mid = reception_status.subject_message_id

if mid in self.received:
raise RuntimeError(
f"ReceptationStatus for message_subject_id {reception_status.subject_message_id} has already "
f"been received!"
f"ReceptionStatus for message_subject_id {mid} has already been received!"
)

self.received[reception_status.subject_message_id] = reception_status
awaiting = self.awaiting.get(reception_status.subject_message_id)
self.received[mid] = reception_status

if awaiting:
awaiting = self.awaiting.get(mid)
if awaiting is not None:
awaiting.set()
del self.awaiting[reception_status.subject_message_id]
self.awaiting.pop(mid, None)
15 changes: 9 additions & 6 deletions tests/unit/reception_status_awaiter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ async def test__wait_for_reception_status__multiple_receive_while_waiting(self):
should_be_waiting_still_1 = not wait_task_1.done()
should_be_waiting_still_2 = not wait_task_2.done()
await awaiter.receive_reception_status(s2_reception_status)
await wait_task_1
await wait_task_2
received_s2_reception_status_1 = wait_task_1.result()
received_s2_reception_status_2 = wait_task_2.result()
results = await asyncio.gather(wait_task_1, wait_task_2, return_exceptions=True)

# Assert
expected_s2_reception_status = ReceptionStatus( # pyright: ignore[reportCallIssue]
Expand All @@ -95,8 +92,14 @@ async def test__wait_for_reception_status__multiple_receive_while_waiting(self):

self.assertTrue(should_be_waiting_still_1)
self.assertTrue(should_be_waiting_still_2)
self.assertEqual(expected_s2_reception_status, received_s2_reception_status_1)
self.assertEqual(expected_s2_reception_status, received_s2_reception_status_2)

successful_results = [result for result in results if not isinstance(result, Exception)]
exception_results = [result for result in results if isinstance(result, Exception)]

self.assertEqual(1, len(successful_results))
self.assertEqual(1, len(exception_results))
self.assertEqual(expected_s2_reception_status, successful_results[0])
self.assertIsInstance(exception_results[0], asyncio.TimeoutError)

async def test__receive_reception_status__wrong_message(self):
# Arrange
Expand Down
Loading