From 062de0324ee6b4eccbc7e558ff1088d8ffdfee4b Mon Sep 17 00:00:00 2001 From: Chris Guidry Date: Wed, 26 Mar 2025 20:10:53 -0400 Subject: [PATCH] Adds extendable timeouts Closes #87 --- src/docket/__init__.py | 2 + src/docket/dependencies.py | 31 +++++++++ src/docket/worker.py | 55 ++++++++++++--- tests/conftest.py | 4 +- tests/test_fundamentals.py | 136 ++++++++++++++++++++++++++++++++++++- tests/test_worker.py | 6 +- 6 files changed, 219 insertions(+), 15 deletions(-) diff --git a/src/docket/__init__.py b/src/docket/__init__.py index 8ed49f0..bbfb566 100644 --- a/src/docket/__init__.py +++ b/src/docket/__init__.py @@ -18,6 +18,7 @@ Retry, TaskKey, TaskLogger, + Timeout, ) from .docket import Docket from .execution import Execution @@ -36,5 +37,6 @@ "ExponentialRetry", "Logged", "Perpetual", + "Timeout", "__version__", ] diff --git a/src/docket/dependencies.py b/src/docket/dependencies.py index c04e42c..4fe9659 100644 --- a/src/docket/dependencies.py +++ b/src/docket/dependencies.py @@ -1,6 +1,7 @@ import abc import inspect import logging +import time from datetime import timedelta from typing import Any, Awaitable, Callable, Counter, TypeVar, cast @@ -171,6 +172,36 @@ def perpetuate(self, *args: Any, **kwargs: Any) -> None: self.kwargs = kwargs +class Timeout(Dependency): + single = True + + base: timedelta + + _deadline: float + + def __init__(self, base: timedelta) -> None: + self.base = base + + def __call__( + self, docket: Docket, worker: Worker, execution: Execution + ) -> "Timeout": + return Timeout(base=self.base) + + def start(self) -> None: + self._deadline = time.monotonic() + self.base.total_seconds() + + def expired(self) -> bool: + return time.monotonic() >= self._deadline + + def remaining(self) -> timedelta: + return timedelta(seconds=self._deadline - time.monotonic()) + + def extend(self, by: timedelta | None = None) -> None: + if by is None: + by = self.base + self._deadline += by.total_seconds() + + def get_dependency_parameters( function: Callable[..., Awaitable[Any]], ) -> dict[str, Dependency]: diff --git a/src/docket/worker.py b/src/docket/worker.py index 3799ce8..674a2d3 100644 --- a/src/docket/worker.py +++ b/src/docket/worker.py @@ -50,7 +50,7 @@ if TYPE_CHECKING: # pragma: no cover - from .dependencies import Dependency + from .dependencies import Dependency, Timeout class _stream_due_tasks(Protocol): @@ -507,13 +507,20 @@ async def _execute(self, message: RedisMessage) -> None: }, links=links, ): - await execution.function( - *execution.args, - **{ - **execution.kwargs, - **dependencies, - }, - ) + from .dependencies import Timeout, get_single_dependency_of_type + + if timeout := get_single_dependency_of_type(dependencies, Timeout): + await self._run_function_with_timeout( + execution, dependencies, timeout + ) + else: + await execution.function( + *execution.args, + **{ + **execution.kwargs, + **dependencies, + }, + ) TASKS_SUCCEEDED.add(1, counter_labels) duration = datetime.now(timezone.utc) - start @@ -539,6 +546,38 @@ async def _execute(self, message: RedisMessage) -> None: TASKS_COMPLETED.add(1, counter_labels) TASK_DURATION.record(duration.total_seconds(), counter_labels) + async def _run_function_with_timeout( + self, + execution: Execution, + dependencies: dict[str, "Dependency"], + timeout: "Timeout", + ) -> None: + timeout.start() + task_coro = execution.function( + *execution.args, **execution.kwargs, **dependencies + ) + task = asyncio.create_task(task_coro) + try: + while not task.done(): # pragma: no branch + remaining = timeout.remaining().total_seconds() + if timeout.expired(): + task.cancel() + break + + try: + await asyncio.wait_for(asyncio.shield(task), timeout=remaining) + return + except asyncio.TimeoutError: + continue + finally: + if not task.done(): + task.cancel() + + try: + await task + except asyncio.CancelledError: + raise asyncio.TimeoutError + def _get_dependencies( self, execution: Execution, diff --git a/tests/conftest.py b/tests/conftest.py index b30b1a3..3ccae28 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,8 +156,8 @@ async def docket(redis_url: str, aiolib: str) -> AsyncGenerator[Docket, None]: async def worker(docket: Docket) -> AsyncGenerator[Worker, None]: async with Worker( docket, - minimum_check_interval=timedelta(milliseconds=10), - scheduling_resolution=timedelta(milliseconds=10), + minimum_check_interval=timedelta(milliseconds=5), + scheduling_resolution=timedelta(milliseconds=5), ) as worker: yield worker diff --git a/tests/test_fundamentals.py b/tests/test_fundamentals.py index 58604ed..2d07d39 100644 --- a/tests/test_fundamentals.py +++ b/tests/test_fundamentals.py @@ -5,8 +5,9 @@ as possible to aid with understanding docket. """ +import asyncio import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from logging import LoggerAdapter from typing import Annotated, Callable from unittest.mock import AsyncMock, call @@ -26,6 +27,7 @@ Retry, TaskKey, TaskLogger, + Timeout, Worker, tasks, ) @@ -1065,3 +1067,135 @@ async def my_automatic_task( await worker.run_at_most({"my_automatic_task": 3}) assert calls == 3 + + +async def test_simple_timeout(docket: Docket, worker: Worker): + """A task can be scheduled with a timeout""" + + called = False + + async def task_with_timeout( + timeout: Timeout = Timeout(timedelta(milliseconds=100)), + ): + await asyncio.sleep(0.01) + + nonlocal called + called = True + + await docket.add(task_with_timeout)() + + start = datetime.now(timezone.utc) + + await worker.run_until_finished() + + elapsed = datetime.now(timezone.utc) - start + + assert called + assert elapsed <= timedelta(milliseconds=150) + + +async def test_simple_timeout_cancels_tasks(docket: Docket, worker: Worker): + """A task can be scheduled with a timeout and are cancelled""" + + called = False + + async def task_with_timeout( + timeout: Timeout = Timeout(timedelta(milliseconds=100)), + ): + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + nonlocal called + called = True + + await docket.add(task_with_timeout)() + + start = datetime.now(timezone.utc) + + await worker.run_until_finished() + + elapsed = datetime.now(timezone.utc) - start + + assert called + assert timedelta(milliseconds=100) <= elapsed <= timedelta(milliseconds=200) + + +async def test_timeout_can_be_extended(docket: Docket, worker: Worker): + """A task can be scheduled with a timeout and extend themselves""" + + called = False + + async def task_with_timeout( + timeout: Timeout = Timeout(timedelta(milliseconds=100)), + ): + await asyncio.sleep(0.05) + + timeout.extend(timedelta(milliseconds=200)) + + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + nonlocal called + called = True + + await docket.add(task_with_timeout)() + + start = datetime.now(timezone.utc) + + await worker.run_until_finished() + + elapsed = datetime.now(timezone.utc) - start + + assert called + assert timedelta(milliseconds=250) <= elapsed <= timedelta(milliseconds=400) + + +async def test_timeout_extends_by_base_by_default(docket: Docket, worker: Worker): + """A task can be scheduled with a timeout and extend itself by the base timeout""" + + called = False + + async def task_with_timeout( + timeout: Timeout = Timeout(timedelta(milliseconds=100)), + ): + await asyncio.sleep(0.05) + + timeout.extend() # defaults to the base timeout + + try: + await asyncio.sleep(5) + except asyncio.CancelledError: + nonlocal called + called = True + + await docket.add(task_with_timeout)() + + start = datetime.now(timezone.utc) + + await worker.run_until_finished() + + elapsed = datetime.now(timezone.utc) - start + + assert called + assert timedelta(milliseconds=150) <= elapsed <= timedelta(milliseconds=300) + + +async def test_timeout_is_compatible_with_retry(docket: Docket, worker: Worker): + """A task that times out can be retried""" + + successes: list[int] = [] + + async def task_with_timeout( + retry: Retry = Retry(attempts=3), + timeout: Timeout = Timeout(timedelta(milliseconds=100)), + ): + if retry.attempt == 1: + await asyncio.sleep(1) + + successes.append(retry.attempt) + + await docket.add(task_with_timeout)() + + await worker.run_until_finished() + + assert successes == [2] diff --git a/tests/test_worker.py b/tests/test_worker.py index fb4f9f3..0dae2b0 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -382,14 +382,12 @@ async def perpetual_task( intervals = [next - previous for previous, next in zip(timestamps, timestamps[1:])] minimum = min(intervals) - maximum = max(intervals) debug = ", ".join([f"{i.total_seconds() * 1000:.2f}ms" for i in intervals]) - # even with a variable duration, Docket attempts to schedule them equally and to - # abide by the target interval + # It's not reliable to assert the maximum duration on different machine setups, but + # we'll make sure that the minimum is observed, which is the guarantee assert minimum >= timedelta(milliseconds=50), debug - assert maximum <= timedelta(milliseconds=75), debug async def test_worker_can_exit_from_perpetual_tasks_that_queue_further_tasks(