Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/docket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Retry,
TaskKey,
TaskLogger,
Timeout,
)
from .docket import Docket
from .execution import Execution
Expand All @@ -36,5 +37,6 @@
"ExponentialRetry",
"Logged",
"Perpetual",
"Timeout",
"__version__",
]
31 changes: 31 additions & 0 deletions src/docket/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import inspect
import logging
import time
from datetime import timedelta
from typing import Any, Awaitable, Callable, Counter, TypeVar, cast

Expand Down Expand Up @@ -171,6 +172,36 @@ def perpetuate(self, *args: Any, **kwargs: Any) -> None:
self.kwargs = kwargs


class Timeout(Dependency):
single = True

base: timedelta

_deadline: float

def __init__(self, base: timedelta) -> None:
self.base = base

def __call__(
self, docket: Docket, worker: Worker, execution: Execution
) -> "Timeout":
return Timeout(base=self.base)

def start(self) -> None:
self._deadline = time.monotonic() + self.base.total_seconds()

def expired(self) -> bool:
return time.monotonic() >= self._deadline

def remaining(self) -> timedelta:
return timedelta(seconds=self._deadline - time.monotonic())

def extend(self, by: timedelta | None = None) -> None:
if by is None:
by = self.base
self._deadline += by.total_seconds()


def get_dependency_parameters(
function: Callable[..., Awaitable[Any]],
) -> dict[str, Dependency]:
Expand Down
55 changes: 47 additions & 8 deletions src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@


if TYPE_CHECKING: # pragma: no cover
from .dependencies import Dependency
from .dependencies import Dependency, Timeout


class _stream_due_tasks(Protocol):
Expand Down Expand Up @@ -507,13 +507,20 @@ async def _execute(self, message: RedisMessage) -> None:
},
links=links,
):
await execution.function(
*execution.args,
**{
**execution.kwargs,
**dependencies,
},
)
from .dependencies import Timeout, get_single_dependency_of_type

if timeout := get_single_dependency_of_type(dependencies, Timeout):
await self._run_function_with_timeout(
execution, dependencies, timeout
)
else:
await execution.function(
*execution.args,
**{
**execution.kwargs,
**dependencies,
},
)

TASKS_SUCCEEDED.add(1, counter_labels)
duration = datetime.now(timezone.utc) - start
Expand All @@ -539,6 +546,38 @@ async def _execute(self, message: RedisMessage) -> None:
TASKS_COMPLETED.add(1, counter_labels)
TASK_DURATION.record(duration.total_seconds(), counter_labels)

async def _run_function_with_timeout(
self,
execution: Execution,
dependencies: dict[str, "Dependency"],
timeout: "Timeout",
) -> None:
timeout.start()
task_coro = execution.function(
*execution.args, **execution.kwargs, **dependencies
)
task = asyncio.create_task(task_coro)
try:
while not task.done(): # pragma: no branch
remaining = timeout.remaining().total_seconds()
if timeout.expired():
task.cancel()
break

try:
await asyncio.wait_for(asyncio.shield(task), timeout=remaining)
return
except asyncio.TimeoutError:
continue
finally:
if not task.done():
task.cancel()

try:
await task
except asyncio.CancelledError:
raise asyncio.TimeoutError

def _get_dependencies(
self,
execution: Execution,
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ async def docket(redis_url: str, aiolib: str) -> AsyncGenerator[Docket, None]:
async def worker(docket: Docket) -> AsyncGenerator[Worker, None]:
async with Worker(
docket,
minimum_check_interval=timedelta(milliseconds=10),
scheduling_resolution=timedelta(milliseconds=10),
minimum_check_interval=timedelta(milliseconds=5),
scheduling_resolution=timedelta(milliseconds=5),
) as worker:
yield worker

Expand Down
136 changes: 135 additions & 1 deletion tests/test_fundamentals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
as possible to aid with understanding docket.
"""

import asyncio
import logging
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from logging import LoggerAdapter
from typing import Annotated, Callable
from unittest.mock import AsyncMock, call
Expand All @@ -26,6 +27,7 @@
Retry,
TaskKey,
TaskLogger,
Timeout,
Worker,
tasks,
)
Expand Down Expand Up @@ -1065,3 +1067,135 @@ async def my_automatic_task(
await worker.run_at_most({"my_automatic_task": 3})

assert calls == 3


async def test_simple_timeout(docket: Docket, worker: Worker):
"""A task can be scheduled with a timeout"""

called = False

async def task_with_timeout(
timeout: Timeout = Timeout(timedelta(milliseconds=100)),
):
await asyncio.sleep(0.01)

nonlocal called
called = True

await docket.add(task_with_timeout)()

start = datetime.now(timezone.utc)

await worker.run_until_finished()

elapsed = datetime.now(timezone.utc) - start

assert called
assert elapsed <= timedelta(milliseconds=150)


async def test_simple_timeout_cancels_tasks(docket: Docket, worker: Worker):
"""A task can be scheduled with a timeout and are cancelled"""

called = False

async def task_with_timeout(
timeout: Timeout = Timeout(timedelta(milliseconds=100)),
):
try:
await asyncio.sleep(5)
except asyncio.CancelledError:
nonlocal called
called = True

await docket.add(task_with_timeout)()

start = datetime.now(timezone.utc)

await worker.run_until_finished()

elapsed = datetime.now(timezone.utc) - start

assert called
assert timedelta(milliseconds=100) <= elapsed <= timedelta(milliseconds=200)


async def test_timeout_can_be_extended(docket: Docket, worker: Worker):
"""A task can be scheduled with a timeout and extend themselves"""

called = False

async def task_with_timeout(
timeout: Timeout = Timeout(timedelta(milliseconds=100)),
):
await asyncio.sleep(0.05)

timeout.extend(timedelta(milliseconds=200))

try:
await asyncio.sleep(5)
except asyncio.CancelledError:
nonlocal called
called = True

await docket.add(task_with_timeout)()

start = datetime.now(timezone.utc)

await worker.run_until_finished()

elapsed = datetime.now(timezone.utc) - start

assert called
assert timedelta(milliseconds=250) <= elapsed <= timedelta(milliseconds=400)


async def test_timeout_extends_by_base_by_default(docket: Docket, worker: Worker):
"""A task can be scheduled with a timeout and extend itself by the base timeout"""

called = False

async def task_with_timeout(
timeout: Timeout = Timeout(timedelta(milliseconds=100)),
):
await asyncio.sleep(0.05)

timeout.extend() # defaults to the base timeout

try:
await asyncio.sleep(5)
except asyncio.CancelledError:
nonlocal called
called = True

await docket.add(task_with_timeout)()

start = datetime.now(timezone.utc)

await worker.run_until_finished()

elapsed = datetime.now(timezone.utc) - start

assert called
assert timedelta(milliseconds=150) <= elapsed <= timedelta(milliseconds=300)


async def test_timeout_is_compatible_with_retry(docket: Docket, worker: Worker):
"""A task that times out can be retried"""

successes: list[int] = []

async def task_with_timeout(
retry: Retry = Retry(attempts=3),
timeout: Timeout = Timeout(timedelta(milliseconds=100)),
):
if retry.attempt == 1:
await asyncio.sleep(1)

successes.append(retry.attempt)

await docket.add(task_with_timeout)()

await worker.run_until_finished()

assert successes == [2]
6 changes: 2 additions & 4 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,12 @@ async def perpetual_task(

intervals = [next - previous for previous, next in zip(timestamps, timestamps[1:])]
minimum = min(intervals)
maximum = max(intervals)

debug = ", ".join([f"{i.total_seconds() * 1000:.2f}ms" for i in intervals])

# even with a variable duration, Docket attempts to schedule them equally and to
# abide by the target interval
# It's not reliable to assert the maximum duration on different machine setups, but
# we'll make sure that the minimum is observed, which is the guarantee
assert minimum >= timedelta(milliseconds=50), debug
assert maximum <= timedelta(milliseconds=75), debug


async def test_worker_can_exit_from_perpetual_tasks_that_queue_further_tasks(
Expand Down