Skip to content

Commit

Permalink
Merge pull request #20 from icgood/fixes
Browse files Browse the repository at this point in the history
Fix bugs in Python 3.11
  • Loading branch information
icgood committed Feb 11, 2023
2 parents 83568a9 + 84f98e4 commit e30bb59
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 13 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
license = f.read()

setup(name='swim-protocol',
version='0.3.10',
version='0.3.11',
author='Ian Good',
author_email='ian@icgood.net',
description='SWIM protocol implementation for exchanging cluster '
Expand Down
10 changes: 7 additions & 3 deletions swimprotocol/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import asyncio
from abc import abstractmethod
from asyncio import Event
from collections.abc import Sequence
from asyncio import Event, Task
from collections.abc import MutableSet, Sequence
from contextlib import ExitStack
from typing import TypeVar, Generic, Protocol, Any, NoReturn
from weakref import WeakKeyDictionary
Expand Down Expand Up @@ -43,15 +43,19 @@ class Listener(Generic[ListenT]):
def __init__(self, cls: type[ListenT]) -> None:
super().__init__()
self.event = Event()
self._running: MutableSet[Task[Any]] = set()
self._waiting: WeakKeyDictionary[Event, list[ListenT]] = \
WeakKeyDictionary()

async def _run_callback_poll(self, callback: ListenerCallback[ListenT]) \
-> NoReturn:
running = self._running
while True:
items = await self.poll()
for item in items:
asyncio.create_task(callback(item))
task = asyncio.create_task(callback(item))
running.add(task)
task.add_done_callback(running.discard)

def on_notify(self, callback: ListenerCallback[ListenT]) -> ExitStack:
"""Provides a context manager that causes *callback* to be called when
Expand Down
4 changes: 2 additions & 2 deletions swimprotocol/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __len__(self) -> int:
def _refresh_statuses(self, member: Member) -> None:
if not member.local:
member_status = member.status
for status in Status:
for status in Status.all_statuses():
if member_status & status:
self._statuses[status].add(member)
else:
Expand Down Expand Up @@ -237,7 +237,7 @@ def get(self, name: str, validity: Optional[bytes] = None) -> Member:
member = Member(name, False)
self._non_local.add(member)
self._members[name] = member
for status in Status:
for status in Status.all_statuses():
if member.status & status:
self._statuses[status].add(member)
if not member.local and validity is not None \
Expand Down
6 changes: 6 additions & 0 deletions swimprotocol/status.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

from __future__ import annotations

from collections.abc import Collection
from enum import auto, Flag

__all__ = ['Status']
Expand Down Expand Up @@ -58,3 +59,8 @@ def transition(self, to: Status) -> Status:
return Status.SUSPECT
else:
return to

@classmethod
def all_statuses(cls) -> Collection[Status]:
"""A collection of all the statuses, including aggregate statuses."""
return cls.__members__.values()
9 changes: 7 additions & 2 deletions swimprotocol/udp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

import asyncio
import logging
from asyncio import BaseTransport, Condition, DatagramProtocol, \
from asyncio import BaseTransport, Condition, Task, DatagramProtocol, \
DatagramTransport
from collections import deque
from collections.abc import MutableSet
from typing import cast, Final, Optional

from .pack import UdpPack
Expand Down Expand Up @@ -34,6 +35,7 @@ def __init__(self, address_parser: AddressParser,
self._transport: Optional[DatagramTransport] = None
self._queue_lock = Condition()
self._queue: deque[Packet] = deque()
self._running: MutableSet[Task[None]] = set()

@property
def transport(self) -> DatagramTransport:
Expand Down Expand Up @@ -61,7 +63,10 @@ def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
packet = self.udp_pack.unpack(data)
if packet is None:
return
asyncio.create_task(self._push(packet))
running = self._running
task = asyncio.create_task(self._push(packet))
running.add(task)
task.add_done_callback(running.discard)

def error_received(self, exc: Exception) -> None:
"""Called when a UDP send or receive operation fails.
Expand Down
16 changes: 11 additions & 5 deletions swimprotocol/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
from abc import abstractmethod
from asyncio import Event, Task, TimeoutError
from collections.abc import Mapping, Sequence
from collections.abc import Mapping, MutableSet, Sequence
from contextlib import suppress
from typing import final, Protocol, Final, Optional, NoReturn
from weakref import WeakSet, WeakKeyDictionary
Expand Down Expand Up @@ -63,13 +63,19 @@ def __init__(self, config: BaseConfig, members: Members, io: IO) -> None:
self.config: Final = config
self.members: Final = members
self.io: Final = io
self._running: MutableSet[Task[None]] = set()
self._waiting: WeakKeyDictionary[Member, WeakSet[Event]] = \
WeakKeyDictionary()
self._listening: WeakKeyDictionary[Member, WeakSet[Member]] = \
WeakKeyDictionary()
self._suspect: WeakKeyDictionary[Member, Task[None]] = \
WeakKeyDictionary()

def _run_task(self, task: Task[None]) -> None:
running = self._running
running.add(task)
task.add_done_callback(running.discard)

def _add_waiting(self, member: Member, event: Event) -> None:
waiting = self._waiting.get(member)
if waiting is None:
Expand Down Expand Up @@ -181,8 +187,8 @@ async def check(self, target: Member) -> None:
count, status=Status.AVAILABLE, exclude={target})
if indirects:
await asyncio.wait([
self.io.send(indirect, PingReq(
source=local.source, target=target.name))
asyncio.create_task(self.io.send(indirect, PingReq(
source=local.source, target=target.name)))
for indirect in indirects])
online = await self._wait(target, self.config.ping_req_timeout)
new_status = Status.ONLINE if online else Status.SUSPECT
Expand Down Expand Up @@ -220,7 +226,7 @@ async def run_failure_detection(self) -> NoReturn:
targets = self.members.find(1)
assert targets
for target in targets:
asyncio.create_task(self.check(target))
self._run_task(asyncio.create_task(self.check(target)))
await asyncio.sleep(self.config.ping_interval)

async def run_dissemination(self) -> NoReturn:
Expand All @@ -236,7 +242,7 @@ async def run_dissemination(self) -> NoReturn:
while True:
targets = self.members.find(1, status=Status.AVAILABLE)
for target in targets:
asyncio.create_task(self.disseminate(target))
self._run_task(asyncio.create_task(self.disseminate(target)))
await asyncio.sleep(self.config.sync_interval)

@final
Expand Down

0 comments on commit e30bb59

Please sign in to comment.