Skip to content

Commit

Permalink
Support Python 3.11 (#574)
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jul 21, 2023
1 parent b7cbd97 commit ec1d7fe
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.7', '3.8', '3.9', '3.10' ]
python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11' ]
timeout-minutes: 15
steps:
- uses: actions/checkout@v3
Expand Down
10 changes: 7 additions & 3 deletions hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ async def rpc_join_group(
# wait for the group to be assembled or disbanded
timeout = max(0.0, self.potential_leaders.declared_expiration_time - get_dht_time())
await asyncio.wait(
{self.assembled_group, self.was_accepted_to_group.wait()},
{self.assembled_group, asyncio.create_task(self.was_accepted_to_group.wait())},
return_when=asyncio.FIRST_COMPLETED,
timeout=timeout,
)
Expand Down Expand Up @@ -480,7 +480,11 @@ async def pop_next_leader(self) -> PeerID:
self.peer_id.to_bytes(),
):
await asyncio.wait(
{self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED
{
asyncio.create_task(self.update_finished.wait()),
asyncio.create_task(self.declared_expiration.wait()),
},
return_when=asyncio.FIRST_COMPLETED,
)
self.declared_expiration.clear()
if self.update_finished.is_set():
Expand Down Expand Up @@ -511,7 +515,7 @@ async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None
self.update_finished.set()

await asyncio.wait(
{self.running.wait(), self.update_triggered.wait()},
{asyncio.create_task(self.running.wait()), asyncio.create_task(self.update_triggered.wait())},
return_when=asyncio.ALL_COMPLETED,
timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
)
Expand Down
5 changes: 4 additions & 1 deletion hivemind/averaging/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,10 @@ async def accumulate_part(

while part_index > self.current_part_index:
# wait for previous parts to finish processing ...
await asyncio.wait({self.current_part_future, self.finished.wait()}, return_when=asyncio.FIRST_COMPLETED)
await asyncio.wait(
{self.current_part_future, asyncio.create_task(self.finished.wait())},
return_when=asyncio.FIRST_COMPLETED,
)
if self.finished.is_set():
raise AllreduceException(f"attempted to aggregate part in a finalized {self.__class__.__name__}")

Expand Down
2 changes: 1 addition & 1 deletion hivemind/dht/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def call_ping(self, peer: PeerID, validate: bool = False, strict: bool = T
f"Peer {peer} can't access this node. " f"Probably, libp2p has failed to bypass the firewall"
)

if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
if response.dht_time != dht_pb2.PingResponse.DESCRIPTOR.fields_by_name["dht_time"].default_value:
if (
response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS
or response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS
Expand Down
4 changes: 3 additions & 1 deletion hivemind/dht/traverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ async def worker():
# get nearest neighbors (over network) and update search heaps. Abort if search finishes early
get_neighbors_task = asyncio.create_task(get_neighbors(chosen_peer, queries_to_call))
pending_tasks.add(get_neighbors_task)
await asyncio.wait([get_neighbors_task, search_finished_event.wait()], return_when=asyncio.FIRST_COMPLETED)
await_finished_task = asyncio.create_task(search_finished_event.wait())
await asyncio.wait([get_neighbors_task, await_finished_task], return_when=asyncio.FIRST_COMPLETED)
del await_finished_task
if search_finished_event.is_set():
break # other worker triggered finish_search, we exit immediately
pending_tasks.remove(get_neighbors_task)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ msgpack>=0.5.6
sortedcontainers
uvloop>=0.14.0
grpcio-tools>=1.33.2
protobuf>=3.12.2,<4.0.0
protobuf>=3.12.2
configargparse>=1.2.3
multiaddr>=0.0.9
pymultihash>=0.8.2
Expand Down
2 changes: 1 addition & 1 deletion tests/test_averaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def test_averaging_cancel():

step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]

time.sleep(0.1)
time.sleep(0.05)
step_controls[0].cancel()
step_controls[1].cancel()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dht_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,4 +301,4 @@ async def test_dhtnode_edge_cases():
assert subkey in stored.value
assert stored.value[subkey].value == value

await asyncio.wait([node.shutdown() for node in peers])
await asyncio.wait([asyncio.create_task(node.shutdown()) for node in peers])
2 changes: 1 addition & 1 deletion tests/test_util_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ async def coro():

num_coros = max(100, mp.cpu_count() * 5 + 1)
# note: if we deprecate py3.7, this can be reduced to max(33, cpu + 5); see https://bugs.python.org/issue35279
await asyncio.wait({coro() for _ in range(num_coros)})
await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)})


def test_batch_tensor_descriptor_msgpack():
Expand Down

0 comments on commit ec1d7fe

Please sign in to comment.