diff --git a/examples/task_progress.py b/examples/task_progress.py new file mode 100644 index 0000000..9cb9494 --- /dev/null +++ b/examples/task_progress.py @@ -0,0 +1,111 @@ +"""Example demonstrating task progress tracking and real-time monitoring. + +This example shows how to: +- Report progress from within a task using ExecutionProgress +- Track progress with current value, total, and status messages +- Monitor task progress in real-time using the 'docket watch' command +- Schedule tasks for future execution + +Key Concepts: +- ExecutionProgress: Tracks task progress (current/total) and status messages +- Progress dependency: Injected into tasks via Progress() default parameter +- Real-time monitoring: Use 'docket watch' CLI to monitor running tasks +- State tracking: Tasks transition through SCHEDULED → QUEUED → RUNNING → COMPLETED + +Run this example with `uv run -m examples.task_progress` and use the printed 'docket watch' command to see live progress updates. +""" + +from datetime import datetime, timedelta, timezone +from docket import Docket, Progress, Worker +import asyncio +import rich.console + +from docket.execution import ExecutionProgress + +from .common import run_redis + + +async def long_task(progress: ExecutionProgress = Progress()) -> None: + """A long-running task that reports progress as it executes. + + This demonstrates the key progress tracking patterns: + - Progress dependency injection via Progress() default parameter + - Incremental progress updates with increment() + - Status messages with set_message() + + The ExecutionProgress object has a default total of 100, so we don't need + to call set_total() in this example. The progress automatically increments + from 0 to 100. + + Args: + progress: Injected ExecutionProgress tracker (automatically provided by Docket) + + Pattern for your own tasks: + 1. Add progress parameter with Progress() default + 2. Call increment() as work progresses (or set_total + increment) + 3. Optionally set_message() to show current status + 4. Monitor with: docket watch --url --docket + """ + # Simulate 100 steps of work, each taking 1 second + for i in range(1, 101): + await asyncio.sleep(1) # Simulate work being done + + # Increment progress by 1 (tracks that one more unit is complete) + await progress.increment() + + # Update status message every 10 items for demonstration + if i % 10 == 0: + await progress.set_message(f"{i} splines retriculated") + + +# Export tasks for docket CLI to discover +tasks = [long_task] + +# Console for printing user-friendly messages +console = rich.console.Console() + + +async def main(): + """Run the progress tracking example. + + This function demonstrates the complete lifecycle: + 1. Start a Redis container for testing + 2. Create a Docket (task queue) + 3. Start a Worker (executes tasks) + 4. Register and schedule a task + 5. Monitor progress with the 'docket watch' command + + The task is scheduled 20 seconds in the future to give you time to + run the watch command and see the task transition through states: + SCHEDULED → QUEUED → RUNNING → COMPLETED + """ + # Start a temporary Redis container for this example + # In production, you'd connect to your existing Redis instance + async with run_redis("7.4.2") as redis_url: + # Create a Docket connected to Redis + async with Docket(name="task-progress", url=redis_url) as docket: + # Start a Worker to execute tasks from the docket + async with Worker(docket, name="task-progress-worker") as worker: + # Register the task so the worker knows how to execute it + docket.register(long_task) + + # Schedule the task to run 20 seconds from now + # This gives you time to run the watch command before it starts + in_twenty_seconds = datetime.now(timezone.utc) + timedelta(seconds=20) + execution = await docket.add( + long_task, key="long-task", when=in_twenty_seconds + )() + + # Print instructions for monitoring + console.print(f"Execution {execution.key} started!") + console.print( + f"Run [blue]docket watch --url {redis_url} --docket {docket.name} {execution.key}[/blue] to see the progress!" + ) + + # Run the worker until all tasks complete + # The worker will wait for the scheduled time, then execute the task + await worker.run_until_finished() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index b321038..90c26fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dev = [ "pytest>=8.3.4", "pytest-asyncio>=0.24.0", "pytest-cov>=6.0.0", + "pytest-timeout>=2.4.0", "pytest-xdist>=3.6.1", "ruff>=0.9.7", ] diff --git a/src/docket/__init__.py b/src/docket/__init__.py index ff4d66c..6e405e4 100644 --- a/src/docket/__init__.py +++ b/src/docket/__init__.py @@ -18,6 +18,7 @@ Depends, ExponentialRetry, Perpetual, + Progress, Retry, TaskArgument, TaskKey, @@ -25,7 +26,7 @@ Timeout, ) from .docket import Docket -from .execution import Execution +from .execution import Execution, ExecutionState from .worker import Worker __all__ = [ @@ -38,9 +39,11 @@ "Depends", "Docket", "Execution", + "ExecutionState", "ExponentialRetry", "Logged", "Perpetual", + "Progress", "Retry", "TaskArgument", "TaskKey", diff --git a/src/docket/agenda.py b/src/docket/agenda.py index 91e1fc3..9b44252 100644 --- a/src/docket/agenda.py +++ b/src/docket/agenda.py @@ -181,6 +181,7 @@ async def scatter( # Create execution with unique key key = str(uuid7()) execution = Execution( + docket=docket, function=resolved_func, args=args, kwargs=kwargs, diff --git a/src/docket/cli.py b/src/docket/cli.py index 1534380..671877b 100644 --- a/src/docket/cli.py +++ b/src/docket/cli.py @@ -5,17 +5,29 @@ import os import socket import sys +import time from datetime import datetime, timedelta, timezone from functools import partial from typing import Annotated, Any, Collection +from unittest.mock import AsyncMock import typer from rich.console import Console +from rich.layout import Layout +from rich.live import Live +from rich.progress import ( + BarColumn, + Progress, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TaskID, +) from rich.table import Table from . import __version__, tasks from .docket import Docket, DocketSnapshot, WorkerInfo -from .execution import Operator +from .execution import Execution, ExecutionState, Operator from .worker import Worker app: typer.Typer = typer.Typer( @@ -376,7 +388,7 @@ async def run() -> None: asyncio.run(run()) -@app.command(help="Clear all pending and scheduled tasks from the docket") +@app.command(help="Clear all queued and scheduled tasks from the docket") def clear( docket_: Annotated[ str, @@ -509,10 +521,7 @@ async def run() -> None: async with Docket(name=docket_, url=url) as docket: when = datetime.now(timezone.utc) + delay execution = await docket.add(tasks.trace, when)(message) - print( - f"Added {execution.function.__name__} task {execution.key!r} to " - f"the docket {docket.name!r}" - ) + print(f"Added trace task {execution.key!r} to the docket {docket.name!r}") asyncio.run(run()) @@ -553,10 +562,7 @@ async def run() -> None: async with Docket(name=docket_, url=url) as docket: when = datetime.now(timezone.utc) + delay execution = await docket.add(tasks.fail, when)(message) - print( - f"Added {execution.function.__name__} task {execution.key!r} to " - f"the docket {docket.name!r}" - ) + print(f"Added fail task {execution.key!r} to the docket {docket.name!r}") asyncio.run(run()) @@ -597,10 +603,7 @@ async def run() -> None: async with Docket(name=docket_, url=url) as docket: when = datetime.now(timezone.utc) + delay execution = await docket.add(tasks.sleep, when)(seconds) - print( - f"Added {execution.function.__name__} task {execution.key!r} to " - f"the docket {docket.name!r}" - ) + print(f"Added sleep task {execution.key!r} to the docket {docket.name!r}") asyncio.run(run()) @@ -810,6 +813,214 @@ async def run() -> DocketSnapshot: console.print(stats_table) +@app.command(help="Monitor progress of a specific task execution") +def watch( + key: Annotated[str, typer.Argument(help="The task execution key to monitor")], + url: Annotated[ + str, + typer.Option( + "--url", + "-u", + envvar="DOCKET_REDIS_URL", + help="Redis URL (e.g., redis://localhost:6379/0)", + ), + ] = "redis://localhost:6379/0", + docket_name: Annotated[ + str, + typer.Option( + "--docket", + "-d", + envvar="DOCKET_NAME", + help="Docket name", + ), + ] = "docket", +) -> None: + """Monitor the progress of a specific task execution in real-time using event-driven updates.""" + + async def monitor() -> None: + async with Docket(docket_name, url) as docket: + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), key, 1 + ) # TODO: Replace AsyncMock with actual task function + console = Console() + + # State colors for display + state_colors = { + ExecutionState.SCHEDULED: "yellow", + ExecutionState.QUEUED: "cyan", + ExecutionState.RUNNING: "blue", + ExecutionState.COMPLETED: "green", + ExecutionState.FAILED: "red", + } + + # Load initial snapshot + await execution.sync() + + # Track current state for display + current_state = execution.state + worker_name: str | None = execution.worker + error_message: str | None = execution.error + + # Initialize progress values + current_val = ( + execution.progress.current + if execution.progress.current is not None + else 0 + ) + total_val = execution.progress.total + progress_message = execution.progress.message + + active_progress = Progress( + TextColumn("[bold blue]{task.description}"), + BarColumn(bar_width=None), # Auto width + TaskProgressColumn(), + TimeElapsedColumn(), + expand=True, + ) + + progress_task_id = None + + def set_progress_start_time(task_id: TaskID, started_at: datetime) -> None: + """Set progress bar start time based on execution start time.""" + elapsed_since_start = datetime.now(timezone.utc) - started_at + monotonic_start = time.monotonic() - elapsed_since_start.total_seconds() + active_progress.tasks[task_id].start_time = monotonic_start + + # Initialize progress task if we have progress data + if current_val > 0 and total_val > 0: + progress_task_id = active_progress.add_task( + progress_message or "Processing...", + total=total_val, + completed=current_val, + ) + # Set start time based on execution.started_at if available + if execution.started_at is not None: # pragma: no cover + set_progress_start_time(progress_task_id, execution.started_at) + + def create_display_layout() -> Layout: + """Create the layout for watch display.""" + layout = Layout() + + # Build info lines + info_lines = [ + f"[bold]Task:[/bold] {key}", + f"[bold]Docket:[/bold] {docket_name}", + ] + + # Add state with color + state_color = state_colors.get(current_state, "white") + info_lines.append( + f"[bold]State:[/bold] [{state_color}]{current_state.value.upper()}[/{state_color}]" + ) + + # Add worker if available + if worker_name: + info_lines.append(f"[bold]Worker:[/bold] {worker_name}") + + # Add error if failed + if error_message: + info_lines.append(f"[red bold]Error:[/red bold] {error_message}") + + # Add completion status + if current_state == ExecutionState.COMPLETED: + info_lines.append( + "[green bold]✓ Task completed successfully[/green bold]" + ) + elif current_state == ExecutionState.FAILED: + info_lines.append("[red bold]✗ Task failed[/red bold]") + + info_section = "\n".join(info_lines) + + # Build layout without big gaps + if progress_task_id is not None: + # Choose the right progress instance + # Show info and progress together with minimal spacing + layout.split_column( + Layout(info_section, name="info", size=len(info_lines)), + Layout(active_progress, name="progress", size=2), + ) + else: + # Just show info + layout.update(Layout(info_section, name="info")) + + return layout + + # Create initial layout + layout = create_display_layout() + + # If already in terminal state, display once and exit + if current_state in (ExecutionState.COMPLETED, ExecutionState.FAILED): + console.print(layout) + return + + # Use Live for smooth updates + with Live(layout, console=console, refresh_per_second=4) as live: + # Subscribe to events and update display + async for event in execution.subscribe(): # pragma: no cover + if event["type"] == "state": + # Update state information + current_state = ExecutionState(event["state"]) + if worker := event.get("worker"): + worker_name = worker + if error := event.get("error"): + error_message = error + if started_at := event.get("started_at"): + execution.started_at = datetime.fromisoformat(started_at) + # Update progress bar start time if we have a progress task + if progress_task_id is not None: + set_progress_start_time( + progress_task_id, execution.started_at + ) + + # Update layout + layout = create_display_layout() + live.update(layout) + + # Exit if terminal state reached + if current_state in ( + ExecutionState.COMPLETED, + ExecutionState.FAILED, + ): + break + + elif event["type"] == "progress": + # Update progress information + current_val = event["current"] + total_val: int = event.get("total", execution.progress.total) + progress_message = event.get( + "message", execution.progress.message + ) + + # Update or create progress task + if total_val > 0 and execution.started_at is not None: + if progress_task_id is None: + # Create new progress task (first time only) + progress_task_id = active_progress.add_task( + progress_message or "Processing...", + total=total_val, + completed=current_val or 0, + ) + # Set start time based on execution.started_at if available + if started_at := execution.started_at: + set_progress_start_time( + progress_task_id, execution.started_at + ) + else: + # Update existing progress task + active_progress.update( + progress_task_id, + completed=current_val, + total=total_val, + description=progress_message or "Processing...", + ) + + # Update layout + layout = create_display_layout() + live.update(layout) + + asyncio.run(monitor()) + + workers_app: typer.Typer = typer.Typer( help="Look at the workers on a docket", no_args_is_help=True ) diff --git a/src/docket/dependencies.py b/src/docket/dependencies.py index 40daddd..e899005 100644 --- a/src/docket/dependencies.py +++ b/src/docket/dependencies.py @@ -22,8 +22,9 @@ ) from .docket import Docket -from .execution import Execution, TaskFunction, get_signature +from .execution import Execution, ExecutionProgress, TaskFunction, get_signature from .instrumentation import CACHE_SIZE +# Run and RunProgress have been consolidated into Execution if TYPE_CHECKING: # pragma: no cover from .worker import Worker @@ -192,6 +193,33 @@ async def my_task(logger: "LoggerAdapter[Logger]" = TaskLogger()) -> None: return cast("logging.LoggerAdapter[logging.Logger]", _TaskLogger()) +class _Progress(Dependency): + async def __aenter__(self) -> ExecutionProgress: + execution = self.execution.get() + return execution.progress + + +def Progress() -> ExecutionProgress: + """A dependency to report progress updates for the currently executing task. + + Tasks can use this to report their current progress (current/total values) and + status messages to external observers. + + Example: + + ```python + @task + async def process_records(records: list, progress: ExecutionProgress = Progress()) -> None: + await progress.set_total(len(records)) + for i, record in enumerate(records): + await process(record) + await progress.increment() + await progress.set_message(f"Processed {record.id}") + ``` + """ + return cast(ExecutionProgress, _Progress()) + + class ForcedRetry(Exception): """Raised when a task requests a retry via `in_` or `at`""" diff --git a/src/docket/docket.py b/src/docket/docket.py index f573f7e..409ce25 100644 --- a/src/docket/docket.py +++ b/src/docket/docket.py @@ -27,7 +27,7 @@ from typing_extensions import Self import redis.exceptions -from opentelemetry import propagate, trace +from opentelemetry import trace from redis.asyncio import ConnectionPool, Redis from uuid_extensions import uuid7 @@ -41,6 +41,7 @@ StrikeList, TaskFunction, ) + from .instrumentation import ( REDIS_DISRUPTIONS, STRIKES_IN_EFFECT, @@ -49,19 +50,12 @@ TASKS_REPLACED, TASKS_SCHEDULED, TASKS_STRICKEN, - message_setter, ) logger: logging.Logger = logging.getLogger(__name__) tracer: trace.Tracer = trace.get_tracer(__name__) -class _schedule_task(Protocol): - async def __call__( - self, keys: list[str], args: list[str | float | bytes] - ) -> str: ... # pragma: no cover - - class _cancel_task(Protocol): async def __call__( self, keys: list[str], args: list[str] @@ -144,7 +138,6 @@ async def my_task(greeting: str, recipient: str) -> None: _monitor_strikes_task: asyncio.Task[None] _connection_pool: ConnectionPool - _schedule_task_script: _schedule_task | None _cancel_task_script: _cancel_task | None def __init__( @@ -153,6 +146,7 @@ def __init__( url: str = "redis://localhost:6379/0", heartbeat_interval: timedelta = timedelta(seconds=2), missed_heartbeats: int = 5, + execution_ttl: timedelta = timedelta(hours=1), ) -> None: """ Args: @@ -167,12 +161,14 @@ def __init__( heartbeat_interval: How often workers send heartbeat messages to the docket. missed_heartbeats: How many heartbeats a worker can miss before it is considered dead. + execution_ttl: How long to keep completed or failed execution state records + in Redis before they expire. Defaults to 1 hour. """ self.name = name self.url = url self.heartbeat_interval = heartbeat_interval self.missed_heartbeats = missed_heartbeats - self._schedule_task_script = None + self.execution_ttl = execution_ttl self._cancel_task_script = None @property @@ -338,10 +334,27 @@ def add( key = str(uuid7()) async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution: - execution = Execution(function, args, kwargs, when, key, attempt=1) + execution = Execution(self, function, args, kwargs, when, key, attempt=1) + + # Check if task is stricken before scheduling + if self.strike_list.is_stricken(execution): + logger.warning( + "%r is stricken, skipping schedule of %r", + execution.function.__name__, + execution.key, + ) + TASKS_STRICKEN.add( + 1, + { + **self.labels(), + **execution.general_labels(), + "docket.where": "docket", + }, + ) + return execution - async with self.redis() as redis: - await self._schedule(redis, execution, replace=False) + # Schedule atomically (includes state record write) + await execution.schedule(replace=False) TASKS_ADDED.add(1, {**self.labels(), **execution.general_labels()}) TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()}) @@ -397,10 +410,27 @@ def replace( function = self.tasks[function] async def scheduler(*args: P.args, **kwargs: P.kwargs) -> Execution: - execution = Execution(function, args, kwargs, when, key, attempt=1) + execution = Execution(self, function, args, kwargs, when, key, attempt=1) + + # Check if task is stricken before scheduling + if self.strike_list.is_stricken(execution): + logger.warning( + "%r is stricken, skipping schedule of %r", + execution.function.__name__, + execution.key, + ) + TASKS_STRICKEN.add( + 1, + { + **self.labels(), + **execution.general_labels(), + "docket.where": "docket", + }, + ) + return execution - async with self.redis() as redis: - await self._schedule(redis, execution, replace=True) + # Schedule atomically (includes state record write) + await execution.schedule(replace=True) TASKS_REPLACED.add(1, {**self.labels(), **execution.general_labels()}) TASKS_CANCELLED.add(1, {**self.labels(), **execution.general_labels()}) @@ -419,8 +449,25 @@ async def schedule(self, execution: Execution) -> None: "code.function.name": execution.function.__name__, }, ): - async with self.redis() as redis: - await self._schedule(redis, execution, replace=False) + # Check if task is stricken before scheduling + if self.strike_list.is_stricken(execution): + logger.warning( + "%r is stricken, skipping schedule of %r", + execution.function.__name__, + execution.key, + ) + TASKS_STRICKEN.add( + 1, + { + **self.labels(), + **execution.general_labels(), + "docket.where": "docket", + }, + ) + return + + # Schedule atomically (includes state record write) + await execution.schedule(replace=False) TASKS_SCHEDULED.add(1, {**self.labels(), **execution.general_labels()}) @@ -456,125 +503,6 @@ def parked_task_key(self, key: str) -> str: def stream_id_key(self, key: str) -> str: return f"{self.name}:stream-id:{key}" - async def _schedule( - self, - redis: Redis, - execution: Execution, - replace: bool = False, - ) -> None: - """Schedule a task atomically. - - Handles: - - Checking for task existence - - Cancelling existing tasks when replacing - - Adding tasks to stream (immediate) or queue (future) - - Tracking stream message IDs for later cancellation - """ - if self.strike_list.is_stricken(execution): - logger.warning( - "%r is stricken, skipping schedule of %r", - execution.function.__name__, - execution.key, - ) - TASKS_STRICKEN.add( - 1, - { - **self.labels(), - **execution.general_labels(), - "docket.where": "docket", - }, - ) - return - - message: dict[bytes, bytes] = execution.as_message() - propagate.inject(message, setter=message_setter) - - key = execution.key - when = execution.when - known_task_key = self.known_task_key(key) - is_immediate = when <= datetime.now(timezone.utc) - - # Lock per task key to prevent race conditions between concurrent operations - async with redis.lock(f"{known_task_key}:lock", timeout=10): - if self._schedule_task_script is None: - self._schedule_task_script = cast( - _schedule_task, - redis.register_script( - # KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key - # ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields - """ - local stream_key = KEYS[1] - local known_key = KEYS[2] - local parked_key = KEYS[3] - local queue_key = KEYS[4] - local stream_id_key = KEYS[5] - - local task_key = ARGV[1] - local when_timestamp = ARGV[2] - local is_immediate = ARGV[3] == '1' - local replace = ARGV[4] == '1' - - -- Extract message fields from ARGV[5] onwards - local message = {} - for i = 5, #ARGV, 2 do - message[#message + 1] = ARGV[i] -- field name - message[#message + 1] = ARGV[i + 1] -- field value - end - - -- Handle replacement: cancel existing task if needed - if replace then - local existing_message_id = redis.call('GET', stream_id_key) - if existing_message_id then - redis.call('XDEL', stream_key, existing_message_id) - end - redis.call('DEL', known_key, parked_key, stream_id_key) - redis.call('ZREM', queue_key, task_key) - else - -- Check if task already exists - if redis.call('EXISTS', known_key) == 1 then - return 'EXISTS' - end - end - - if is_immediate then - -- Add to stream and store message ID for later cancellation - local message_id = redis.call('XADD', stream_key, '*', unpack(message)) - redis.call('SET', known_key, when_timestamp) - redis.call('SET', stream_id_key, message_id) - return message_id - else - -- Add to queue with task data in parked hash - redis.call('SET', known_key, when_timestamp) - redis.call('HSET', parked_key, unpack(message)) - redis.call('ZADD', queue_key, when_timestamp, task_key) - return 'QUEUED' - end - """ - ), - ) - schedule_task = self._schedule_task_script - - await schedule_task( - keys=[ - self.stream_key, - known_task_key, - self.parked_task_key(key), - self.queue_key, - self.stream_id_key(key), - ], - args=[ - key, - str(when.timestamp()), - "1" if is_immediate else "0", - "1" if replace else "0", - *[ - item - for field, value in message.items() - for item in (field, value) - ], - ], - ) - async def _cancel(self, redis: Redis, key: str) -> None: """Cancel a task atomically. @@ -587,24 +515,33 @@ async def _cancel(self, redis: Redis, key: str) -> None: self._cancel_task_script = cast( _cancel_task, redis.register_script( - # KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key + # KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key, runs_key # ARGV: task_key """ local stream_key = KEYS[1] + -- TODO: Remove in next breaking release (v0.14.0) - legacy key locations local known_key = KEYS[2] local parked_key = KEYS[3] local queue_key = KEYS[4] local stream_id_key = KEYS[5] + local runs_key = KEYS[6] local task_key = ARGV[1] + -- Get stream ID (check new location first, then legacy) + local message_id = redis.call('HGET', runs_key, 'stream_id') + + -- TODO: Remove in next breaking release (v0.14.0) - check legacy location + if not message_id then + message_id = redis.call('GET', stream_id_key) + end + -- Delete from stream if message ID exists - local message_id = redis.call('GET', stream_id_key) if message_id then redis.call('XDEL', stream_key, message_id) end - -- Clean up all task-related keys - redis.call('DEL', known_key, parked_key, stream_id_key) + -- Clean up all task-related keys (handles both new and legacy formats) + redis.call('DEL', known_key, parked_key, stream_id_key, runs_key) redis.call('ZREM', queue_key, task_key) return 'OK' @@ -621,6 +558,7 @@ async def _cancel(self, redis: Redis, key: str) -> None: self.parked_task_key(key), self.queue_key, self.stream_id_key(key), + f"{self.name}:runs:{key}", # runs_key ], args=[key], ) @@ -791,8 +729,7 @@ async def snapshot(self) -> DocketSnapshot: } for message_id, message in stream_messages: - function = self.tasks[message[b"function"].decode()] - execution = Execution.from_message(function, message) + execution = await Execution.from_message(self, message) if message_id in pending_lookup: worker_name = pending_lookup[message_id]["consumer"].decode() started = now - timedelta( @@ -803,8 +740,7 @@ async def snapshot(self) -> DocketSnapshot: future.append(execution) # pragma: no cover for message in queued_messages: - function = self.tasks[message[b"function"].decode()] - execution = Execution.from_message(function, message) + execution = await Execution.from_message(self, message) future.append(execution) workers = await self.workers() @@ -894,7 +830,7 @@ async def task_workers(self, task_name: str) -> Collection[WorkerInfo]: return workers async def clear(self) -> int: - """Clear all pending and scheduled tasks from the docket. + """Clear all queued and scheduled tasks from the docket. This removes all tasks from the stream (immediate tasks) and queue (scheduled tasks), along with their associated parked data. Running diff --git a/src/docket/execution.py b/src/docket/execution.py index 4c04718..8b5528d 100644 --- a/src/docket/execution.py +++ b/src/docket/execution.py @@ -1,9 +1,22 @@ import abc import enum import inspect +import json import logging -from datetime import datetime -from typing import Any, Awaitable, Callable, Hashable, Literal, Mapping, cast +from datetime import datetime, timezone +from typing import ( + TYPE_CHECKING, + Any, + AsyncGenerator, + Awaitable, + Callable, + Hashable, + Literal, + Mapping, + Protocol, + TypedDict, + cast, +) from typing_extensions import Self @@ -12,7 +25,10 @@ from opentelemetry import propagate, trace from .annotations import Logged -from .instrumentation import CACHE_SIZE, message_getter +from .instrumentation import CACHE_SIZE, message_getter, message_setter + +if TYPE_CHECKING: + from .docket import Docket logger: logging.Logger = logging.getLogger(__name__) @@ -20,6 +36,12 @@ Message = dict[bytes, bytes] +class _schedule_task(Protocol): + async def __call__( + self, keys: list[str], args: list[str | float | bytes] + ) -> str: ... # pragma: no cover + + _signature_cache: dict[Callable[..., Any], inspect.Signature] = {} @@ -40,9 +62,247 @@ def get_signature(function: Callable[..., Any]) -> inspect.Signature: return signature +class ExecutionState(enum.Enum): + """Lifecycle states for task execution.""" + + SCHEDULED = "scheduled" + """Task is scheduled and waiting in the queue for its execution time.""" + + QUEUED = "queued" + """Task has been moved to the stream and is ready to be claimed by a worker.""" + + RUNNING = "running" + """Task is currently being executed by a worker.""" + + COMPLETED = "completed" + """Task execution finished successfully.""" + + FAILED = "failed" + """Task execution failed.""" + + +class ProgressEvent(TypedDict): + type: Literal["progress"] + key: str + current: int | None + total: int + message: str | None + updated_at: str | None + + +class StateEvent(TypedDict): + type: Literal["state"] + key: str + state: ExecutionState + when: str + worker: str | None + started_at: str | None + completed_at: str | None + error: str | None + + +class ExecutionProgress: + """Manages user-reported progress for a task execution. + + Progress data is stored in Redis hash {docket}:progress:{key} and includes: + - current: Current progress value (integer) + - total: Total/target value (integer) + - message: User-provided status message (string) + - updated_at: Timestamp of last update (ISO 8601 string) + + This data is ephemeral and deleted when the task completes. + """ + + def __init__(self, docket: "Docket", key: str) -> None: + """Initialize progress tracker for a specific task. + + Args: + docket: The docket instance + key: The task execution key + """ + self.docket = docket + self.key = key + self._redis_key = f"{docket.name}:progress:{key}" + self.current: int | None = None + self.total: int = 1 + self.message: str | None = None + self.updated_at: datetime | None = None + + @classmethod + async def create(cls, docket: "Docket", key: str) -> Self: + """Create and initialize progress tracker by reading from Redis. + + Args: + docket: The docket instance + key: The task execution key + + Returns: + ExecutionProgress instance with attributes populated from Redis + """ + instance = cls(docket, key) + await instance.sync() + return instance + + async def set_total(self, total: int) -> None: + """Set the total/target value for progress tracking. + + Args: + total: The total number of units to complete. Must be at least 1. + """ + if total < 1: + raise ValueError("Total must be at least 1") + + updated_at_dt = datetime.now(timezone.utc) + updated_at = updated_at_dt.isoformat() + async with self.docket.redis() as redis: + await redis.hset( + self._redis_key, + mapping={ + "total": str(total), + "updated_at": updated_at, + }, + ) + # Update instance attributes + self.total = total + self.updated_at = updated_at_dt + # Publish update event + await self._publish({"total": total, "updated_at": updated_at}) + + async def increment(self, amount: int = 1) -> None: + """Atomically increment the current progress value. + + Args: + amount: Amount to increment by. Must be at least 1. + """ + if amount < 1: + raise ValueError("Amount must be at least 1") + + updated_at_dt = datetime.now(timezone.utc) + updated_at = updated_at_dt.isoformat() + async with self.docket.redis() as redis: + new_current = await redis.hincrby(self._redis_key, "current", amount) + await redis.hset( + self._redis_key, + "updated_at", + updated_at, + ) + # Update instance attributes using Redis return value + self.current = new_current + self.updated_at = updated_at_dt + # Publish update event with new current value + await self._publish({"current": new_current, "updated_at": updated_at}) + + async def set_message(self, message: str | None) -> None: + """Update the progress status message. + + Args: + message: Status message describing current progress + """ + updated_at_dt = datetime.now(timezone.utc) + updated_at = updated_at_dt.isoformat() + async with self.docket.redis() as redis: + await redis.hset( + self._redis_key, + mapping={ + "message": message, + "updated_at": updated_at, + }, + ) + # Update instance attributes + self.message = message + self.updated_at = updated_at_dt + # Publish update event + await self._publish({"message": message, "updated_at": updated_at}) + + async def sync(self) -> None: + """Synchronize instance attributes with current progress data from Redis. + + Updates self.current, self.total, self.message, and self.updated_at + with values from Redis. Sets attributes to None if no data exists. + """ + async with self.docket.redis() as redis: + data = await redis.hgetall(self._redis_key) + if data: + self.current = int(data.get(b"current", b"0")) + self.total = int(data.get(b"total", b"100")) + self.message = data[b"message"].decode() if b"message" in data else None + self.updated_at = ( + datetime.fromisoformat(data[b"updated_at"].decode()) + if b"updated_at" in data + else None + ) + else: + self.current = None + self.total = 100 + self.message = None + self.updated_at = None + + async def _delete(self) -> None: + """Delete the progress data from Redis. + + Called internally when task execution completes. + """ + async with self.docket.redis() as redis: + await redis.delete(self._redis_key) + # Reset instance attributes + self.current = None + self.total = 100 + self.message = None + self.updated_at = None + + async def _publish(self, data: dict[str, Any]) -> None: + """Publish progress update to Redis pub/sub channel. + + Args: + data: Progress data to publish (partial update) + """ + channel = f"{self.docket.name}:progress:{self.key}" + # Create ephemeral Redis client for publishing + async with self.docket.redis() as redis: + # Use instance attributes for current state + payload: ProgressEvent = { + "type": "progress", + "key": self.key, + "current": self.current if self.current is not None else 0, + "total": self.total, + "message": self.message, + "updated_at": data.get("updated_at"), + } + + # Publish JSON payload + await redis.publish(channel, json.dumps(payload)) + + async def subscribe(self) -> AsyncGenerator[ProgressEvent, None]: + """Subscribe to progress updates for this task. + + Yields: + Dict containing progress update events with fields: + - type: "progress" + - key: task key + - current: current progress value + - total: total/target value (or None) + - message: status message (or None) + - updated_at: ISO 8601 timestamp + """ + channel = f"{self.docket.name}:progress:{self.key}" + async with self.docket.redis() as redis: + async with redis.pubsub() as pubsub: + await pubsub.subscribe(channel) + async for message in pubsub.listen(): # pragma: no cover + if message["type"] == "message": + yield json.loads(message["data"]) + + class Execution: + """Represents a task execution with state management and progress tracking. + + Combines task invocation metadata (function, args, when, etc.) with + Redis-backed lifecycle state tracking and user-reported progress. + """ + def __init__( self, + docket: "Docket", function: TaskFunction, args: tuple[Any, ...], kwargs: dict[str, Any], @@ -52,6 +312,7 @@ def __init__( trace_context: opentelemetry.context.Context | None = None, redelivered: bool = False, ) -> None: + self.docket = docket self.function = function self.args = args self.kwargs = kwargs @@ -60,6 +321,13 @@ def __init__( self.attempt = attempt self.trace_context = trace_context self.redelivered = redelivered + self.state: ExecutionState = ExecutionState.SCHEDULED + self.worker: str | None = None + self.started_at: datetime | None = None + self.completed_at: datetime | None = None + self.error: str | None = None + self.progress: ExecutionProgress = ExecutionProgress(docket, key) + self._redis_key = f"{docket.name}:runs:{key}" def as_message(self) -> Message: return { @@ -72,8 +340,15 @@ def as_message(self) -> Message: } @classmethod - def from_message(cls, function: TaskFunction, message: Message) -> Self: - return cls( + async def from_message(cls, docket: "Docket", message: Message) -> Self: + function_name = message[b"function"].decode() + if not (function := docket.tasks.get(function_name)): + raise ValueError( + f"Task function {function_name!r} is not registered with the current docket" + ) + + instance = cls( + docket=docket, function=function, args=cloudpickle.loads(message[b"args"]), kwargs=cloudpickle.loads(message[b"kwargs"]), @@ -83,6 +358,8 @@ def from_message(cls, function: TaskFunction, message: Message) -> Self: trace_context=propagate.extract(message, getter=message_getter), redelivered=False, # Default to False, will be set to True in worker if it's a redelivery ) + await instance.sync() + return instance def general_labels(self) -> Mapping[str, str]: return {"docket.task": self.function.__name__} @@ -128,6 +405,407 @@ def incoming_span_links(self) -> list[trace.Link]: initiating_context = initiating_span.get_span_context() return [trace.Link(initiating_context)] if initiating_context.is_valid else [] + async def schedule(self, replace: bool = False) -> None: + """Schedule this task atomically in Redis. + + This performs an atomic operation that: + - Adds the task to the stream (immediate) or queue (future) + - Writes the execution state record + - Tracks metadata for later cancellation + + Args: + replace: If True, replaces any existing task with the same key. + If False, raises an error if the task already exists. + """ + message: dict[bytes, bytes] = self.as_message() + propagate.inject(message, setter=message_setter) + + key = self.key + when = self.when + known_task_key = self.docket.known_task_key(key) + is_immediate = when <= datetime.now(timezone.utc) + + async with self.docket.redis() as redis: + # Lock per task key to prevent race conditions between concurrent operations + async with redis.lock(f"{known_task_key}:lock", timeout=10): + # Register script for this connection (not cached to avoid event loop issues) + schedule_script = cast( + _schedule_task, + redis.register_script( + # KEYS: stream_key, known_key, parked_key, queue_key, stream_id_key, runs_key + # ARGV: task_key, when_timestamp, is_immediate, replace, ...message_fields + """ + local stream_key = KEYS[1] + -- TODO: Remove in next breaking release (v0.14.0) - legacy key locations + local known_key = KEYS[2] + local parked_key = KEYS[3] + local queue_key = KEYS[4] + local stream_id_key = KEYS[5] + local runs_key = KEYS[6] + + local task_key = ARGV[1] + local when_timestamp = ARGV[2] + local is_immediate = ARGV[3] == '1' + local replace = ARGV[4] == '1' + + -- Extract message fields from ARGV[5] onwards + local message = {} + for i = 5, #ARGV, 2 do + message[#message + 1] = ARGV[i] -- field name + message[#message + 1] = ARGV[i + 1] -- field value + end + + -- Handle replacement: cancel existing task if needed + if replace then + -- Get stream ID from runs hash (check new location first) + local existing_message_id = redis.call('HGET', runs_key, 'stream_id') + + -- TODO: Remove in next breaking release (v0.14.0) - check legacy location + if not existing_message_id then + existing_message_id = redis.call('GET', stream_id_key) + end + + if existing_message_id then + redis.call('XDEL', stream_key, existing_message_id) + end + + redis.call('ZREM', queue_key, task_key) + redis.call('DEL', parked_key) + + -- TODO: Remove in next breaking release (v0.14.0) - clean up legacy keys + redis.call('DEL', known_key, stream_id_key) + + -- Note: runs_key is updated below, not deleted + else + -- Check if task already exists (check new location first, then legacy) + local known_exists = redis.call('HEXISTS', runs_key, 'known') == 1 + if not known_exists then + -- TODO: Remove in next breaking release (v0.14.0) - check legacy location + known_exists = redis.call('EXISTS', known_key) == 1 + end + if known_exists then + return 'EXISTS' + end + end + + if is_immediate then + -- Add to stream for immediate execution + local message_id = redis.call('XADD', stream_key, '*', unpack(message)) + + -- Store state and metadata in runs hash + redis.call('HSET', runs_key, + 'state', 'queued', + 'when', when_timestamp, + 'known', when_timestamp, + 'stream_id', message_id + ) + else + -- Park task data for future execution + redis.call('HSET', parked_key, unpack(message)) + + -- Add to sorted set queue + redis.call('ZADD', queue_key, when_timestamp, task_key) + + -- Store state and metadata in runs hash + redis.call('HSET', runs_key, + 'state', 'scheduled', + 'when', when_timestamp, + 'known', when_timestamp + ) + end + + return 'OK' + """ + ), + ) + + await schedule_script( + keys=[ + self.docket.stream_key, + known_task_key, + self.docket.parked_task_key(key), + self.docket.queue_key, + self.docket.stream_id_key(key), + self._redis_key, + ], + args=[ + key, + str(when.timestamp()), + "1" if is_immediate else "0", + "1" if replace else "0", + *[ + item + for field, value in message.items() + for item in (field, value) + ], + ], + ) + + # Update local state based on whether task is immediate or scheduled + if is_immediate: + self.state = ExecutionState.QUEUED + await self._publish_state( + {"state": ExecutionState.QUEUED.value, "when": when.isoformat()} + ) + else: + self.state = ExecutionState.SCHEDULED + await self._publish_state( + {"state": ExecutionState.SCHEDULED.value, "when": when.isoformat()} + ) + + async def claim(self, worker: str) -> None: + """Atomically claim task and transition to RUNNING state. + + This consolidates worker operations when claiming a task into a single + atomic Lua script that: + - Sets state to RUNNING with worker name and timestamp + - Initializes progress tracking (current=0, total=100) + - Deletes known/stream_id fields to allow task rescheduling + - Cleans up legacy keys for backwards compatibility + + Args: + worker: Name of the worker claiming the task + """ + started_at = datetime.now(timezone.utc) + started_at_iso = started_at.isoformat() + + async with self.docket.redis() as redis: + claim_script = redis.register_script( + # KEYS: runs_key, progress_key, known_key, stream_id_key + # ARGV: worker, started_at_iso + """ + local runs_key = KEYS[1] + local progress_key = KEYS[2] + -- TODO: Remove in next breaking release (v0.14.0) - legacy key locations + local known_key = KEYS[3] + local stream_id_key = KEYS[4] + + local worker = ARGV[1] + local started_at = ARGV[2] + + -- Update execution state to running + redis.call('HSET', runs_key, + 'state', 'running', + 'worker', worker, + 'started_at', started_at + ) + + -- Initialize progress tracking + redis.call('HSET', progress_key, + 'current', '0', + 'total', '100' + ) + + -- Delete known/stream_id fields to allow task rescheduling + redis.call('HDEL', runs_key, 'known', 'stream_id') + + -- TODO: Remove in next breaking release (v0.14.0) - legacy key cleanup + redis.call('DEL', known_key, stream_id_key) + + return 'OK' + """ + ) + + await claim_script( + keys=[ + self._redis_key, # runs_key + self.progress._redis_key, # progress_key + f"{self.docket.name}:known:{self.key}", # legacy known_key + f"{self.docket.name}:stream-id:{self.key}", # legacy stream_id_key + ], + args=[worker, started_at_iso], + ) + + # Update local state + self.state = ExecutionState.RUNNING + self.worker = worker + self.started_at = started_at + self.progress.current = 0 + self.progress.total = 100 + + # Publish state change event + await self._publish_state( + { + "state": ExecutionState.RUNNING.value, + "worker": worker, + "started_at": started_at_iso, + } + ) + + async def mark_as_completed(self) -> None: + """Mark task as completed successfully. + + Sets TTL on state data (from docket.execution_ttl) and deletes progress data. + """ + completed_at = datetime.now(timezone.utc).isoformat() + async with self.docket.redis() as redis: + await redis.hset( + self._redis_key, + mapping={ + "state": ExecutionState.COMPLETED.value, + "completed_at": completed_at, + }, + ) + # Set TTL from docket configuration + await redis.expire( + self._redis_key, int(self.docket.execution_ttl.total_seconds()) + ) + self.state = ExecutionState.COMPLETED + # Delete progress data + await self.progress._delete() + # Publish state change event + await self._publish_state( + {"state": ExecutionState.COMPLETED.value, "completed_at": completed_at} + ) + + async def mark_as_failed(self, error: str | None = None) -> None: + """Mark task as failed. + + Args: + error: Optional error message describing the failure + + Sets TTL on state data (from docket.execution_ttl) and deletes progress data. + """ + completed_at = datetime.now(timezone.utc).isoformat() + async with self.docket.redis() as redis: + mapping = { + "state": ExecutionState.FAILED.value, + "completed_at": completed_at, + } + if error: + mapping["error"] = error + await redis.hset(self._redis_key, mapping=mapping) + # Set TTL from docket configuration + await redis.expire( + self._redis_key, int(self.docket.execution_ttl.total_seconds()) + ) + self.state = ExecutionState.FAILED + # Delete progress data + await self.progress._delete() + # Publish state change event + state_data = { + "state": ExecutionState.FAILED.value, + "completed_at": completed_at, + } + if error: + state_data["error"] = error + await self._publish_state(state_data) + + async def sync(self) -> None: + """Synchronize instance attributes with current execution data from Redis. + + Updates self.state, execution metadata, and progress data from Redis. + Sets attributes to None if no data exists. + """ + async with self.docket.redis() as redis: + data = await redis.hgetall(self._redis_key) + if data: + # Update state + state_value = data.get(b"state") + if state_value: + if isinstance(state_value, bytes): + state_value = state_value.decode() + self.state = ExecutionState(state_value) + + # Update metadata + self.worker = data[b"worker"].decode() if b"worker" in data else None + self.started_at = ( + datetime.fromisoformat(data[b"started_at"].decode()) + if b"started_at" in data + else None + ) + self.completed_at = ( + datetime.fromisoformat(data[b"completed_at"].decode()) + if b"completed_at" in data + else None + ) + self.error = data[b"error"].decode() if b"error" in data else None + else: + # No data exists - reset to defaults + self.state = ExecutionState.SCHEDULED + self.worker = None + self.started_at = None + self.completed_at = None + self.error = None + + # Sync progress data + await self.progress.sync() + + async def _publish_state(self, data: dict) -> None: + """Publish state change to Redis pub/sub channel. + + Args: + data: State data to publish + """ + channel = f"{self.docket.name}:state:{self.key}" + # Create ephemeral Redis client for publishing + async with self.docket.redis() as redis: + # Build payload with all relevant state information + payload = { + "type": "state", + "key": self.key, + **data, + } + await redis.publish(channel, json.dumps(payload)) + + async def subscribe(self) -> AsyncGenerator[StateEvent | ProgressEvent, None]: + """Subscribe to both state and progress updates for this task. + + Emits the current state as the first event, then subscribes to real-time + state and progress updates via Redis pub/sub. + + Yields: + Dict containing state or progress update events with a 'type' field: + - For state events: type="state", state, worker, timestamps, error + - For progress events: type="progress", current, total, message, updated_at + """ + # First, emit the current state + await self.sync() + + # Build initial state event from current attributes + initial_state: StateEvent = { + "type": "state", + "key": self.key, + "state": self.state, + "when": self.when.isoformat(), + "worker": self.worker, + "started_at": self.started_at.isoformat() if self.started_at else None, + "completed_at": self.completed_at.isoformat() + if self.completed_at + else None, + "error": self.error, + } + + yield initial_state + + progress_event: ProgressEvent = { + "type": "progress", + "key": self.key, + "current": self.progress.current, + "total": self.progress.total, + "message": self.progress.message, + "updated_at": self.progress.updated_at.isoformat() + if self.progress.updated_at + else None, + } + + yield progress_event + + # Then subscribe to real-time updates + state_channel = f"{self.docket.name}:state:{self.key}" + progress_channel = f"{self.docket.name}:progress:{self.key}" + async with self.docket.redis() as redis: + async with redis.pubsub() as pubsub: + await pubsub.subscribe(state_channel, progress_channel) + async for message in pubsub.listen(): # pragma: no cover + if message["type"] == "message": + message_data = json.loads(message["data"]) + if message_data["type"] == "state": + message_data["state"] = ExecutionState( + message_data["state"] + ) + yield message_data + def compact_signature(signature: inspect.Signature) -> str: from .dependencies import Dependency diff --git a/src/docket/worker.py b/src/docket/worker.py index d4e1ae8..e070aea 100644 --- a/src/docket/worker.py +++ b/src/docket/worker.py @@ -37,6 +37,8 @@ RedisReadGroupResponse, ) from .execution import compact_signature, get_signature + +# Run class has been consolidated into Execution from .instrumentation import ( QUEUE_DEPTH, REDIS_DISRUPTIONS, @@ -297,21 +299,20 @@ async def get_new_deliveries(redis: Redis) -> RedisReadGroupResponse: await asyncio.sleep(self.minimum_check_interval.total_seconds()) return result - def start_task( + async def start_task( message_id: RedisMessageID, message: RedisMessage, is_redelivery: bool = False, ) -> bool: - function_name = message[b"function"].decode() - if not (function := self.docket.tasks.get(function_name)): - logger.warning( - "Task function %r not found", - function_name, + try: + execution = await Execution.from_message(self.docket, message) + except ValueError as e: + logger.error( + "Unable to start task: %s", + e, extra=log_context, ) return False - - execution = Execution.from_message(function, message) execution.redelivered = is_redelivery task = asyncio.create_task(self._execute(execution), name=execution.key) @@ -364,7 +365,7 @@ async def ack_message(redis: Redis, message_id: RedisMessageID) -> None: if not message: # pragma: no cover continue - task_started = start_task( + task_started = await start_task( message_id, message, is_redelivery ) if not task_started: @@ -434,6 +435,16 @@ async def _scheduler_loop( 'attempt', task['attempt'] ) redis.call('DEL', hash_key) + + -- Set run state to queued + local run_key = ARGV[2] .. ":runs:" .. task['key'] + redis.call('HSET', run_key, 'state', 'queued') + + -- Publish state change event to pub/sub + local channel = ARGV[2] .. ":state:" .. task['key'] + local payload = '{"type":"state","key":"' .. task['key'] .. '","state":"queued","when":"' .. task['when'] .. '"}' + redis.call('PUBLISH', channel, payload) + due_work = due_work + 1 end end @@ -513,6 +524,11 @@ async def _delete_known_task( return logger.debug("Deleting known task", extra=self._log_context()) + # Delete known/stream_id from runs hash to allow task rescheduling + runs_key = f"{self.docket.name}:runs:{key}" + await redis.hdel(runs_key, "known", "stream_id") + + # TODO: Remove in next breaking release (v0.14.0) - legacy key cleanup known_task_key = self.docket.known_task_key(key) stream_id_key = self.docket.stream_id_key(key) await redis.delete(known_task_key, stream_id_key) @@ -548,6 +564,10 @@ async def _execute(self, execution: Execution) -> None: arrow = "↬" if execution.attempt > 1 else "↪" logger.info("%s [%s] %s", arrow, ms(punctuality), call, extra=log_context) + # Atomically claim task and transition to running state + # This also initializes progress and cleans up known/stream_id to allow rescheduling + await execution.claim(self.name) + dependencies: dict[str, Dependency] = {} with tracer.start_as_current_span( @@ -588,14 +608,11 @@ async def _execute(self, execution: Execution) -> None: # Successfully acquired slot pass - # Preemptively reschedule the perpetual task for the future, or clear - # the known task key for this task + # Preemptively reschedule the perpetual task for the future + # Note: known/stream_id already deleted by claim_and_run() rescheduled = await self._perpetuate_if_requested( execution, dependencies ) - if not rescheduled: - async with self.docket.redis() as redis: - await self._delete_known_task(redis, execution) dependency_failures = { k: v @@ -653,6 +670,10 @@ async def _execute(self, execution: Execution) -> None: execution, dependencies, timedelta(seconds=duration) ) + if not rescheduled: + # Mark execution as completed + await execution.mark_as_completed() + arrow = "↫" if rescheduled else "↩" logger.info( "%s [%s] %s", arrow, ms(duration), call, extra=log_context @@ -670,6 +691,10 @@ async def _execute(self, execution: Execution) -> None: execution, dependencies, timedelta(seconds=duration) ) + # Mark execution as failed with error message + error_msg = f"{type(e).__name__}: {str(e)}" + await execution.mark_as_failed(error_msg) + arrow = "↫" if retried else "↩" logger.exception( "%s [%s] %s", arrow, ms(duration), call, extra=log_context @@ -740,7 +765,8 @@ async def _retry_if_requested( execution.when = datetime.now(timezone.utc) + retry.delay execution.attempt += 1 - await self.docket.schedule(execution) + # Use replace=True since the task is being rescheduled after failure + await execution.schedule(replace=True) TASKS_RETRIED.add(1, {**self.labels(), **execution.general_labels()}) return True diff --git a/tests/cli/test_watch.py b/tests/cli/test_watch.py new file mode 100644 index 0000000..dea2fed --- /dev/null +++ b/tests/cli/test_watch.py @@ -0,0 +1,396 @@ +"""Tests for the docket watch CLI command.""" + +import asyncio +from datetime import datetime, timedelta, timezone +import os +from unittest.mock import AsyncMock + +import pytest + +from docket import Docket, Progress, Worker +from docket.execution import ExecutionProgress + +from .utils import run_cli + +# Skip CLI tests when using memory backend since CLI rejects memory:// URLs +pytestmark = pytest.mark.skipif( + os.environ.get("REDIS_VERSION") == "memory", + reason="CLI commands require a persistent Redis backend", +) + + +async def test_watch_completed_task(docket: Docket, the_task: AsyncMock): + """Watch should display completed task and exit immediately.""" + docket.register(the_task) + + # Create and complete a task + execution = await docket.add(the_task, key="completed-task")() + await execution.claim("worker-1") + await execution.mark_as_completed() + + # Watch should show completion and exit + result = await run_cli( + "watch", + "completed-task", + "--url", + docket.url, + "--docket", + docket.name, + ) + + assert result.exit_code == 0 + assert "completed-task" in result.output + assert "COMPLETED" in result.output.upper() + assert docket.name in result.output + assert "✓" in result.output or "completed successfully" in result.output.lower() + + +async def test_watch_failed_task(docket: Docket, the_task: AsyncMock): + """Watch should display failed task with error message.""" + docket.register(the_task) + + execution = await docket.add(the_task, key="failed-task")() + await execution.claim("worker-1") + await execution.mark_as_failed("Test error message") + + result = await run_cli( + "watch", + "failed-task", + "--url", + docket.url, + "--docket", + docket.name, + ) + + assert result.exit_code == 0 + assert "failed-task" in result.output + assert "FAILED" in result.output.upper() + assert docket.name in result.output + assert "✗" in result.output or "failed" in result.output.lower() + assert "Test error message" in result.output or "Error" in result.output + + +async def test_watch_running_task_until_completion(docket: Docket, worker: Worker): + """Watch should monitor task from running to completion.""" + + async def slower_task(): + # Sleep long enough for watch command to connect and receive events + await asyncio.sleep(0.5) + + docket.register(slower_task) + await docket.add(slower_task, key="slower-task")() + + # Start worker in background + worker_task = asyncio.create_task(worker.run_until_finished()) + + # Give worker a moment to claim the task + await asyncio.sleep(0.05) + + # Watch should receive state events while task runs + result = await run_cli( + "watch", + "slower-task", + "--url", + docket.url, + "--docket", + docket.name, + timeout=2.0, + ) + + await worker_task + + assert result.exit_code == 0 + assert "slower-task" in result.output + assert docket.name in result.output + assert "RUNNING" in result.output.upper() or "COMPLETED" in result.output.upper() + + +async def test_watch_with_progress_updates(docket: Docket, worker: Worker): + """Watch should display progress bar updates.""" + + async def task_with_progress(progress: ExecutionProgress = Progress()): + await progress.set_total(10) + await progress.set_message("Starting") + for i in range(10): + # Slower so watch has time to connect and receive events + await asyncio.sleep(0.1) + await progress.increment() + await progress.set_message(f"Step {i + 1}") + + docket.register(task_with_progress) + await docket.add(task_with_progress, key="progress-task")() + + worker_task = asyncio.create_task(worker.run_until_finished()) + + result = await run_cli( + "watch", + "progress-task", + "--url", + docket.url, + "--docket", + docket.name, + ) + + await worker_task + + assert result.exit_code == 0 + assert "progress-task" in result.output + assert docket.name in result.output + # State should be shown + assert "COMPLETED" in result.output.upper() + + +async def test_watch_scheduled_task_transition(docket: Docket, worker: Worker): + """Watch should show task transition from scheduled to completed.""" + + async def scheduled_task(): + await asyncio.sleep(0.01) + + docket.register(scheduled_task) + + # Schedule for near future + when = datetime.now(timezone.utc) + timedelta(seconds=2) + await docket.add(scheduled_task, when=when, key="scheduled-task")() + + worker_task = asyncio.create_task(worker.run_until_finished()) + + result = await run_cli( + "watch", + "scheduled-task", + "--url", + docket.url, + "--docket", + docket.name, + timeout=5.0, + ) + + await worker_task + + assert result.exit_code == 0 + assert "scheduled-task" in result.output + assert docket.name in result.output + # Should show final completed state + assert "COMPLETED" in result.output.upper() or "SCHEDULED" in result.output.upper() + + +async def test_watch_task_with_initial_progress(docket: Docket, worker: Worker): + """Watch should handle task that already has progress when monitoring starts.""" + + async def task_with_initial_progress(progress: ExecutionProgress = Progress()): + # Set progress before watch likely connects + await progress.set_total(20) + await progress.increment(5) + # Then continue slowly + for _ in range(15): + await asyncio.sleep(0.1) + await progress.increment() + + docket.register(task_with_initial_progress) + await docket.add(task_with_initial_progress, key="initial-progress")() + + worker_task = asyncio.create_task(worker.run_until_finished()) + + # Let task get started and report some progress + await asyncio.sleep(1.5) + + result = await run_cli( + "watch", + "initial-progress", + "--url", + docket.url, + "--docket", + docket.name, + timeout=3.0, + ) + + await worker_task + + assert result.exit_code == 0 + + +async def test_watch_task_with_worker_assignment(docket: Docket, worker: Worker): + """Watch should show worker name when task is claimed.""" + + async def long_running_task(): + # Long enough for watch to definitely connect + await asyncio.sleep(3.0) + + docket.register(long_running_task) + await docket.add(long_running_task, key="worker-assigned")() + + worker_task = asyncio.create_task(worker.run_until_finished()) + + # Give worker time to claim task + await asyncio.sleep(0.1) + + result = await run_cli( + "watch", + "worker-assigned", + "--url", + docket.url, + "--docket", + docket.name, + timeout=5.0, + ) + + await worker_task + + assert result.exit_code == 0 + assert "worker-assigned" in result.output + assert docket.name in result.output + assert "Worker" in result.output or worker.name in result.output + + +async def test_watch_task_that_starts_while_watching(docket: Docket, worker: Worker): + """Watch should receive started_at event and update progress bar timing.""" + + async def task_that_waits_then_progresses(progress: ExecutionProgress = Progress()): + # Immediately report progress so watch sees it + await progress.set_total(10) + await progress.increment(1) + await progress.set_message("Started") + # Then continue + for _ in range(9): + await asyncio.sleep(0.15) + await progress.increment() + + docket.register(task_that_waits_then_progresses) + + # Schedule task for slightly in future so watch can connect first + when = datetime.now(timezone.utc) + timedelta(seconds=2) + await docket.add(task_that_waits_then_progresses, when=when, key="timing-test")() + + worker_task = asyncio.create_task(worker.run_until_finished()) + + # Start watching BEFORE task starts + result = await run_cli( + "watch", + "timing-test", + "--url", + docket.url, + "--docket", + docket.name, + timeout=4.0, + ) + + await worker_task + + assert result.exit_code == 0 + assert "timing-test" in result.output + assert docket.name in result.output + + +async def test_watch_receives_progress_events_during_execution( + docket: Docket, worker: Worker +): + """Watch should receive and process progress events as they occur.""" + + async def task_with_many_updates(progress: ExecutionProgress = Progress()): + await progress.set_total(20) + for i in range(20): + await asyncio.sleep(0.08) # 1.6 seconds total + await progress.increment() + if i % 5 == 0: + await progress.set_message(f"Checkpoint {i}") + + docket.register(task_with_many_updates) + await docket.add(task_with_many_updates, key="many-updates")() + + worker_task = asyncio.create_task(worker.run_until_finished()) + + # Give worker just a moment then start watching + await asyncio.sleep(0.05) + + result = await run_cli( + "watch", + "many-updates", + "--url", + docket.url, + "--docket", + docket.name, + timeout=3.0, + ) + + await worker_task + + assert result.exit_code == 0 + assert "many-updates" in result.output + assert docket.name in result.output + + +async def test_watch_already_running_task_with_progress(docket: Docket, worker: Worker): + """Watch a task that's already running with progress when watch starts.""" + + async def task_already_running(progress: ExecutionProgress = Progress()): + # Set up some initial state quickly + await progress.set_total(30) + await progress.increment(10) + await progress.set_message("Already started") + # Then run slowly so watch can observe + for _ in range(20): + await asyncio.sleep(0.1) + await progress.increment() + + docket.register(task_already_running) + await docket.add(task_already_running, key="already-running")() + + # Start worker + worker_task = asyncio.create_task(worker.run_until_finished()) + + # Wait for task to actually start and report some progress + await asyncio.sleep(0.3) + + # Now start watching - task should already be RUNNING with progress + result = await run_cli( + "watch", + "already-running", + "--url", + docket.url, + "--docket", + docket.name, + timeout=4.0, + ) + + await worker_task + + assert result.exit_code == 0 + assert "already-running" in result.output + assert docket.name in result.output + + +async def test_watch_task_with_worker_in_state_event(docket: Docket, worker: Worker): + """Watch should handle state events containing worker name.""" + + async def task_with_delays(progress: ExecutionProgress = Progress()): + # Publish progress to create progress bar + await progress.set_total(15) + await progress.increment(1) + # Long delay so watch receives events + for _ in range(14): + await asyncio.sleep(0.15) + await progress.increment() + + docket.register(task_with_delays) + + # Schedule slightly in future + when = datetime.now(timezone.utc) + timedelta(milliseconds=150) + await docket.add(task_with_delays, when=when, key="worker-event")() + + worker_task = asyncio.create_task(worker.run_until_finished()) + + # Start watch early so it's listening when task starts + result = await run_cli( + "watch", + "worker-event", + "--url", + docket.url, + "--docket", + docket.name, + timeout=5.0, + ) + + await worker_task + + assert result.exit_code == 0 + assert "worker-event" in result.output + assert docket.name in result.output diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 7d0c201..93ba265 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -588,6 +588,7 @@ async def task1(): ... async def task2(): ... execution1 = Execution( + docket=docket, key="task1-key", function=task1, args=(), @@ -597,6 +598,7 @@ async def task2(): ... ) execution2 = Execution( + docket=docket, key="task2-key", function=task2, args=(), @@ -637,6 +639,7 @@ async def test_contextvar_not_leaked_to_caller(docket: Docket): async def dummy_task(): ... execution = Execution( + docket=docket, key="test-key", function=dummy_task, args=(), diff --git a/tests/test_docket.py b/tests/test_docket.py index 69400c0..c3cd10f 100644 --- a/tests/test_docket.py +++ b/tests/test_docket.py @@ -166,3 +166,45 @@ async def test_clear_no_redis_key_leaks(docket: Docket, the_task: AsyncMock): snapshot = await docket.snapshot() assert len(snapshot.future) == 0 assert len(snapshot.running) == 0 + + +async def test_docket_schedule_method_with_immediate_task( + docket: Docket, the_task: AsyncMock +): + """Test direct scheduling via docket.schedule(execution) for immediate execution.""" + from docket import Execution + + # Register task so snapshot can look it up + docket.register(the_task) + + execution = Execution( + docket, the_task, ("arg",), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + await docket.schedule(execution) + + # Verify task was scheduled + snapshot = await docket.snapshot() + assert len(snapshot.future) == 1 + + +async def test_docket_schedule_with_stricken_task(docket: Docket, the_task: AsyncMock): + """Test that docket.schedule respects strike list.""" + from docket import Execution + + # Register task + docket.register(the_task) + + # Strike the task + await docket.strike("the_task") + + execution = Execution( + docket, the_task, (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + # Try to schedule - should be blocked + await docket.schedule(execution) + + # Verify task was NOT scheduled + snapshot = await docket.snapshot() + assert len(snapshot.future) == 0 diff --git a/tests/test_execution_progress.py b/tests/test_execution_progress.py new file mode 100644 index 0000000..a764a83 --- /dev/null +++ b/tests/test_execution_progress.py @@ -0,0 +1,850 @@ +"""Tests for task execution state and progress tracking.""" + +import asyncio +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock + +import pytest + +from docket import Docket, Execution, ExecutionState, Progress, Worker +from docket.execution import ExecutionProgress, ProgressEvent, StateEvent + + +async def test_run_state_scheduled(docket: Docket, the_task: AsyncMock): + """Execution should be set to QUEUED when an immediate task is added.""" + execution = await docket.add(the_task)("arg1", "arg2") + + assert isinstance(execution, Execution) + await execution.sync() + assert execution.state == ExecutionState.QUEUED + + +async def test_run_state_pending_to_running(docket: Docket, worker: Worker): + """Execution should transition from QUEUED to RUNNING during execution.""" + executed = asyncio.Event() + + async def test_task(): + # Verify we're in RUNNING state + executed.set() + + await docket.add(test_task)() + + # Start worker but don't wait for completion yet + worker_task = asyncio.create_task(worker.run_until_finished()) + + # Wait for task to start executing + await executed.wait() + + # Give it a moment to complete + await worker_task + + +async def test_run_state_completed_on_success( + docket: Docket, worker: Worker, the_task: AsyncMock +): + """Execution should be set to COMPLETED when task succeeds.""" + execution = await docket.add(the_task)() + + await worker.run_until_finished() + + await execution.sync() + assert execution.state == ExecutionState.COMPLETED + + +async def test_run_state_failed_on_exception(docket: Docket, worker: Worker): + """Execution should be set to FAILED when task raises an exception.""" + + async def failing_task(): + raise ValueError("Task failed!") + + execution = await docket.add(failing_task)() + + await worker.run_until_finished() + + await execution.sync() + assert execution.state == ExecutionState.FAILED + + +async def test_progress_create(docket: Docket): + """Progress.create() should initialize instance from Redis.""" + # First create a progress instance and set some values + progress = ExecutionProgress(docket, "test-key") + await progress.set_total(100) + await progress.increment(5) + await progress.set_message("Test message") + + # Now create a new instance using create() + progress2 = await ExecutionProgress.create(docket, "test-key") + + # Verify it loaded the data from Redis + assert progress2.current == 5 + assert progress2.total == 100 + assert progress2.message == "Test message" + assert progress2.updated_at is not None + + +async def test_progress_set_total(docket: Docket): + """Progress should be able to set total value.""" + progress = ExecutionProgress(docket, "test-key") + + await progress.set_total(100) + + assert progress.total == 100 + assert progress.updated_at is not None + + +async def test_progress_set_total_invalid(docket: Docket): + """Progress should raise an error if total is less than 1.""" + progress = ExecutionProgress(docket, "test-key") + with pytest.raises(ValueError): + await progress.set_total(0) + + +async def test_progress_increment_invalid(docket: Docket): + """Progress should raise an error if amount is less than 1.""" + progress = ExecutionProgress(docket, "test-key") + with pytest.raises(ValueError): + await progress.increment(0) + + +async def test_progress_increment(docket: Docket): + """Progress should atomically increment current value.""" + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + # Initialize with set_running (which sets current=0) + await execution.claim("worker-1") + progress = execution.progress + + # Increment multiple times + await progress.increment() + await progress.increment() + await progress.increment(2) + + assert progress.current == 4 # 0 + 1 + 1 + 2 = 4 + assert progress.updated_at is not None + + +async def test_progress_set_message(docket: Docket): + """Progress should be able to set status message.""" + progress = ExecutionProgress(docket, "test-key") + + await progress.set_message("Processing items...") + + assert progress.message == "Processing items..." + assert progress.updated_at is not None + + +async def test_progress_dependency_injection(docket: Docket, worker: Worker): + """Progress dependency should be injected into task functions.""" + progress_values: list[int] = [] + + async def task_with_progress(progress: ExecutionProgress = Progress()): + await progress.set_total(10) + for i in range(10): + await asyncio.sleep(0.001) + await progress.increment() + await progress.set_message(f"Processing item {i + 1}") + # Capture progress data + assert progress.current is not None + progress_values.append(progress.current) + + await docket.add(task_with_progress)() + + await worker.run_until_finished() + + # Verify progress was tracked + assert len(progress_values) > 0 + assert progress_values[-1] == 10 # Should reach 10 + + +async def test_progress_deleted_on_completion(docket: Docket, worker: Worker): + """Progress data should be deleted when task completes.""" + + async def task_with_progress(progress: ExecutionProgress = Progress()): + await progress.set_total(5) + await progress.increment() + + execution = await docket.add(task_with_progress)() + + # Before execution, no progress + await execution.progress.sync() + assert execution.progress.current is None + + await worker.run_until_finished() + + # After completion, progress should be deleted + await execution.progress.sync() + assert execution.progress.current is None + + +async def test_run_state_ttl_after_completion( + docket: Docket, worker: Worker, the_task: AsyncMock +): + """Run state should have TTL set after completion.""" + execution = await docket.add(the_task)() + + await worker.run_until_finished() + + # Verify state exists + await execution.sync() + assert execution.state == ExecutionState.COMPLETED + + # Verify TTL is set to the configured execution_ttl (default: 1 hour = 3600 seconds) + expected_ttl = int(docket.execution_ttl.total_seconds()) + async with docket.redis() as redis: + ttl = await redis.ttl(execution._redis_key) # type: ignore[reportPrivateUsage] + assert 0 < ttl <= expected_ttl # TTL should be set and reasonable + + +async def test_custom_execution_ttl(redis_url: str, the_task: AsyncMock): + """Docket should respect custom execution_ttl configuration.""" + # Create docket with custom 5-minute TTL + custom_ttl = timedelta(minutes=5) + async with Docket( + name="test-custom-ttl", url=redis_url, execution_ttl=custom_ttl + ) as docket: + async with Worker(docket) as worker: + execution = await docket.add(the_task)() + + await worker.run_until_finished() + + # Verify state is completed + await execution.sync() + assert execution.state == ExecutionState.COMPLETED + + # Verify TTL matches custom value (300 seconds) + expected_ttl = int(custom_ttl.total_seconds()) + async with docket.redis() as redis: + ttl = await redis.ttl(execution._redis_key) # type: ignore[reportPrivateUsage] + assert 0 < ttl <= expected_ttl + # Verify it's approximately the custom value (not the default 3600) + assert ttl > 200 # Should be close to 300, not near 0 + assert ttl <= 300 # Should not exceed configured value + + +async def test_full_lifecycle_integration(docket: Docket, worker: Worker): + """Test complete lifecycle: SCHEDULED -> QUEUED -> RUNNING -> COMPLETED.""" + states_observed: list[ExecutionState] = [] + + async def tracking_task(progress: ExecutionProgress = Progress()): + await progress.set_total(3) + for i in range(3): + await progress.increment() + await progress.set_message(f"Step {i + 1}") + await asyncio.sleep(0.01) + + # Schedule task in the future + when = datetime.now(timezone.utc) + timedelta(milliseconds=50) + execution = await docket.add(tracking_task, when=when)() + + # Should be SCHEDULED + await execution.sync() + assert execution.state == ExecutionState.SCHEDULED + states_observed.append(execution.state) + + # Run worker + await worker.run_until_finished() + + # Should be COMPLETED + await execution.sync() + assert execution.state == ExecutionState.COMPLETED + states_observed.append(execution.state) + + # Verify we observed the expected states + assert ExecutionState.SCHEDULED in states_observed + assert ExecutionState.COMPLETED in states_observed + + +async def test_progress_with_multiple_increments(docket: Docket, worker: Worker): + """Test progress tracking with realistic usage pattern.""" + + async def process_items(items: list[int], progress: ExecutionProgress = Progress()): + await progress.set_total(len(items)) + await progress.set_message("Starting processing") + + for i, _item in enumerate(items): + await asyncio.sleep(0.001) # Simulate work + await progress.increment() + await progress.set_message(f"Processed item {i + 1}/{len(items)}") + + await progress.set_message("All items processed") + + items = list(range(20)) + execution = await docket.add(process_items)(items) + + await worker.run_until_finished() + + # Verify final state + await execution.sync() + assert execution.state == ExecutionState.COMPLETED + + +async def test_progress_without_total(docket: Docket, worker: Worker): + """Progress should work even without setting total.""" + + async def task_without_total(progress: ExecutionProgress = Progress()): + for _ in range(5): + await progress.increment() + await asyncio.sleep(0.001) + + execution = await docket.add(task_without_total)() + + await worker.run_until_finished() + + await execution.sync() + assert execution.state == ExecutionState.COMPLETED + + +async def test_run_add_returns_run_instance(docket: Docket, the_task: AsyncMock): + """Verify that docket.add() returns an Execution instance.""" + result = await docket.add(the_task)("arg1") + + assert isinstance(result, Execution) + assert result.key is not None + assert len(result.key) > 0 + + +async def test_error_message_stored_on_failure(docket: Docket, worker: Worker): + """Failed run should store error message.""" + + async def failing_task(): + raise RuntimeError("Something went wrong!") + + execution = await docket.add(failing_task)() + + await worker.run_until_finished() + + # Check state is FAILED + await execution.sync() + assert execution.state == ExecutionState.FAILED + assert execution.error == "RuntimeError: Something went wrong!" + + +async def test_concurrent_progress_updates(docket: Docket): + """Progress updates should be atomic and safe for concurrent access.""" + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + progress = execution.progress + + await execution.claim("worker-1") + + # Simulate concurrent increments + async def increment_many(): + for _ in range(10): + await progress.increment() + + await asyncio.gather( + increment_many(), + increment_many(), + increment_many(), + ) + + # Sync to ensure we have the latest value from Redis + await progress.sync() + # Should be exactly 30 due to atomic HINCRBY + assert progress.current == 30 + + +async def test_progress_publish_events(docket: Docket): + """Progress updates should publish events to pub/sub channel.""" + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + progress = execution.progress + + # Set up subscriber in background + events: list[ProgressEvent] = [] + + async def collect_events(): + async for event in progress.subscribe(): # pragma: no cover + events.append(event) + if len(events) >= 3: # Collect 3 events then stop + break + + subscriber_task = asyncio.create_task(collect_events()) + + # Give subscriber time to connect + await asyncio.sleep(0.1) + + # Initialize and publish updates + await execution.claim("worker-1") + await progress.set_total(100) + await progress.increment(10) + await progress.set_message("Processing...") + + # Wait for subscriber to collect events + await asyncio.wait_for(subscriber_task, timeout=2.0) + + # Verify we received progress events + assert len(events) >= 3 + + # Check set_total event + total_event = next(e for e in events if e.get("total") == 100) + assert total_event["type"] == "progress" + assert total_event["key"] == "test-key" + assert "updated_at" in total_event + + # Check increment event + increment_event = next(e for e in events if e.get("current") == 10) + assert increment_event["type"] == "progress" + assert increment_event["current"] == 10 + + # Check message event + message_event = next(e for e in events if e.get("message") == "Processing...") + assert message_event["type"] == "progress" + assert message_event["message"] == "Processing..." + + +async def test_state_publish_events(docket: Docket, the_task: AsyncMock): + """State changes should publish events to pub/sub channel.""" + # Note: This test verifies the pub/sub mechanism works. + # Pub/sub is skipped for memory:// backend, so this test effectively + # documents the expected behavior for real Redis backends. + + execution = await docket.add(the_task, key="test-key")() + + # Verify state was set correctly + assert execution.state == ExecutionState.QUEUED + + # Verify state record exists in Redis + await execution.sync() + assert execution.state == ExecutionState.QUEUED + + +async def test_run_subscribe_both_state_and_progress(docket: Docket): + """Run.subscribe() should yield both state and progress events.""" + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + # Set up subscriber in background + all_events: list[StateEvent | ProgressEvent] = [] + + async def collect_events(): + async for event in execution.subscribe(): # pragma: no cover + all_events.append(event) + # Stop after we get a running state and some progress + if ( + len( + [ + e + for e in all_events + if e["type"] == "state" and e["state"] == ExecutionState.RUNNING + ] + ) + > 0 + and len([e for e in all_events if e["type"] == "progress"]) >= 3 + ): + break + + subscriber_task = asyncio.create_task(collect_events()) + + # Give subscriber time to connect + await asyncio.sleep(0.1) + + # Publish mixed state and progress events + await execution.claim("worker-1") + await execution.progress.set_total(50) + await execution.progress.increment(5) + + # Wait for subscriber to collect events + await asyncio.wait_for(subscriber_task, timeout=2.0) + + # Verify we got both types + state_events = [e for e in all_events if e["type"] == "state"] + progress_events = [e for e in all_events if e["type"] == "progress"] + + assert len(state_events) >= 1 + assert len(progress_events) >= 2 + + # Verify state event + running_event = next( + e for e in state_events if e["state"] == ExecutionState.RUNNING + ) + assert running_event["worker"] == "worker-1" + + # Verify progress events + total_event = next(e for e in progress_events if e.get("total") == 50) + assert total_event["current"] is not None and total_event["current"] >= 0 + + increment_event = next(e for e in progress_events if e.get("current") == 5) + assert increment_event["current"] == 5 + + +async def test_completed_state_publishes_event(docket: Docket): + """Completed state should publish event with completed_at timestamp.""" + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + # Set up subscriber + events: list[StateEvent] = [] + + async def collect_events(): + async for event in execution.subscribe(): # pragma: no cover + if event["type"] == "state": + events.append(event) + if any(e["state"] == ExecutionState.COMPLETED for e in events): + break + + subscriber_task = asyncio.create_task(collect_events()) + await asyncio.sleep(0.1) + + await execution.claim("worker-1") + await execution.mark_as_completed() + + await asyncio.wait_for(subscriber_task, timeout=2.0) + + # Find completed event + completed_event = next(e for e in events if e["state"] == ExecutionState.COMPLETED) + assert completed_event["type"] == "state" + assert "completed_at" in completed_event + + +async def test_failed_state_publishes_event_with_error(docket: Docket): + """Failed state should publish event with error message.""" + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + # Set up subscriber + events: list[StateEvent] = [] + + async def collect_events(): + async for event in execution.subscribe(): # pragma: no cover + if event["type"] == "state": + events.append(event) + if any(e["state"] == ExecutionState.FAILED for e in events): + break + + subscriber_task = asyncio.create_task(collect_events()) + await asyncio.sleep(0.1) + + await execution.claim("worker-1") + await execution.mark_as_failed("Something went wrong!") + + await asyncio.wait_for(subscriber_task, timeout=2.0) + + # Find failed event + failed_event = next(e for e in events if e["state"] == ExecutionState.FAILED) + assert failed_event["type"] == "state" + assert failed_event["error"] == "Something went wrong!" + assert "completed_at" in failed_event + + +async def test_end_to_end_progress_monitoring_with_worker( + docket: Docket, worker: Worker +): + """Test complete end-to-end progress monitoring with real worker execution.""" + collected_events: list[StateEvent | ProgressEvent] = [] + + async def task_with_progress(progress: ExecutionProgress = Progress()): + """Task that reports progress as it executes.""" + await progress.set_total(5) + await progress.set_message("Starting work") + + for i in range(5): + await asyncio.sleep(0.01) + await progress.increment() + await progress.set_message(f"Processing step {i + 1}/5") + + await progress.set_message("Work complete") + + # Schedule the task + execution = await docket.add(task_with_progress)() + + # Start subscriber to collect events + async def collect_events(): + async for event in execution.subscribe(): # pragma: no cover + collected_events.append(event) + # Stop when we reach completed state + if event["type"] == "state" and event["state"] == ExecutionState.COMPLETED: + break + + subscriber_task = asyncio.create_task(collect_events()) + + # Give subscriber time to connect + await asyncio.sleep(0.1) + + # Run the worker + await worker.run_until_finished() + + # Wait for subscriber to finish + await asyncio.wait_for(subscriber_task, timeout=5.0) + + # Verify we collected comprehensive events + assert len(collected_events) > 0 + + # Extract event types + state_events: list[StateEvent] = [ + e for e in collected_events if e["type"] == "state" + ] + progress_events = [e for e in collected_events if e["type"] == "progress"] + + # Verify state transitions occurred + # Note: scheduled may happen before subscriber connects + state_sequence: list[ExecutionState] = [e["state"] for e in state_events] + assert state_sequence == [ + ExecutionState.QUEUED, + ExecutionState.RUNNING, + ExecutionState.COMPLETED, + ] + + # Verify worker was recorded + running_events = [e for e in state_events if e["state"] == ExecutionState.RUNNING] + assert len(running_events) > 0 + assert "worker" in running_events[0] + + # Verify progress events were published + assert len(progress_events) >= 5 # At least one for each increment + + # Verify progress reached total + final_progress = progress_events[-1] + assert final_progress["current"] is not None and final_progress["current"] == 5 + assert final_progress["total"] == 5 + + # Verify messages were updated + message_events = [e for e in progress_events if e.get("message")] + assert len(message_events) > 0 + assert any( + "complete" in e["message"].lower() + for e in message_events + if e["message"] is not None + ) + + # Verify final state is completed + assert state_events[-1]["state"] == ExecutionState.COMPLETED + assert "completed_at" in state_events[-1] + + +async def test_end_to_end_failed_task_monitoring(docket: Docket, worker: Worker): + """Test progress monitoring for a task that fails.""" + collected_events: list[StateEvent | ProgressEvent] = [] + + async def failing_task(progress: ExecutionProgress = Progress()): + """Task that reports progress then fails.""" + await progress.set_total(10) + await progress.set_message("Starting work") + await progress.increment(3) + await progress.set_message("About to fail") + raise ValueError("Task failed intentionally") + + # Schedule the task + execution = await docket.add(failing_task)() + + # Start subscriber + async def collect_events(): + async for event in execution.subscribe(): # pragma: no cover + collected_events.append(event) + # Stop when we reach failed state + if event["type"] == "state" and event["state"] == ExecutionState.FAILED: + break + + subscriber_task = asyncio.create_task(collect_events()) + await asyncio.sleep(0.1) + + # Run the worker + await worker.run_until_finished() + + # Wait for subscriber + await asyncio.wait_for(subscriber_task, timeout=5.0) + + # Verify we got events + assert len(collected_events) > 0 + + state_events = [e for e in collected_events if e["type"] == "state"] + progress_events = [e for e in collected_events if e["type"] == "progress"] + + # Verify task reached running state + state_sequence = [e["state"] for e in state_events] + assert state_sequence == [ + ExecutionState.QUEUED, + ExecutionState.RUNNING, + ExecutionState.FAILED, + ] + + # Verify progress was reported before failure + assert len(progress_events) >= 2 + + # Find set_total event + total_event = next((e for e in progress_events if e.get("total") == 10), None) + assert total_event is not None + + # Find increment event + increment_event = next((e for e in progress_events if e.get("current") == 3), None) + assert increment_event is not None + + # Verify error message in failed event + failed_event = next(e for e in state_events if e["state"] == ExecutionState.FAILED) + assert failed_event["error"] is not None + assert "ValueError" in failed_event["error"] + assert "intentionally" in failed_event["error"] + + +async def test_mark_as_failed_without_error_message(docket: Docket): + """Test mark_as_failed with error=None.""" + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + await execution.claim("worker-1") + await execution.mark_as_failed(error=None) + + await execution.sync() + assert execution.state == ExecutionState.FAILED + assert execution.error is None + assert execution.completed_at is not None + + +async def test_execution_sync_with_no_redis_data(docket: Docket): + """Test sync() when no execution data exists in Redis.""" + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "nonexistent-key", 1 + ) + + # Sync without ever scheduling + await execution.sync() + + # Should reset to defaults + assert execution.state == ExecutionState.SCHEDULED + assert execution.worker is None + assert execution.started_at is None + assert execution.completed_at is None + assert execution.error is None + + +async def test_progress_publish_with_memory_backend(): + """Test that _publish() safely handles memory:// backend.""" + from docket import Docket + from docket.execution import ExecutionProgress + + # Create docket with memory:// URL + async with Docket(name="test-memory", url="memory://") as docket: + progress = ExecutionProgress(docket, "test-key") + + # This should not raise an error even though pub/sub doesn't work with memory:// + # The _publish method has an early return for memory:// backend + await getattr(progress, "_publish")({"type": "progress", "current": 10}) + + # Verify it completed without error + assert progress.docket.url == "memory://" + + +async def test_execution_sync_with_missing_state_field(docket: Docket): + """Test sync() when Redis data exists but has no 'state' field.""" + from unittest.mock import AsyncMock, patch + + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + # Set initial state + execution.state = ExecutionState.RUNNING + + # Mock Redis to return data WITHOUT state field + mock_data = { + b"worker": b"worker-1", + b"started_at": b"2024-01-01T00:00:00+00:00", + # No b"state" field - state_value will be None + } + + with patch.object(execution.docket, "redis") as mock_redis_ctx: + mock_redis = AsyncMock() + mock_redis.hgetall.return_value = mock_data + mock_redis_ctx.return_value.__aenter__.return_value = mock_redis + + # Mock progress sync to avoid extra Redis calls + with patch.object(execution.progress, "sync"): + await execution.sync() + + # State should NOT be updated (stays as RUNNING) + assert execution.state == ExecutionState.RUNNING + # But other fields should be updated + assert execution.worker == "worker-1" + assert execution.started_at is not None + + +async def test_execution_sync_with_string_state_value(docket: Docket): + """Test sync() handles non-bytes state value (defensive coding).""" + from unittest.mock import AsyncMock, patch + + execution = Execution( + docket, AsyncMock(), (), {}, datetime.now(timezone.utc), "test-key", 1 + ) + + # Mock Redis to return string state (defensive code handles both bytes and str) + mock_data = { + b"state": "completed", # String, not bytes! + b"worker": b"worker-1", + b"completed_at": b"2024-01-01T00:00:00+00:00", + } + + with patch.object(execution.docket, "redis") as mock_redis_ctx: + mock_redis = AsyncMock() + mock_redis.hgetall.return_value = mock_data + mock_redis_ctx.return_value.__aenter__.return_value = mock_redis + + # Mock progress sync + with patch.object(execution.progress, "sync"): + await execution.sync() + + # Should handle string and set state correctly + assert execution.state == ExecutionState.COMPLETED + assert execution.worker == "worker-1" + + +async def test_subscribing_to_completed_execution(docket: Docket, worker: Worker): + """Subscribing to already-completed executions should emit final state.""" + + async def completed_task(): + await asyncio.sleep(0.01) + + async def failed_task(): + await asyncio.sleep(0.01) + raise ValueError("Task failed") + + # Test subscribing to a completed task + execution = await docket.add(completed_task, key="already-done:123")() + + # Run the task to completion first + await worker.run_until_finished() + + # Now subscribe to the already-completed execution + async def get_first_event() -> StateEvent | None: + async for event in execution.subscribe(): # pragma: no cover + assert event["type"] == "state" + return event + + first_event = await asyncio.wait_for(get_first_event(), timeout=1.0) + assert first_event is not None + + # Verify the initial state includes completion metadata + assert first_event["type"] == "state" + assert first_event["state"] == ExecutionState.COMPLETED + assert first_event["completed_at"] is not None + assert first_event["error"] is None + + # Test subscribing to a failed task + execution = await docket.add(failed_task, key="already-failed:456")() + + # Run the task to failure first + await worker.run_until_finished() + + # Now subscribe to the already-failed execution + async def get_first_failed_event() -> StateEvent | None: + async for event in execution.subscribe(): # pragma: no cover + assert event["type"] == "state" + return event + + first_event = await asyncio.wait_for(get_first_failed_event(), timeout=1.0) + assert first_event is not None + + # Verify the initial state includes error metadata + assert first_event["type"] == "state" + assert first_event["state"] == ExecutionState.FAILED + assert first_event["completed_at"] is not None + assert first_event["error"] is not None + assert first_event["error"] == "ValueError: Task failed" diff --git a/tests/test_fundamentals.py b/tests/test_fundamentals.py index e75e7ab..cd2d47d 100644 --- a/tests/test_fundamentals.py +++ b/tests/test_fundamentals.py @@ -23,9 +23,11 @@ Depends, Docket, Execution, + ExecutionState, ExponentialRetry, Logged, Perpetual, + Progress, Retry, TaskArgument, TaskKey, @@ -34,6 +36,7 @@ Worker, tasks, ) +from docket.execution import ExecutionProgress, StateEvent @pytest.fixture @@ -581,6 +584,158 @@ async def the_task(a: str, b: str, this_key: str = TaskKey()): assert called +async def test_tasks_can_report_progress(docket: Docket, worker: Worker): + """docket should support tasks reporting their progress""" + + called = False + + async def the_task( + a: str, + b: str, + progress: ExecutionProgress = Progress(), + ): + assert a == "a" + assert b == "c" + + # Set the total expected work + await progress.set_total(100) + + # Increment progress + await progress.increment(10) + await progress.increment(20) + + # Set a status message + await progress.set_message("Processing items...") + + # Read back current progress + assert progress.current == 30 + assert progress.total == 100 + assert progress.message == "Processing items..." + + nonlocal called + called = True + + await docket.add(the_task, key="progress-task:123")("a", b="c") + + await worker.run_until_finished() + + assert called + + +async def test_tasks_can_access_execution_state(docket: Docket, worker: Worker): + """docket should support providing execution state and metadata to a task""" + + called = False + + async def the_task( + a: str, + b: str, + this_execution: Execution = CurrentExecution(), + ): + assert a == "a" + assert b == "c" + + assert isinstance(this_execution, Execution) + assert this_execution.key == "stateful-task:123" + assert this_execution.state == ExecutionState.RUNNING + assert this_execution.worker is not None + assert this_execution.started_at is not None + + nonlocal called + called = True + + await docket.add(the_task, key="stateful-task:123")("a", b="c") + + await worker.run_until_finished() + + assert called + + +async def test_execution_state_lifecycle( + docket: Docket, worker: Worker, now: Callable[[], datetime] +): + """docket executions transition through states: QUEUED → RUNNING → COMPLETED""" + + async def successful_task(): + await asyncio.sleep(0.01) + + async def failing_task(): + await asyncio.sleep(0.01) + raise ValueError("Task failed") + + # Test successful execution lifecycle + execution = await docket.add( + successful_task, key="success:123", when=now() + timedelta(seconds=1) + )() + + # Collect state events + state_events: list[StateEvent] = [] + + async def collect_states() -> None: + async for event in execution.subscribe(): # pragma: no cover + if event["type"] == "state": + state_events.append(event) + if event["state"] == ExecutionState.COMPLETED: + break + + subscriber_task = asyncio.create_task(collect_states()) + + await worker.run_until_finished() + await asyncio.wait_for(subscriber_task, timeout=7.0) + + # Verify we saw the state transitions + # Note: subscribe() emits the initial state first, then real-time updates + states = [e["state"] for e in state_events] + assert states == [ + ExecutionState.SCHEDULED, + ExecutionState.QUEUED, + ExecutionState.RUNNING, + ExecutionState.COMPLETED, + ] + + # Verify final state has completion metadata + final_state = state_events[-1] + assert final_state["state"] == ExecutionState.COMPLETED + assert final_state["completed_at"] is not None + assert "error" not in final_state # No error for successful completion + + # Test failed execution lifecycle + execution = await docket.add( + failing_task, key="failure:456", when=now() + timedelta(seconds=1) + )() + + failed_state_events: list[StateEvent] = [] + + async def collect_failed_states() -> None: + async for event in execution.subscribe(): # pragma: no cover + if event["type"] == "state": + failed_state_events.append(event) + if event["state"] == ExecutionState.FAILED: + break + + subscriber_task = asyncio.create_task(collect_failed_states()) + + await worker.run_until_finished() + await asyncio.wait_for(subscriber_task, timeout=7.0) + + # Verify we saw the state transitions + # Note: subscribe() emits the initial state first, then real-time updates + states = [e["state"] for e in failed_state_events] + assert states == [ + ExecutionState.SCHEDULED, + ExecutionState.QUEUED, + ExecutionState.RUNNING, + ExecutionState.FAILED, + ] + + # Verify final state has error information + final_state = failed_state_events[-1] + assert final_state["state"] == ExecutionState.FAILED + assert final_state["completed_at"] is not None + assert final_state["error"] is not None + assert final_state["error"] == "ValueError: Task failed" + + async def test_all_dockets_have_a_trace_task( docket: Docket, worker: Worker, caplog: pytest.LogCaptureFixture ): diff --git a/tests/test_instrumentation.py b/tests/test_instrumentation.py index b0b8ec0..3def22f 100644 --- a/tests/test_instrumentation.py +++ b/tests/test_instrumentation.py @@ -40,7 +40,7 @@ async def the_task(): assert isinstance(span, Span) captured.append(span) - execution = await docket.add(the_task)() + run = await docket.add(the_task)() await worker.run_until_finished() @@ -57,8 +57,9 @@ async def the_task(): assert task_span.attributes["docket.name"] == docket.name assert task_span.attributes["docket.task"] == "the_task" - assert task_span.attributes["docket.key"] == execution.key - assert task_span.attributes["docket.when"] == execution.when.isoformat() + assert task_span.attributes["docket.key"] == run.key + assert run.when is not None + assert task_span.attributes["docket.when"] == run.when.isoformat() assert task_span.attributes["docket.attempt"] == 1 assert task_span.attributes["code.function.name"] == "the_task" diff --git a/tests/test_striking.py b/tests/test_striking.py index a68e8b1..8d02a80 100644 --- a/tests/test_striking.py +++ b/tests/test_striking.py @@ -140,6 +140,7 @@ async def test_restoring_is_idempotent(docket: Docket): ], ) def test_strike_operators( + docket: Docket, operator: Operator, value: Any, test_value: Any, @@ -156,6 +157,7 @@ async def test_function(the_parameter: Any) -> None: pass # pragma: no cover execution = Execution( + docket=docket, function=test_function, args=(), kwargs={"the_parameter": test_value}, diff --git a/tests/test_worker.py b/tests/test_worker.py index c2a04ee..3c9515b 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -247,7 +247,10 @@ async def test_worker_handles_unregistered_task_execution_on_initial_delivery( with caplog.at_level(logging.WARNING): await worker.run_until_finished() - assert "Task function 'the_task' not found" in caplog.text + assert ( + "Task function 'the_task' is not registered with the current docket" + in caplog.text + ) async def test_worker_handles_unregistered_task_execution_on_redelivery( @@ -300,7 +303,10 @@ async def test_task(): with caplog.at_level(logging.WARNING): await worker_b.run_until_finished() - assert "Task function 'test_task' not found" in caplog.text + assert ( + "Task function 'test_task' is not registered with the current docket" + in caplog.text + ) builtin_tasks = {function.__name__ for function in standard_tasks} @@ -1262,6 +1268,7 @@ async def task_without_concurrency_dependency(): async with Worker(docket) as worker: # Create execution for task without concurrency dependency execution = Execution( + docket=docket, function=task_without_concurrency_dependency, args=(), kwargs={}, @@ -1287,6 +1294,7 @@ async def task_without_concurrency_dependency(): async with Worker(docket) as worker: # Create execution for task without concurrency dependency execution = Execution( + docket=docket, function=task_without_concurrency_dependency, args=(), kwargs={}, @@ -1315,6 +1323,7 @@ async def task_with_missing_arg( async with Worker(docket) as worker: # Create execution that doesn't have the required parameter execution = Execution( + docket=docket, function=task_with_missing_arg, args=(), kwargs={}, # Missing the required parameter @@ -1346,6 +1355,7 @@ async def task_with_missing_concurrency_arg( async with Worker(docket) as worker: # Create execution without the required parameter execution = Execution( + docket=docket, function=task_with_missing_concurrency_arg, args=(), kwargs={}, # Missing the required "missing_param" @@ -1659,7 +1669,7 @@ async def successful_task(): assert task_executed, "Task should have executed successfully" # Verify cleanup - await checker.verify_keys_returned_to_baseline("successful task execution") + await checker.verify_keys_increased("successful task execution") async def test_redis_key_cleanup_failed_task(docket: Docket, worker: Worker) -> None: @@ -1692,7 +1702,7 @@ async def failing_task(): assert task_attempted, "Task should have been attempted" # Verify cleanup despite failure - await checker.verify_keys_returned_to_baseline("failed task execution") + await checker.verify_keys_increased("failed task execution") async def test_redis_key_cleanup_cancelled_task(docket: Docket, worker: Worker) -> None: diff --git a/uv.lock b/uv.lock index 7587cab..ba1b30d 100644 --- a/uv.lock +++ b/uv.lock @@ -1520,6 +1520,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-cov" }, + { name = "pytest-timeout" }, { name = "pytest-xdist" }, { name = "ruff" }, ] @@ -1568,6 +1569,7 @@ dev = [ { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-cov", specifier = ">=6.0.0" }, + { name = "pytest-timeout", specifier = ">=2.4.0" }, { name = "pytest-xdist", specifier = ">=3.6.1" }, { name = "ruff", specifier = ">=0.9.7" }, ] @@ -1664,6 +1666,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "pytest-xdist" version = "3.8.0"