diff --git a/setup.py b/setup.py index 3c25fff..59dfa38 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ license = f.read() setup(name='swim-protocol', - version='0.3.10', + version='0.3.11', author='Ian Good', author_email='ian@icgood.net', description='SWIM protocol implementation for exchanging cluster ' diff --git a/swimprotocol/listener.py b/swimprotocol/listener.py index dda4119..e809e05 100644 --- a/swimprotocol/listener.py +++ b/swimprotocol/listener.py @@ -3,8 +3,8 @@ import asyncio from abc import abstractmethod -from asyncio import Event -from collections.abc import Sequence +from asyncio import Event, Task +from collections.abc import MutableSet, Sequence from contextlib import ExitStack from typing import TypeVar, Generic, Protocol, Any, NoReturn from weakref import WeakKeyDictionary @@ -43,15 +43,19 @@ class Listener(Generic[ListenT]): def __init__(self, cls: type[ListenT]) -> None: super().__init__() self.event = Event() + self._running: MutableSet[Task[Any]] = set() self._waiting: WeakKeyDictionary[Event, list[ListenT]] = \ WeakKeyDictionary() async def _run_callback_poll(self, callback: ListenerCallback[ListenT]) \ -> NoReturn: + running = self._running while True: items = await self.poll() for item in items: - asyncio.create_task(callback(item)) + task = asyncio.create_task(callback(item)) + running.add(task) + task.add_done_callback(running.discard) def on_notify(self, callback: ListenerCallback[ListenT]) -> ExitStack: """Provides a context manager that causes *callback* to be called when diff --git a/swimprotocol/members.py b/swimprotocol/members.py index b3d0445..07c377f 100644 --- a/swimprotocol/members.py +++ b/swimprotocol/members.py @@ -172,7 +172,7 @@ def __len__(self) -> int: def _refresh_statuses(self, member: Member) -> None: if not member.local: member_status = member.status - for status in Status: + for status in Status.all_statuses(): if member_status & status: self._statuses[status].add(member) else: @@ -237,7 +237,7 @@ def get(self, name: str, validity: Optional[bytes] = None) -> Member: member = Member(name, False) self._non_local.add(member) self._members[name] = member - for status in Status: + for status in Status.all_statuses(): if member.status & status: self._statuses[status].add(member) if not member.local and validity is not None \ diff --git a/swimprotocol/status.py b/swimprotocol/status.py index 1360740..5ee4ae2 100644 --- a/swimprotocol/status.py +++ b/swimprotocol/status.py @@ -1,6 +1,7 @@ from __future__ import annotations +from collections.abc import Collection from enum import auto, Flag __all__ = ['Status'] @@ -58,3 +59,8 @@ def transition(self, to: Status) -> Status: return Status.SUSPECT else: return to + + @classmethod + def all_statuses(cls) -> Collection[Status]: + """A collection of all the statuses, including aggregate statuses.""" + return cls.__members__.values() diff --git a/swimprotocol/udp/protocol.py b/swimprotocol/udp/protocol.py index d1924d2..8badeee 100644 --- a/swimprotocol/udp/protocol.py +++ b/swimprotocol/udp/protocol.py @@ -3,9 +3,10 @@ import asyncio import logging -from asyncio import BaseTransport, Condition, DatagramProtocol, \ +from asyncio import BaseTransport, Condition, Task, DatagramProtocol, \ DatagramTransport from collections import deque +from collections.abc import MutableSet from typing import cast, Final, Optional from .pack import UdpPack @@ -34,6 +35,7 @@ def __init__(self, address_parser: AddressParser, self._transport: Optional[DatagramTransport] = None self._queue_lock = Condition() self._queue: deque[Packet] = deque() + self._running: MutableSet[Task[None]] = set() @property def transport(self) -> DatagramTransport: @@ -61,7 +63,10 @@ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: packet = self.udp_pack.unpack(data) if packet is None: return - asyncio.create_task(self._push(packet)) + running = self._running + task = asyncio.create_task(self._push(packet)) + running.add(task) + task.add_done_callback(running.discard) def error_received(self, exc: Exception) -> None: """Called when a UDP send or receive operation fails. diff --git a/swimprotocol/worker.py b/swimprotocol/worker.py index 9fe7c8b..1495e57 100644 --- a/swimprotocol/worker.py +++ b/swimprotocol/worker.py @@ -4,7 +4,7 @@ import asyncio from abc import abstractmethod from asyncio import Event, Task, TimeoutError -from collections.abc import Mapping, Sequence +from collections.abc import Mapping, MutableSet, Sequence from contextlib import suppress from typing import final, Protocol, Final, Optional, NoReturn from weakref import WeakSet, WeakKeyDictionary @@ -63,6 +63,7 @@ def __init__(self, config: BaseConfig, members: Members, io: IO) -> None: self.config: Final = config self.members: Final = members self.io: Final = io + self._running: MutableSet[Task[None]] = set() self._waiting: WeakKeyDictionary[Member, WeakSet[Event]] = \ WeakKeyDictionary() self._listening: WeakKeyDictionary[Member, WeakSet[Member]] = \ @@ -70,6 +71,11 @@ def __init__(self, config: BaseConfig, members: Members, io: IO) -> None: self._suspect: WeakKeyDictionary[Member, Task[None]] = \ WeakKeyDictionary() + def _run_task(self, task: Task[None]) -> None: + running = self._running + running.add(task) + task.add_done_callback(running.discard) + def _add_waiting(self, member: Member, event: Event) -> None: waiting = self._waiting.get(member) if waiting is None: @@ -181,8 +187,8 @@ async def check(self, target: Member) -> None: count, status=Status.AVAILABLE, exclude={target}) if indirects: await asyncio.wait([ - self.io.send(indirect, PingReq( - source=local.source, target=target.name)) + asyncio.create_task(self.io.send(indirect, PingReq( + source=local.source, target=target.name))) for indirect in indirects]) online = await self._wait(target, self.config.ping_req_timeout) new_status = Status.ONLINE if online else Status.SUSPECT @@ -220,7 +226,7 @@ async def run_failure_detection(self) -> NoReturn: targets = self.members.find(1) assert targets for target in targets: - asyncio.create_task(self.check(target)) + self._run_task(asyncio.create_task(self.check(target))) await asyncio.sleep(self.config.ping_interval) async def run_dissemination(self) -> NoReturn: @@ -236,7 +242,7 @@ async def run_dissemination(self) -> NoReturn: while True: targets = self.members.find(1, status=Status.AVAILABLE) for target in targets: - asyncio.create_task(self.disseminate(target)) + self._run_task(asyncio.create_task(self.disseminate(target))) await asyncio.sleep(self.config.sync_interval) @final