Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 135 additions & 32 deletions agentlightning/runner/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,21 @@

import asyncio
import logging
import threading
import time
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence, TypeVar, cast
from contextlib import suppress
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Literal,
Optional,
Sequence,
TypeVar,
cast,
)

from opentelemetry.sdk.trace import ReadableSpan

Expand All @@ -30,6 +43,7 @@
RolloutRawResult,
Span,
)
from agentlightning.utils.system_snapshot import system_snapshot

if TYPE_CHECKING:
from agentlightning.execution.events import ExecutionEvent
Expand All @@ -52,19 +66,31 @@ class LitAgentRunner(Runner[T_task]):
worker_id: Identifier for the active worker process, if any.
"""

def __init__(self, tracer: Tracer, max_rollouts: Optional[int] = None, poll_interval: float = 5.0) -> None:
def __init__(
self,
tracer: Tracer,
max_rollouts: Optional[int] = None,
poll_interval: float = 5.0,
heartbeat_interval: float = 10.0,
heartbeat_launch_mode: Literal["asyncio", "thread"] = "asyncio",
) -> None:
"""Initialize the agent runner.

Args:
tracer: [`Tracer`][agentlightning.Tracer] used for rollout spans.
max_rollouts: Optional cap on iterations processed by
[`iter`][agentlightning.LitAgentRunner.iter].
poll_interval: Seconds to wait between store polls when no work is available.
heartbeat_interval: Seconds to wait between sending heartbeats to the store.
heartbeat_launch_mode: Launch mode for the heartbeat loop. Can be "asyncio" or "thread".
"asyncio" is the default and recommended mode. Use "thread" if you are experiencing blocking coroutines.
"""
super().__init__()
self._tracer = tracer
self._max_rollouts = max_rollouts
self._poll_interval = poll_interval
self._heartbeat_interval = heartbeat_interval
self._heartbeat_launch_mode = heartbeat_launch_mode

# Set later
self._agent: Optional[LitAgent[T_task]] = None
Expand Down Expand Up @@ -304,6 +330,67 @@ async def _post_process_rollout_result(

return trace_spans

async def _emit_heartbeat(self, store: LightningStore) -> None:
"""Send a heartbeat tick to the store."""
worker_id = self.get_worker_id()

try:
await store.update_worker(worker_id, system_snapshot())
except asyncio.CancelledError:
# bypass the exception
raise
except Exception:
logger.exception("%s Unable to update worker heartbeat.", self._log_prefix())

def _start_heartbeat_loop(self, store: LightningStore) -> Optional[Callable[[], Awaitable[None]]]:
"""Start a background heartbeat loop and return an async stopper."""

if self._heartbeat_interval <= 0:
return None

if self.worker_id is None:
logger.warning("%s Cannot start heartbeat loop without worker_id.", self._log_prefix())
return None

if self._heartbeat_launch_mode == "asyncio":
stop_event = asyncio.Event()

async def heartbeat_loop() -> None:
while not stop_event.is_set():
await self._emit_heartbeat(store)
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(stop_event.wait(), timeout=self._heartbeat_interval)

task = asyncio.create_task(heartbeat_loop(), name=f"{self.get_worker_id()}-heartbeat")

async def stop() -> None:
stop_event.set()
with suppress(asyncio.CancelledError):
await task

return stop

if self._heartbeat_launch_mode == "thread":
stop_evt = threading.Event()

def thread_worker() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while not stop_evt.is_set():
loop.run_until_complete(self._emit_heartbeat(store))
stop_evt.wait(self._heartbeat_interval)

thread = threading.Thread(target=thread_worker, name=f"{self.get_worker_id()}-heartbeat", daemon=True)
thread.start()

async def stop() -> None:
stop_evt.set()
await asyncio.to_thread(thread.join)

return stop

raise ValueError(f"Unsupported heartbeat launch mode: {self._heartbeat_launch_mode}")

async def _sleep_until_next_poll(self, event: Optional[ExecutionEvent] = None) -> None:
"""Sleep until the next poll interval, with optional event-based interruption.

Expand Down Expand Up @@ -450,39 +537,49 @@ async def iter(self, *, event: Optional[ExecutionEvent] = None) -> None:
logger.info(f"{self._log_prefix()} Started async rollouts (max: {self._max_rollouts or 'unlimited'}).")
store = self.get_store()

while not (event is not None and event.is_set()) and (
self._max_rollouts is None or num_tasks_processed < self._max_rollouts
):
# Retrieve the next rollout
next_rollout: Optional[Rollout] = None
while not (event is not None and event.is_set()):
logger.debug(f"{self._log_prefix()} Try to poll for next rollout.")
next_rollout = await store.dequeue_rollout()
if next_rollout is None:
logger.debug(f"{self._log_prefix()} No rollout to poll. Waiting for {self._poll_interval} seconds.")
await self._sleep_until_next_poll(event)
else:
break
stop_heartbeat = self._start_heartbeat_loop(store)

if next_rollout is None:
return

try:
# Claim the rollout but updating the current worker id
await store.update_attempt(
next_rollout.rollout_id, next_rollout.attempt.attempt_id, worker_id=self.get_worker_id()
)
except Exception:
# This exception could happen if the rollout is dequeued and the other end died for some reason
logger.exception(f"{self._log_prefix()} Exception during update_attempt, giving up the rollout.")
continue
try:
while not (event is not None and event.is_set()) and (
self._max_rollouts is None or num_tasks_processed < self._max_rollouts
):
# Retrieve the next rollout
next_rollout: Optional[Rollout] = None
while not (event is not None and event.is_set()):
logger.debug(f"{self._log_prefix()} Try to poll for next rollout.")
next_rollout = await store.dequeue_rollout(worker_id=self.get_worker_id())
if next_rollout is None:
logger.debug(
f"{self._log_prefix()} No rollout to poll. Waiting for {self._poll_interval} seconds."
)
await self._sleep_until_next_poll(event)
else:
break

# Execute the step
await self._step_impl(next_rollout)
if next_rollout is None:
return

num_tasks_processed += 1
if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
logger.info(f"{self._log_prefix()} Progress: {num_tasks_processed}/{self._max_rollouts or 'unlimited'}")
try:
# Claim the rollout but updating the current worker id
await store.update_attempt(
next_rollout.rollout_id, next_rollout.attempt.attempt_id, worker_id=self.get_worker_id()
)
except Exception:
# This exception could happen if the rollout is dequeued and the other end died for some reason
logger.exception(f"{self._log_prefix()} Exception during update_attempt, giving up the rollout.")
continue

# Execute the step
await self._step_impl(next_rollout)

num_tasks_processed += 1
if num_tasks_processed % 10 == 0 or num_tasks_processed == 1:
logger.info(
f"{self._log_prefix()} Progress: {num_tasks_processed}/{self._max_rollouts or 'unlimited'}"
)
finally:
if stop_heartbeat is not None:
await stop_heartbeat()

logger.info(f"{self._log_prefix()} Finished async rollouts. Processed {num_tasks_processed} tasks.")

Expand Down Expand Up @@ -526,6 +623,12 @@ async def step(
resources_id = None

attempted_rollout = await self.get_store().start_rollout(input=input, mode=mode, resources_id=resources_id)
# Register the attempt as running by the current worker
await self.get_store().update_attempt(
attempted_rollout.rollout_id,
attempted_rollout.attempt.attempt_id,
worker_id=self.get_worker_id(),
)
rollout_id = await self._step_impl(attempted_rollout, raise_on_exception=True)

completed_rollout = await store.get_rollout_by_id(rollout_id)
Expand Down
54 changes: 53 additions & 1 deletion agentlightning/store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RolloutStatus,
Span,
TaskInput,
Worker,
)


Expand Down Expand Up @@ -167,7 +168,7 @@ async def enqueue_rollout(
"""
raise NotImplementedError()

async def dequeue_rollout(self) -> Optional[AttemptedRollout]:
async def dequeue_rollout(self, worker_id: Optional[str] = None) -> Optional[AttemptedRollout]:
"""Claim the oldest queued rollout and transition it to `preparing`.

This function do not block.
Expand All @@ -180,6 +181,8 @@ async def dequeue_rollout(self) -> Optional[AttemptedRollout]:
the number of attempts already registered for the rollout plus one.
* Return an [`AttemptedRollout`][agentlightning.AttemptedRollout] snapshot so the
runner knows both rollout metadata and the attempt identifier.
* Optionally refresh the caller's [`Worker`][agentlightning.Worker] telemetry
(e.g., `last_dequeue_time`) when `worker_id` is provided.

Returns:
The next attempt to execute, or `None` when no eligible rollouts are queued.
Expand Down Expand Up @@ -527,6 +530,12 @@ async def update_attempt(
Similar to [`update_rollout()`][agentlightning.LightningStore.update_rollout],
parameters also default to the sentinel [`UNSET`][agentlightning.store.base.UNSET].

If `worker_id` is present, the worker status will be updated following the rules:

1. If attempt status is "succeeded" or "failed", the corresponding worker status will be set to "idle".
2. If attempt status is "unresponsive" or "timeout", the corresponding worker status will be set to "unknown".
3. Otherwise, the worker status will be set to "busy".

Args:
rollout_id: Identifier of the rollout whose attempt will be updated.
attempt_id: Attempt identifier or `"latest"` as a convenience.
Expand All @@ -543,3 +552,46 @@ async def update_attempt(
ValueError: Implementations must raise when the rollout or attempt is unknown.
"""
raise NotImplementedError()

async def query_workers(
self,
) -> List[Worker]:
"""Query all workers in the system.

Returns:
A list of all workers.
"""
raise NotImplementedError()

async def get_worker_by_id(self, worker_id: str) -> Optional[Worker]:
"""Retrieve a single worker by identifier.

Args:
worker_id: Identifier of the worker.

Returns:
The worker record if it exists, otherwise `None`.

Raises:
NotImplementedError: Subclasses must implement lookup semantics.
"""
raise NotImplementedError()

async def update_worker(
self,
worker_id: str,
heartbeat_stats: Dict[str, Any] | Unset = UNSET,
) -> Worker:
"""Record a heartbeat for `worker_id` and refresh telemetry.

Implementations must treat this API as heartbeat-only: it should snapshot
the latest stats when provided, stamp `last_heartbeat_time` with the
current wall clock, and rely on other store mutations (`dequeue_rollout`,
`update_attempt`, etc.) to drive the worker's busy/idle status,
assignment, and activity timestamps.

Args:
worker_id: Identifier of the worker to update.
heartbeat_stats: Replacement worker heartbeat statistics (non-null when provided).
"""
raise NotImplementedError()
Loading
Loading