Skip to content
2 changes: 2 additions & 0 deletions src/docket/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Depends,
ExponentialRetry,
Perpetual,
Progress,
Retry,
TaskArgument,
TaskKey,
Expand All @@ -41,6 +42,7 @@
"ExponentialRetry",
"Logged",
"Perpetual",
"Progress",
"Retry",
"TaskArgument",
"TaskKey",
Expand Down
140 changes: 140 additions & 0 deletions src/docket/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

import typer
from rich.console import Console
from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn
from rich.table import Table

from . import __version__, tasks
from .docket import Docket, DocketSnapshot, WorkerInfo
from .execution import Operator
from .state import TaskStateStore
from .worker import Worker

app: typer.Typer = typer.Typer(
Expand Down Expand Up @@ -810,6 +812,144 @@ async def run() -> DocketSnapshot:
console.print(stats_table)


@app.command(help="Watch real-time progress for one or more tasks")
def watch(
task_keys: Annotated[
list[str],
typer.Argument(help="Task key(s) to monitor. You can specify multiple keys."),
],
docket_: Annotated[
str,
typer.Option(
"--docket",
help="The name of the docket",
envvar="DOCKET_NAME",
),
] = "docket",
url: Annotated[
str,
typer.Option(
help="The URL of the Redis server",
envvar="DOCKET_URL",
callback=validate_url,
),
] = "redis://localhost:6379/0",
poll_interval: Annotated[
float,
typer.Option(
"--poll-interval",
help="Seconds between progress checks for completed tasks",
),
] = 1.0,
) -> None:
"""Watch real-time progress for tasks using Redis Pub/Sub.

This command monitors progress updates in real-time for one or more tasks.
It polls the initial state, then subscribes to live updates via Redis Pub/Sub.

Note: Tasks must use Progress(publish_events=True) for real-time updates.

Examples:
docket watch task-key-123
docket watch task1 task2 task3
"""
console = Console()

async def monitor() -> None:
async with Docket(name=docket_, url=url) as docket:
store = TaskStateStore(docket, docket.record_ttl)

# Track which tasks are completed
completed_tasks: set[str] = set()
task_bars: dict[str, Any] = {}

# Create Rich progress display
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]{task.description}"),
BarColumn(),
TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
console=console,
) as progress:
# Initialize progress bars for each task
for key in task_keys:
state = await store.get_task_state(key)
if state is None:
console.print(f"[yellow]Task {key!r} not found[/yellow]")
continue

# Create progress bar
description = f"{key}"
if state.completed_at:
description += " (completed)"
completed_tasks.add(key)

task_bar = progress.add_task(
description,
total=state.progress.total,
completed=state.progress.current,
)
task_bars[key] = task_bar

# Show timestamp info
console.print(
f"[dim]{key}: Started at {local_time(state.started_at)}[/dim]"
)

if not task_bars:
console.print("[red]No valid tasks found[/red]")
return

# Monitor progress updates
try:
async for key, progress_info in docket.monitor_progress(task_keys):
if key not in task_bars:
continue

task_bar = task_bars[key]
progress.update(
task_bar,
completed=progress_info.current,
total=progress_info.total,
)

# Check if task completed (current == total)
if (
progress_info.current == progress_info.total
and key not in completed_tasks
):
completed_tasks.add(key)
progress.update(
task_bar,
description=f"{key} (completed)",
)

# Show completion timestamp
state = await store.get_task_state(key)
if state and state.completed_at:
console.print(
f"[green]{key}: Completed at {local_time(state.completed_at)}[/green]"
)

# Exit if all tasks completed
if len(completed_tasks) == len(task_bars):
console.print(
f"[green]All {len(task_bars)} task(s) completed![/green]"
)
break

# Periodically check if tasks still exist
await asyncio.sleep(poll_interval)

except KeyboardInterrupt:
console.print("\n[yellow]Monitoring interrupted[/yellow]")

try:
asyncio.run(monitor())
except KeyboardInterrupt:
console.print("\n[yellow]Monitoring interrupted[/yellow]")


workers_app: typer.Typer = typer.Typer(
help="Look at the workers on a docket", no_args_is_help=True
)
Expand Down
137 changes: 137 additions & 0 deletions src/docket/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from .docket import Docket
from .execution import Execution, TaskFunction, get_signature
from .instrumentation import CACHE_SIZE
from .state import ProgressInfo, TaskStateStore


if TYPE_CHECKING: # pragma: no cover
from .worker import Worker
Expand Down Expand Up @@ -652,6 +654,141 @@ def is_bypassed(self) -> bool:
return self._initialized and self._concurrency_key is None


class Progress(Dependency):
"""Allows a task to report intermediate progress during execution.

Progress is stored in Redis and persists after task completion as a tombstone
record (with TTL). Visible via snapshots or get_progress().

Example:

```python
@task
async def long_running(progress: Progress = Progress()) -> None:
batch = get_some_work()
await progress.set_total(len(batch))
for item in batch:
do_some_work(item)
await progress.increment() # default 1
```
"""

single: bool = True

def __init__(self, publish_events: bool = False) -> None:
"""Initialize Progress dependency.

Args:
publish_events: If True, publish progress updates to Redis Pub/Sub
channel for real-time monitoring (default: False)
"""
# Track current state
self._current: int = 0
self._publish_events = publish_events

async def __aenter__(self) -> "Progress":
from docket.state import DEFAULT_PROGRESS_TOTAL

execution = self.execution.get()
docket = self.docket.get()

self._key = execution.key
self._docket = docket
self._total = DEFAULT_PROGRESS_TOTAL
self._current = 0
self._store = TaskStateStore(docket, docket.record_ttl)

await self._store.set_task_progress(
self._key, ProgressInfo(current=self._current, total=self._total)
)

return self

async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_value: BaseException | None,
_traceback: TracebackType | None,
) -> bool:
"""No cleanup needed - updates are applied immediately."""
return False

async def _publish_event(self) -> None:
"""Publish progress update to Redis Pub/Sub channel."""
if not self._publish_events:
return

import json

message = json.dumps(
{
"key": self._key,
"current": self._current,
"total": self._total,
}
)

async with self._docket.redis() as redis:
await redis.publish( # pyright: ignore[reportUnknownMemberType,reportGeneralTypeIssues]
f"{self._docket.name}:progress-events", message
)

async def set_total(self, total: int) -> None:
"""Set the total expected progress value.

Args:
total: Total expected progress value (must be positive)

Raises:
ValueError: If total is not positive
"""
if total <= 0:
raise ValueError(f"Progress total must be positive, got {total}")
self._total = total
await self._store.set_task_progress(
self._key, ProgressInfo(current=self._current, total=self._total)
)
await self._publish_event()

async def increment(self, amount: int = 1) -> None:
"""Increment progress by the given amount (default 1).

Args:
amount: Amount to increment by (default 1)
"""
self._current = await self._store.increment_task_progress(self._key, amount)
await self._publish_event()

async def set(self, current: int) -> None:
"""Set the current progress value directly.

Args:
current: Current progress value (must be non-negative and <= total)

Raises:
ValueError: If current is negative or exceeds total
"""
if current < 0:
raise ValueError(f"Progress current must be non-negative, got {current}")
if current > self._total:
raise ValueError(
f"Progress current ({current}) cannot exceed total ({self._total})"
)
self._current = current
await self._store.set_task_progress(
self._key, ProgressInfo(current=self._current, total=self._total)
)
await self._publish_event()

async def get(self) -> "ProgressInfo | None":
"""Get current progress info.

Returns:
ProgressInfo if progress exists, None otherwise
"""
return await self._store.get_task_progress(self._key)


D = TypeVar("D", bound=Dependency)


Expand Down
Loading
Loading