Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: safely store and close tasks #253

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 20 additions & 17 deletions nextcore/common/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@

from __future__ import annotations

from asyncio import CancelledError, Future, create_task
from asyncio import CancelledError, Future
from collections import defaultdict
from logging import getLogger
from typing import TYPE_CHECKING, Generic, Hashable, TypeVar, cast, overload

from anyio import create_task_group

from .maybe_coro import maybe_coro

# Types
Expand Down Expand Up @@ -441,22 +443,23 @@ async def dispatch(self, event_name: EventNameT, *args: Any) -> None:
"""
logger.debug("Dispatching event %s", event_name)

# Event handlers
# Tasks are used here as some event handler/check might take a long time.
for handler in self._global_event_handlers:
logger.debug("Dispatching to a global handler")
create_task(self._run_global_event_handler(handler, event_name, *args))
for handler in self._event_handlers.get(event_name, []):
logger.debug("Dispatching to a local handler")
create_task(self._run_event_handler(handler, event_name, *args))

# Wait for handlers
for check, future in self._wait_for_handlers.get(event_name, []):
logger.debug("Dispatching to a wait_for handler")
create_task(self._run_wait_for_handler(check, future, event_name, *args))
for check, future in self._global_wait_for_handlers:
logger.debug("Dispatching to a global wait_for handler")
create_task(self._run_global_wait_for_handler(check, future, event_name, *args))
async with create_task_group() as tg:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wont this block until all callbacks have finished running? Because if so, that is going to cause issues.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the task encounters an await statement that requires the task to sleep until something happens, the event loop is then free to work on another task.
AnyIO Docs

It looks like it is blocking. I'm not sure what else to do though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could a per-dispatcher task group be created?

# Event handlers
# Tasks are used here as some event handler/check might take a long time.
for handler in self._global_event_handlers:
logger.debug("Dispatching to a global handler")
tg.start_soon(self._run_global_event_handler, handler, event_name, *args)
for handler in self._event_handlers.get(event_name, []):
logger.debug("Dispatching to a local handler")
tg.start_soon(self._run_event_handler, handler, event_name, *args)

# Wait for handlers
for check, future in self._global_wait_for_handlers:
logger.debug("Dispatching to a global wait_for handler")
tg.start_soon(self._run_global_wait_for_handler, check, future, event_name, *args)
for check, future in self._wait_for_handlers.get(event_name, []):
logger.debug("Dispatching to a wait_for handler")
tg.start_soon(self._run_wait_for_handler, check, future, event_name, *args)

async def _run_event_handler(self, callback: EventCallback, event_name: EventNameT, *args: Any) -> None:
"""Run event with exception handlers"""
Expand Down
19 changes: 17 additions & 2 deletions nextcore/gateway/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ class Shard:
"_logger",
"_received_heartbeat_ack",
"_http_client",
"_receive_task",
"_heartbeat_task",
"_heartbeat_sent_at",
"_latency",
)
Expand Down Expand Up @@ -228,6 +230,8 @@ def __init__(
self._logger: Logger = getLogger(f"{__name__}.{self.shard_id}")
self._received_heartbeat_ack: bool = True
self._http_client: HTTPClient = http_client # TODO: Should this be private?
self._receive_task: asyncio.Task[None] | None = None
self._heartbeat_task: asyncio.Task[None] | None = None

# Latency
self._heartbeat_sent_at: float | None = None
Expand Down Expand Up @@ -282,7 +286,7 @@ async def connect(self) -> None:
self._received_heartbeat_ack = True
self._ws = ws # Use the new connection

create_task(self._receive_loop())
self._receive_task = create_task(self._receive_loop())

# Connection logic is continued in _handle_hello to account for that rate limits are defined there.

Expand Down Expand Up @@ -334,6 +338,17 @@ async def close(self, *, cleanup: bool = True) -> None:
await self._ws.close(code=999)
self._ws = None # Clear it to save some ram
self._send_rate_limit = None # No longer applies

# safely stop running tasks

if self._receive_task is not None:
self._receive_task.cancel()
self._receive_task = None

if self._heartbeat_task is not None:
self._heartbeat_task.cancel()
self._heartbeat_task = None

self.connected.clear()

@property
Expand Down Expand Up @@ -538,7 +553,7 @@ async def _handle_hello(self, data: HelloEvent) -> None:
heartbeat_interval = data["d"]["heartbeat_interval"] / 1000 # Convert from ms to seconds

loop = get_running_loop()
loop.create_task(self._heartbeat_loop(heartbeat_interval))
self._heartbeat_task = loop.create_task(self._heartbeat_loop(heartbeat_interval))

# Create a rate limiter
times, per = self._GATEWAY_SEND_RATE_LIMITS
Expand Down
26 changes: 13 additions & 13 deletions nextcore/gateway/shard_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

from __future__ import annotations

from asyncio import CancelledError, gather, get_running_loop
from asyncio import CancelledError, Task, gather, get_running_loop
from collections import defaultdict
from logging import getLogger
from typing import TYPE_CHECKING

from aiohttp import ClientConnectionError
from anyio import create_task_group

from ..common import Dispatcher, TimesPer
from ..http import Route
Expand Down Expand Up @@ -188,17 +189,20 @@ async def connect(self) -> None:
else:
shard_ids = self.shard_ids

for shard_id in shard_ids:
shard = self._spawn_shard(shard_id, self._active_shard_count)
async with create_task_group() as tg:
for shard_id in shard_ids:
shard = self._spawn_shard(shard_id, self._active_shard_count)
# Here we lazy connect the shard. This gives us a bit more speed when connecting large sets of shards.
await tg.spawn(shard.connect)
EmmmaTech marked this conversation as resolved.
Show resolved Hide resolved

# Register event listeners
shard.raw_dispatcher.add_listener(self._on_raw_shard_receive)
shard.event_dispatcher.add_listener(self._on_shard_dispatch)
shard.dispatcher.add_listener(self._on_shard_critical, "critical")
# Register event listeners
shard.raw_dispatcher.add_listener(self._on_raw_shard_receive)
shard.event_dispatcher.add_listener(self._on_shard_dispatch)
shard.dispatcher.add_listener(self._on_shard_critical, "critical")

logger.info("Added shard event listeners")
logger.info("Added shard event listeners")

self.active_shards.append(shard)
self.active_shards.append(shard)

def _spawn_shard(self, shard_id: int, shard_count: int) -> Shard:
assert self.max_concurrency is not None, "max_concurrency is not set. This is set in connect"
Expand All @@ -214,10 +218,6 @@ def _spawn_shard(self, shard_id: int, shard_count: int) -> Shard:
presence=self.presence,
)

# Here we lazy connect the shard. This gives us a bit more speed when connecting large sets of shards.
loop = get_running_loop()
loop.create_task(shard.connect())

return shard

async def rescale_shards(self, shard_count: int, shard_ids: list[int] | None = None) -> None:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ packages = [
[tool.poetry.dependencies]
python = "^3.8"
aiohttp = ">=3.6.0,<4.0.0"
anyio = "^3.7.0"
frozendict = "^2.3.0"
types-frozendict = "^2.0.6" # Could we extend the version requirement
typing-extensions = "^4.1.1" # Same as above
Expand Down