diff --git a/src/conductor/client/automator/lease_tracker.py b/src/conductor/client/automator/lease_tracker.py index 794e54e2..f98b5253 100644 --- a/src/conductor/client/automator/lease_tracker.py +++ b/src/conductor/client/automator/lease_tracker.py @@ -1,6 +1,27 @@ -"""Shared lease extension (heartbeat) tracking for TaskRunner and AsyncTaskRunner.""" +"""Centralized lease extension (heartbeat) management for Conductor task runners. +Architecture: + LeaseManager runs a single background daemon thread that periodically checks + for tasks needing lease extension heartbeats. Due heartbeats are dispatched + to a small fixed ThreadPoolExecutor for parallel, non-blocking API calls. + + This decouples heartbeat work entirely from worker poll loops, preventing + heartbeat API calls (and their retries) from blocking task polling. + + Thread-safe: track() and untrack() can be called from any thread or event loop. +""" + +import logging +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass +from typing import Any, Dict, Optional + +from conductor.client.http.models.task_result import TaskResult + +logger = logging.getLogger(__name__) # Lease extension constants (matches Java SDK) LEASE_EXTEND_RETRY_COUNT = 3 @@ -15,3 +36,189 @@ class LeaseInfo: response_timeout_seconds: float last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start) interval_seconds: float # 80% of responseTimeoutSeconds + task_client: Any = None # Sync TaskResourceApi for sending heartbeats + + +class LeaseManager: + """Centralized lease extension manager for all workers in a process. + + One background daemon thread checks for due heartbeats at a fixed interval. + A small ThreadPoolExecutor sends heartbeat API calls in parallel. + Poll loops are never blocked by heartbeat work. + + Usage: + manager = LeaseManager.get_instance() + manager.track(task_id, workflow_id, timeout, task_client) + # ... task completes ... + manager.untrack(task_id) + """ + + _instance: Optional['LeaseManager'] = None + _instance_lock = threading.Lock() + _instance_pid: Optional[int] = None + + @classmethod + def get_instance(cls, check_interval: float = 1.0, + max_heartbeat_workers: int = 4) -> 'LeaseManager': + """Get or create the process-wide LeaseManager singleton. + + Fork-safe: a new instance is created after fork (threads don't survive fork). + """ + current_pid = os.getpid() + if cls._instance is None or cls._instance_pid != current_pid: + with cls._instance_lock: + if cls._instance is None or cls._instance_pid != current_pid: + cls._instance = cls( + check_interval=check_interval, + max_heartbeat_workers=max_heartbeat_workers, + ) + cls._instance_pid = current_pid + return cls._instance + + @classmethod + def _reset_instance(cls): + """Reset the singleton. For testing only.""" + with cls._instance_lock: + if cls._instance is not None: + cls._instance.shutdown() + cls._instance = None + cls._instance_pid = None + + def __init__(self, check_interval: float = 1.0, max_heartbeat_workers: int = 4): + self._tracked: Dict[str, LeaseInfo] = {} + self._lock = threading.Lock() + self._executor = ThreadPoolExecutor( + max_workers=max_heartbeat_workers, + thread_name_prefix="lease-heartbeat", + ) + self._stop_event = threading.Event() + self._check_interval = check_interval + self._thread: Optional[threading.Thread] = None + self._started = False + self._start_lock = threading.Lock() + + def _ensure_started(self) -> None: + """Lazily start the background thread on first track() call.""" + if self._started: + return + with self._start_lock: + if not self._started: + self._thread = threading.Thread( + target=self._run, daemon=True, name="lease-manager", + ) + self._thread.start() + self._started = True + logger.debug( + "LeaseManager started (check_interval=%.1fs)", self._check_interval, + ) + + def track(self, task_id: str, workflow_instance_id: str, + response_timeout_seconds: float, task_client: Any) -> None: + """Start tracking a task for lease extension heartbeats. + + Thread-safe. Can be called from any worker thread or event loop. + + Args: + task_id: Conductor task ID. + workflow_instance_id: Workflow instance this task belongs to. + response_timeout_seconds: The task's server-side response timeout. + task_client: A **sync** TaskResourceApi for sending heartbeat API calls. + """ + interval = response_timeout_seconds * LEASE_EXTEND_DURATION_FACTOR + if interval < 1: + logger.debug( + "Skipping lease tracking for task %s (interval %.1fs too short)", + task_id, interval, + ) + return + + info = LeaseInfo( + task_id=task_id, + workflow_instance_id=workflow_instance_id, + response_timeout_seconds=response_timeout_seconds, + last_heartbeat_time=time.monotonic(), + interval_seconds=interval, + task_client=task_client, + ) + with self._lock: + self._tracked[task_id] = info + self._ensure_started() + logger.debug( + "Tracking lease for task %s (timeout=%ss, heartbeat every %ss)", + task_id, response_timeout_seconds, interval, + ) + + def untrack(self, task_id: str) -> None: + """Stop tracking a task. Thread-safe.""" + with self._lock: + removed = self._tracked.pop(task_id, None) + if removed is not None: + logger.debug("Untracked lease for task %s", task_id) + + @property + def tracked_count(self) -> int: + """Number of currently tracked tasks.""" + with self._lock: + return len(self._tracked) + + # -- Background thread ----------------------------------------------------- + + def _run(self) -> None: + """Background loop — checks for due heartbeats at fixed intervals.""" + while not self._stop_event.is_set(): + try: + self._check_and_send() + except Exception as e: + logger.error("LeaseManager error: %s", e) + self._stop_event.wait(self._check_interval) + + def _check_and_send(self) -> None: + """Find tasks with due heartbeats and dispatch to the thread pool.""" + now = time.monotonic() + with self._lock: + due = [ + info for info in self._tracked.values() + if now - info.last_heartbeat_time >= info.interval_seconds + ] + for info in due: + # Update timestamp immediately to prevent double-dispatch on next tick + info.last_heartbeat_time = time.monotonic() + self._executor.submit(self._send_heartbeat, info) + + @staticmethod + def _send_heartbeat(info: LeaseInfo) -> None: + """Send a single lease extension heartbeat with retry. + + Runs in a pool thread — blocking retries only block the pool thread, + never a poll loop. + """ + result = TaskResult( + task_id=info.task_id, + workflow_instance_id=info.workflow_instance_id, + extend_lease=True, + ) + for attempt in range(LEASE_EXTEND_RETRY_COUNT): + try: + info.task_client.update_task(body=result) + logger.debug("Extended lease for task %s", info.task_id) + return + except Exception as e: + if attempt < LEASE_EXTEND_RETRY_COUNT - 1: + time.sleep(0.5 * (attempt + 2)) + else: + logger.error( + "Failed to extend lease for task %s after %d attempts: %s", + info.task_id, LEASE_EXTEND_RETRY_COUNT, e, + ) + + # -- Lifecycle ------------------------------------------------------------- + + def shutdown(self) -> None: + """Stop the background thread and thread pool.""" + self._stop_event.set() + if self._started and self._thread is not None: + self._thread.join(timeout=5) + self._executor.shutdown(wait=False) + with self._lock: + self._tracked.clear() + logger.debug("LeaseManager shut down") diff --git a/src/conductor/client/automator/task_runner.py b/src/conductor/client/automator/task_runner.py index 242c8799..af566de1 100644 --- a/src/conductor/client/automator/task_runner.py +++ b/src/conductor/client/automator/task_runner.py @@ -36,7 +36,7 @@ from conductor.client.worker.worker_config import resolve_worker_config, get_worker_config_oneline from conductor.client.worker.exception import NonRetryableException from conductor.client.automator.json_schema_generator import generate_json_schema_from_function -from conductor.client.automator.lease_tracker import LeaseInfo, LEASE_EXTEND_RETRY_COUNT, LEASE_EXTEND_DURATION_FACTOR +from conductor.client.automator.lease_tracker import LeaseManager logger = logging.getLogger( Configuration.get_logging_formatted_name( @@ -112,8 +112,9 @@ def __init__( self._consecutive_empty_polls = 0 # Track empty polls to implement backoff self._shutdown = False # Flag to indicate graceful shutdown self._use_update_v2 = True # Will be set to False if server doesn't support v2 endpoint - self._lease_info = {} # task_id -> LeaseInfo for lease extension heartbeats - self._lease_lock = threading.Lock() # Protects _lease_info for free-threaded Python + self._lease_manager = LeaseManager.get_instance() + self._tracked_task_ids = set() # Local set for cleanup on shutdown + self._tracked_task_ids_lock = threading.Lock() def run(self) -> None: if self.configuration is not None: @@ -153,9 +154,12 @@ def _cleanup(self) -> None: """Clean up resources - called on exit.""" logger.debug("Cleaning up TaskRunner resources...") - # Stop all lease extension tracking - with self._lease_lock: - self._lease_info.clear() + # Untrack all tasks this runner was tracking from the shared LeaseManager + with self._tracked_task_ids_lock: + task_ids = list(self._tracked_task_ids) + self._tracked_task_ids.clear() + for task_id in task_ids: + self._lease_manager.untrack(task_id) # Shutdown ThreadPoolExecutor (EAFP style - more Pythonic) try: @@ -429,9 +433,6 @@ def __register_task_definition(self) -> None: def run_once(self) -> None: try: - # Send lease extension heartbeats for any tasks that are due - self._send_due_heartbeats() - # Check completed async tasks first (non-blocking) self.__check_completed_async_tasks() @@ -1077,74 +1078,29 @@ def __update_task(self, task_result: TaskResult): return None - # -- Lease extension (heartbeat) methods ---------------------------------- + # -- Lease extension (heartbeat) delegation to LeaseManager ---------------- def _track_lease(self, task: Task) -> None: - """Start tracking a task for lease extension heartbeat.""" - lease_enabled = getattr(self.worker, 'lease_extend_enabled', False) - if not lease_enabled: + """Start tracking a task for lease extension via the shared LeaseManager.""" + if not getattr(self.worker, 'lease_extend_enabled', False): return timeout = getattr(task, 'response_timeout_seconds', None) or 0 if timeout <= 0: return - interval = timeout * LEASE_EXTEND_DURATION_FACTOR - if interval < 1: - return - info = LeaseInfo( + self._lease_manager.track( task_id=task.task_id, workflow_instance_id=task.workflow_instance_id, response_timeout_seconds=timeout, - last_heartbeat_time=time.monotonic(), - interval_seconds=interval, - ) - with self._lease_lock: - self._lease_info[task.task_id] = info - logger.debug( - "Tracking lease for task %s (timeout=%ss, heartbeat every %ss)", - task.task_id, timeout, interval, + task_client=self.task_client, ) + with self._tracked_task_ids_lock: + self._tracked_task_ids.add(task.task_id) def _untrack_lease(self, task_id: str) -> None: """Stop tracking a task for lease extension.""" - with self._lease_lock: - removed = self._lease_info.pop(task_id, None) - if removed is not None: - logger.debug("Untracked lease for task %s", task_id) - - def _send_due_heartbeats(self) -> None: - """Check all tracked tasks and send heartbeats for any that are due.""" - if not self._lease_info: - return - now = time.monotonic() - with self._lease_lock: - infos = list(self._lease_info.values()) - for info in infos: - elapsed = now - info.last_heartbeat_time - if elapsed < info.interval_seconds: - continue - self._send_heartbeat(info) - info.last_heartbeat_time = time.monotonic() - - def _send_heartbeat(self, info: LeaseInfo) -> None: - """Send a single lease extension heartbeat with retry.""" - result = TaskResult( - task_id=info.task_id, - workflow_instance_id=info.workflow_instance_id, - extend_lease=True, - ) - for attempt in range(LEASE_EXTEND_RETRY_COUNT): - try: - self.task_client.update_task(body=result) - logger.debug("Extended lease for task %s", info.task_id) - return - except Exception as e: - if attempt < LEASE_EXTEND_RETRY_COUNT - 1: - time.sleep(0.5 * (attempt + 2)) - else: - logger.error( - "Failed to extend lease for task %s after %d attempts: %s", - info.task_id, LEASE_EXTEND_RETRY_COUNT, e, - ) + self._lease_manager.untrack(task_id) + with self._tracked_task_ids_lock: + self._tracked_task_ids.discard(task_id) # -------------------------------------------------------------------------- diff --git a/tests/unit/automator/test_lease_manager.py b/tests/unit/automator/test_lease_manager.py new file mode 100644 index 00000000..9c25d37e --- /dev/null +++ b/tests/unit/automator/test_lease_manager.py @@ -0,0 +1,320 @@ +"""Tests for the centralized LeaseManager.""" + +import threading +import time +import unittest +from unittest.mock import MagicMock, call, patch + +from conductor.client.automator.lease_tracker import ( + LeaseManager, + LeaseInfo, + LEASE_EXTEND_DURATION_FACTOR, + LEASE_EXTEND_RETRY_COUNT, +) + + +class TestLeaseManagerTrackUntrack(unittest.TestCase): + """Test track/untrack operations.""" + + def setUp(self): + LeaseManager._reset_instance() + self.manager = LeaseManager(check_interval=60) # Long interval — we trigger manually + + def tearDown(self): + self.manager.shutdown() + LeaseManager._reset_instance() + + def test_track_adds_task(self): + client = MagicMock() + self.manager.track('task-1', 'wf-1', 30.0, client) + self.assertEqual(self.manager.tracked_count, 1) + + def test_untrack_removes_task(self): + client = MagicMock() + self.manager.track('task-1', 'wf-1', 30.0, client) + self.manager.untrack('task-1') + self.assertEqual(self.manager.tracked_count, 0) + + def test_untrack_nonexistent_is_noop(self): + self.manager.untrack('nonexistent') + self.assertEqual(self.manager.tracked_count, 0) + + def test_track_skips_short_interval(self): + """Tasks with response_timeout < ~1.25s (interval < 1s) should be skipped.""" + client = MagicMock() + self.manager.track('task-1', 'wf-1', 1.0, client) # 1.0 * 0.8 = 0.8 < 1 + self.assertEqual(self.manager.tracked_count, 0) + + def test_track_accepts_valid_timeout(self): + client = MagicMock() + self.manager.track('task-1', 'wf-1', 10.0, client) # 10 * 0.8 = 8.0 >= 1 + self.assertEqual(self.manager.tracked_count, 1) + + def test_track_multiple_tasks(self): + client = MagicMock() + for i in range(10): + self.manager.track(f'task-{i}', f'wf-{i}', 30.0, client) + self.assertEqual(self.manager.tracked_count, 10) + + def test_track_overwrites_existing(self): + client = MagicMock() + self.manager.track('task-1', 'wf-1', 30.0, client) + self.manager.track('task-1', 'wf-1', 60.0, client) + self.assertEqual(self.manager.tracked_count, 1) + + +class TestLeaseManagerHeartbeat(unittest.TestCase): + """Test heartbeat dispatch logic.""" + + def setUp(self): + LeaseManager._reset_instance() + self.manager = LeaseManager(check_interval=60) + + def tearDown(self): + self.manager.shutdown() + LeaseManager._reset_instance() + + def test_heartbeat_sent_when_due(self): + """Heartbeat should be dispatched when interval has elapsed.""" + client = MagicMock() + self.manager.track('task-1', 'wf-1', 10.0, client) + + # Fast-forward: set last_heartbeat_time to the past + with self.manager._lock: + info = self.manager._tracked['task-1'] + info.last_heartbeat_time = time.monotonic() - 20 # Well past the 8s interval + + self.manager._check_and_send() + + # Wait for the pool thread to execute the heartbeat + self.manager._executor.shutdown(wait=True) + client.update_task.assert_called_once() + result = client.update_task.call_args[1]['body'] + self.assertEqual(result.task_id, 'task-1') + self.assertEqual(result.workflow_instance_id, 'wf-1') + self.assertTrue(result.extend_lease) + + def test_heartbeat_not_sent_when_not_due(self): + """Heartbeat should NOT be dispatched when interval hasn't elapsed.""" + client = MagicMock() + self.manager.track('task-1', 'wf-1', 10.0, client) + + self.manager._check_and_send() + + self.manager._executor.shutdown(wait=True) + client.update_task.assert_not_called() + + def test_heartbeat_retries_on_failure(self): + """Heartbeat should retry up to LEASE_EXTEND_RETRY_COUNT times.""" + client = MagicMock() + client.update_task.side_effect = Exception("server error") + + info = LeaseInfo( + task_id='task-1', + workflow_instance_id='wf-1', + response_timeout_seconds=30.0, + last_heartbeat_time=time.monotonic(), + interval_seconds=24.0, + task_client=client, + ) + + with patch('conductor.client.automator.lease_tracker.time.sleep'): + LeaseManager._send_heartbeat(info) + + self.assertEqual(client.update_task.call_count, LEASE_EXTEND_RETRY_COUNT) + + def test_heartbeat_stops_retrying_on_success(self): + """Heartbeat should stop retrying after a successful call.""" + client = MagicMock() + client.update_task.side_effect = [Exception("fail"), None] # Fail then succeed + + info = LeaseInfo( + task_id='task-1', + workflow_instance_id='wf-1', + response_timeout_seconds=30.0, + last_heartbeat_time=time.monotonic(), + interval_seconds=24.0, + task_client=client, + ) + + with patch('conductor.client.automator.lease_tracker.time.sleep'): + LeaseManager._send_heartbeat(info) + + self.assertEqual(client.update_task.call_count, 2) + + def test_multiple_tasks_heartbeats_dispatched_independently(self): + """Each due task gets its own heartbeat dispatch.""" + client_a = MagicMock() + client_b = MagicMock() + + self.manager.track('task-a', 'wf-a', 10.0, client_a) + self.manager.track('task-b', 'wf-b', 10.0, client_b) + + # Make both due + with self.manager._lock: + past = time.monotonic() - 20 + self.manager._tracked['task-a'].last_heartbeat_time = past + self.manager._tracked['task-b'].last_heartbeat_time = past + + self.manager._check_and_send() + self.manager._executor.shutdown(wait=True) + + client_a.update_task.assert_called_once() + client_b.update_task.assert_called_once() + + +class TestLeaseManagerNonBlocking(unittest.TestCase): + """Test that heartbeats don't block the caller.""" + + def setUp(self): + LeaseManager._reset_instance() + + def tearDown(self): + LeaseManager._reset_instance() + + def test_poll_loop_not_blocked_by_slow_heartbeat(self): + """The caller should return immediately even if heartbeat is slow.""" + slow_client = MagicMock() + slow_client.update_task.side_effect = lambda **kw: time.sleep(2) + + manager = LeaseManager(check_interval=60) + manager.track('task-1', 'wf-1', 10.0, slow_client) + + with manager._lock: + manager._tracked['task-1'].last_heartbeat_time = time.monotonic() - 20 + + start = time.monotonic() + manager._check_and_send() # Submits to pool, returns immediately + elapsed = time.monotonic() - start + + # _check_and_send should return in < 100ms (it just submits to the pool) + self.assertLess(elapsed, 0.1, "check_and_send blocked for too long") + + manager.shutdown() + + +class TestLeaseManagerSingleton(unittest.TestCase): + """Test singleton behavior.""" + + def setUp(self): + LeaseManager._reset_instance() + + def tearDown(self): + LeaseManager._reset_instance() + + def test_get_instance_returns_same_object(self): + a = LeaseManager.get_instance() + b = LeaseManager.get_instance() + self.assertIs(a, b) + a.shutdown() + + def test_reset_creates_new_instance(self): + a = LeaseManager.get_instance() + LeaseManager._reset_instance() + b = LeaseManager.get_instance() + self.assertIsNot(a, b) + b.shutdown() + + @patch('conductor.client.automator.lease_tracker.os.getpid') + def test_new_instance_after_fork(self, mock_getpid): + """After fork (different PID), a fresh instance should be created.""" + mock_getpid.return_value = 1000 + a = LeaseManager.get_instance() + + mock_getpid.return_value = 2000 # Simulate fork + b = LeaseManager.get_instance() + + self.assertIsNot(a, b) + a.shutdown() + b.shutdown() + + +class TestLeaseManagerBackgroundThread(unittest.TestCase): + """Test the background thread lifecycle.""" + + def setUp(self): + LeaseManager._reset_instance() + + def tearDown(self): + LeaseManager._reset_instance() + + def test_thread_starts_lazily_on_first_track(self): + manager = LeaseManager(check_interval=60) + self.assertFalse(manager._started) + + client = MagicMock() + manager.track('task-1', 'wf-1', 10.0, client) + self.assertTrue(manager._started) + self.assertTrue(manager._thread.is_alive()) + + manager.shutdown() + + def test_thread_not_started_if_no_tracks(self): + manager = LeaseManager(check_interval=60) + self.assertFalse(manager._started) + manager.shutdown() + + def test_background_thread_sends_heartbeats(self): + """Verify the background thread actually dispatches heartbeats.""" + client = MagicMock() + manager = LeaseManager(check_interval=0.1) # Check every 100ms + + manager.track('task-1', 'wf-1', 10.0, client) + + # Make it due + with manager._lock: + manager._tracked['task-1'].last_heartbeat_time = time.monotonic() - 20 + + # Wait for background thread to pick it up + time.sleep(0.5) + + manager.shutdown() + client.update_task.assert_called() + + def test_shutdown_stops_thread(self): + manager = LeaseManager(check_interval=0.1) + client = MagicMock() + manager.track('task-1', 'wf-1', 10.0, client) + self.assertTrue(manager._thread.is_alive()) + + manager.shutdown() + self.assertFalse(manager._thread.is_alive()) + + +class TestLeaseManagerThreadSafety(unittest.TestCase): + """Test concurrent track/untrack operations.""" + + def setUp(self): + LeaseManager._reset_instance() + + def tearDown(self): + LeaseManager._reset_instance() + + def test_concurrent_track_untrack(self): + """Many threads tracking/untracking should not corrupt state.""" + manager = LeaseManager(check_interval=60) + client = MagicMock() + errors = [] + + def track_and_untrack(thread_id): + try: + for i in range(50): + task_id = f'task-{thread_id}-{i}' + manager.track(task_id, f'wf-{thread_id}', 30.0, client) + manager.untrack(task_id) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=track_and_untrack, args=(t,)) for t in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(errors, []) + self.assertEqual(manager.tracked_count, 0) + manager.shutdown() + + +if __name__ == '__main__': + unittest.main()