Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ dev = [
"pytest>=8.3.4",
"pytest-asyncio>=0.24.0",
"pytest-cov>=6.0.0",
"pytest-flakefinder>=1.1.0",
"pytest-repeat>=0.9.3",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.6.1",
"ruff>=0.9.7",
Expand Down
2 changes: 1 addition & 1 deletion src/docket/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def set_progress_start_time(task_id: TaskID, started_at: datetime) -> None:

# Initialize progress task if we have progress data
if current_val > 0 and total_val > 0:
progress_task_id = active_progress.add_task(
progress_task_id = active_progress.add_task( # pragma: no cover
progress_message or "Processing...",
total=total_val,
completed=current_val,
Expand Down
4 changes: 2 additions & 2 deletions src/docket/docket.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
url: str = "redis://localhost:6379/0",
heartbeat_interval: timedelta = timedelta(seconds=2),
missed_heartbeats: int = 5,
execution_ttl: timedelta = timedelta(hours=1),
execution_ttl: timedelta = timedelta(minutes=15),
result_storage: AsyncKeyValue | None = None,
) -> None:
"""
Expand All @@ -167,7 +167,7 @@ def __init__(
missed_heartbeats: How many heartbeats a worker can miss before it is
considered dead.
execution_ttl: How long to keep completed or failed execution state records
in Redis before they expire. Defaults to 1 hour.
in Redis before they expire. Defaults to 15 minutes.
"""
self.name = name
self.url = url
Expand Down
60 changes: 42 additions & 18 deletions src/docket/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
import json
import logging
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -692,7 +692,8 @@ async def mark_as_completed(self, result_key: str | None = None) -> None:
Args:
result_key: Optional key where the task result is stored

Sets TTL on state data (from docket.execution_ttl) and deletes progress data.
Sets TTL on state data (from docket.execution_ttl), or deletes state
immediately if execution_ttl is 0. Also deletes progress data.
"""
completed_at = datetime.now(timezone.utc).isoformat()
async with self.docket.redis() as redis:
Expand All @@ -706,10 +707,12 @@ async def mark_as_completed(self, result_key: str | None = None) -> None:
self._redis_key,
mapping=mapping,
)
# Set TTL from docket configuration
await redis.expire(
self._redis_key, int(self.docket.execution_ttl.total_seconds())
)
# Set TTL from docket configuration, or delete if TTL=0
if self.docket.execution_ttl:
ttl_seconds = int(self.docket.execution_ttl.total_seconds())
await redis.expire(self._redis_key, ttl_seconds)
else:
await redis.delete(self._redis_key)
self.state = ExecutionState.COMPLETED
self.result_key = result_key
# Delete progress data
Expand All @@ -728,7 +731,8 @@ async def mark_as_failed(
error: Optional error message describing the failure
result_key: Optional key where the exception is stored

Sets TTL on state data (from docket.execution_ttl) and deletes progress data.
Sets TTL on state data (from docket.execution_ttl), or deletes state
immediately if execution_ttl is 0. Also deletes progress data.
"""
completed_at = datetime.now(timezone.utc).isoformat()
async with self.docket.redis() as redis:
Expand All @@ -741,10 +745,12 @@ async def mark_as_failed(
if result_key is not None:
mapping["result_key"] = result_key
await redis.hset(self._redis_key, mapping=mapping)
# Set TTL from docket configuration
await redis.expire(
self._redis_key, int(self.docket.execution_ttl.total_seconds())
)
# Set TTL from docket configuration, or delete if TTL=0
if self.docket.execution_ttl:
ttl_seconds = int(self.docket.execution_ttl.total_seconds())
await redis.expire(self._redis_key, ttl_seconds)
else:
await redis.delete(self._redis_key)
self.state = ExecutionState.FAILED
self.result_key = result_key
# Delete progress data
Expand All @@ -758,29 +764,47 @@ async def mark_as_failed(
state_data["error"] = error
await self._publish_state(state_data)

async def get_result(self, *, timeout: datetime | None = None) -> Any:
async def get_result(
self,
*,
timeout: timedelta | None = None,
deadline: datetime | None = None,
) -> Any:
"""Retrieve the result of this task execution.

If the execution is not yet complete, this method will wait using
pub/sub for state updates until completion.

Args:
timeout: Optional absolute datetime when to stop waiting.
If None, waits indefinitely.
timeout: Optional duration to wait before giving up.
If None and deadline is None, waits indefinitely.
deadline: Optional absolute datetime when to stop waiting.
If None and timeout is None, waits indefinitely.

Returns:
The result of the task execution, or None if the task returned None.

Raises:
ValueError: If both timeout and deadline are provided
Exception: If the task failed, raises the stored exception
TimeoutError: If timeout is reached before execution completes
TimeoutError: If timeout/deadline is reached before execution completes
"""
# Validate that only one time limit is provided
if timeout is not None and deadline is not None:
raise ValueError("Cannot specify both timeout and deadline")

# Convert timeout to deadline if provided
if timeout is not None:
deadline = datetime.now(timezone.utc) + timeout

# Wait for execution to complete if not already done
if self.state not in (ExecutionState.COMPLETED, ExecutionState.FAILED):
# Calculate timeout duration if absolute timeout provided
# Calculate timeout duration if absolute deadline provided
timeout_seconds = None
if timeout is not None:
timeout_seconds = (timeout - datetime.now(timezone.utc)).total_seconds()
if deadline is not None:
timeout_seconds = (
deadline - datetime.now(timezone.utc)
).total_seconds()
if timeout_seconds <= 0:
raise TimeoutError(
f"Timeout waiting for execution {self.key} to complete"
Expand Down
21 changes: 12 additions & 9 deletions src/docket/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ async def _execute(self, execution: Execution) -> None:
if not rescheduled:
# Store result if appropriate
result_key = None
if result is not None:
if result is not None and self.docket.execution_ttl:
# Serialize and store result
pickled_result = cloudpickle.dumps(result) # type: ignore[arg-type]
# Base64-encode for JSON serialization
Expand Down Expand Up @@ -726,14 +726,17 @@ async def _execute(self, execution: Execution) -> None:

# Store exception in result_storage
result_key = None
pickled_exception = cloudpickle.dumps(e) # type: ignore[arg-type]
# Base64-encode for JSON serialization
encoded_exception = base64.b64encode(pickled_exception).decode("ascii")
result_key = execution.key
ttl_seconds = int(self.docket.execution_ttl.total_seconds())
await self.docket.result_storage.put(
result_key, {"data": encoded_exception}, ttl=ttl_seconds
)
if self.docket.execution_ttl:
pickled_exception = cloudpickle.dumps(e) # type: ignore[arg-type]
# Base64-encode for JSON serialization
encoded_exception = base64.b64encode(pickled_exception).decode(
"ascii"
)
result_key = execution.key
ttl_seconds = int(self.docket.execution_ttl.total_seconds())
await self.docket.result_storage.put(
result_key, {"data": encoded_exception}, ttl=ttl_seconds
)

# Mark execution as failed with error message
error_msg = f"{type(e).__name__}: {str(e)}"
Expand Down
45 changes: 32 additions & 13 deletions tests/cli/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .waiting import (
wait_for_execution_state,
wait_for_progress_data,
wait_for_watch_subscribed,
wait_for_worker_assignment,
)

Expand Down Expand Up @@ -78,30 +79,48 @@ async def test_watch_failed_task(docket: Docket, the_task: AsyncMock):
async def test_watch_running_task_until_completion(docket: Docket, worker: Worker):
"""Watch should monitor task from running to completion."""

async def slower_task():
# Sleep long enough for watch command to connect and receive events
# Coordination key for synchronization
ready_key = f"{docket.name}:test:watch:ready"

async def coordinated_task():
# Wait for watch to be subscribed before completing
async with docket.redis() as redis:
while not await redis.get(ready_key): # type: ignore[misc]
await asyncio.sleep(0.01)
# Now do the work
await asyncio.sleep(0.5)

docket.register(slower_task)
await docket.add(slower_task, key="slower-task")()
docket.register(coordinated_task)
await docket.add(coordinated_task, key="slower-task")()

# Start worker in background
worker_task = asyncio.create_task(worker.run_until_finished())

# Wait for worker to claim the task
await wait_for_execution_state(docket, "slower-task", ExecutionState.RUNNING)

# Watch should receive state events while task runs
result = await run_cli(
"watch",
"slower-task",
"--url",
docket.url,
"--docket",
docket.name,
timeout=2.0,
# Start watch subprocess in background
watch_task = asyncio.create_task(
run_cli(
"watch",
"slower-task",
"--url",
docket.url,
"--docket",
docket.name,
timeout=5.0,
)
)

# Wait for watch to subscribe (deterministic synchronization)
await wait_for_watch_subscribed(docket, "slower-task")

# Signal task it can complete now
async with docket.redis() as redis:
await redis.set(ready_key, "1", ex=10)

# Wait for watch to finish
result = await watch_task
await worker_task

assert result.exit_code == 0
Expand Down
41 changes: 41 additions & 0 deletions tests/cli/waiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,44 @@ async def wait_for_worker_assignment(
raise TimeoutError( # pragma: no cover
f"No worker was assigned to task {key} within {timeout}s"
)


async def wait_for_watch_subscribed(
docket: Docket,
key: str,
*,
timeout: float = 3.0,
interval: float = 0.01,
) -> None:
"""Wait for watch command to subscribe to state channel.

Uses Redis PUBSUB NUMSUB to detect when watch has subscribed.
This ensures watch won't miss state events published after subscription.

Args:
docket: Docket instance
key: Task key
timeout: Maximum time to wait in seconds
interval: Sleep interval between checks in seconds

Raises:
TimeoutError: If watch doesn't subscribe within timeout
"""
start_time = time.monotonic()
state_channel = f"{docket.name}:state:{key}"

while time.monotonic() - start_time < timeout:
async with docket.redis() as redis:
result = await redis.pubsub_numsub(state_channel) # type: ignore[misc]
# Returns list of tuples: [(channel_bytes, count), ...]
for channel, count in result: # type: ignore[misc]
if isinstance(channel, bytes): # pragma: no branch
channel = channel.decode()
if channel == state_channel and count > 0: # pragma: no branch
return

await asyncio.sleep(interval)

raise TimeoutError( # pragma: no cover
f"Watch command did not subscribe to {state_channel} within {timeout}s"
)
Loading