diff --git a/durabletask/worker.py b/durabletask/worker.py index 681e2df..9606dbb 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -346,7 +346,7 @@ def __init__( else: self._interceptors = None - self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options) + self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger) @property def concurrency_options(self) -> ConcurrencyOptions: @@ -533,6 +533,7 @@ def stream_reader(): if work_item.HasField("orchestratorRequest"): self._async_worker_manager.submit_orchestration( self._execute_orchestrator, + self._cancel_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken, @@ -540,6 +541,7 @@ def stream_reader(): elif work_item.HasField("activityRequest"): self._async_worker_manager.submit_activity( self._execute_activity, + self._cancel_activity, work_item.activityRequest, stub, work_item.completionToken, @@ -547,6 +549,7 @@ def stream_reader(): elif work_item.HasField("entityRequest"): self._async_worker_manager.submit_entity_batch( self._execute_entity_batch, + self._cancel_entity_batch, work_item.entityRequest, stub, work_item.completionToken, @@ -554,6 +557,7 @@ def stream_reader(): elif work_item.HasField("entityRequestV2"): self._async_worker_manager.submit_entity_batch( self._execute_entity_batch, + self._cancel_entity_batch, work_item.entityRequestV2, stub, work_item.completionToken @@ -670,6 +674,19 @@ def _execute_orchestrator( f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}" ) + def _cancel_orchestrator( + self, + req: pb.OrchestratorRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): + stub.AbandonTaskOrchestratorWorkItem( + pb.AbandonOrchestrationTaskRequest( + completionToken=completionToken + ) + ) + self._logger.info(f"Cancelled orchestration task for invocation ID: {req.instanceId}") + def _execute_activity( self, req: pb.ActivityRequest, @@ -703,6 +720,19 @@ def _execute_activity( f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" ) + def _cancel_activity( + self, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): + stub.AbandonTaskActivityWorkItem( + pb.AbandonActivityTaskRequest( + completionToken=completionToken + ) + ) + self._logger.info(f"Cancelled activity task for task ID: {req.taskId} on orchestration ID: {req.orchestrationInstance.instanceId}") + def _execute_entity_batch( self, req: Union[pb.EntityBatchRequest, pb.EntityRequest], @@ -771,6 +801,19 @@ def _execute_entity_batch( return batch_result + def _cancel_entity_batch( + self, + req: Union[pb.EntityBatchRequest, pb.EntityRequest], + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): + stub.AbandonTaskEntityWorkItem( + pb.AbandonEntityTaskRequest( + completionToken=completionToken + ) + ) + self._logger.info(f"Cancelled entity batch task for instance ID: {req.instanceId}") + class _RuntimeOrchestrationContext(task.OrchestrationContext): _generator: Optional[Generator[task.Task, Any, Any]] @@ -1933,8 +1976,10 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool: class _AsyncWorkerManager: - def __init__(self, concurrency_options: ConcurrencyOptions): + def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger): self.concurrency_options = concurrency_options + self._logger = logger + self.activity_semaphore = None self.orchestration_semaphore = None self.entity_semaphore = None @@ -2044,17 +2089,51 @@ async def run(self): ) # Start background consumers for each work type - if self.activity_queue is not None and self.orchestration_queue is not None \ - and self.entity_batch_queue is not None: - await asyncio.gather( - self._consume_queue(self.activity_queue, self.activity_semaphore), - self._consume_queue( - self.orchestration_queue, self.orchestration_semaphore - ), - self._consume_queue( - self.entity_batch_queue, self.entity_semaphore + try: + if self.activity_queue is not None and self.orchestration_queue is not None \ + and self.entity_batch_queue is not None: + await asyncio.gather( + self._consume_queue(self.activity_queue, self.activity_semaphore), + self._consume_queue( + self.orchestration_queue, self.orchestration_semaphore + ), + self._consume_queue( + self.entity_batch_queue, self.entity_semaphore + ) ) - ) + except Exception as queue_exception: + self._logger.error(f"Shutting down worker - Uncaught error in worker manager: {queue_exception}") + while self.activity_queue is not None and not self.activity_queue.empty(): + try: + func, cancellation_func, args, kwargs = self.activity_queue.get_nowait() + await self._run_func(cancellation_func, *args, **kwargs) + self._logger.error(f"Activity work item args: {args}, kwargs: {kwargs}") + except asyncio.QueueEmpty: + # Queue was empty, no cancellation needed + pass + except Exception as cancellation_exception: + self._logger.error(f"Uncaught error while cancelling activity work item: {cancellation_exception}") + while self.orchestration_queue is not None and not self.orchestration_queue.empty(): + try: + func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait() + await self._run_func(cancellation_func, *args, **kwargs) + self._logger.error(f"Orchestration work item args: {args}, kwargs: {kwargs}") + except asyncio.QueueEmpty: + # Queue was empty, no cancellation needed + pass + except Exception as cancellation_exception: + self._logger.error(f"Uncaught error while cancelling orchestration work item: {cancellation_exception}") + while self.entity_batch_queue is not None and not self.entity_batch_queue.empty(): + try: + func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait() + await self._run_func(cancellation_func, *args, **kwargs) + self._logger.error(f"Entity batch work item args: {args}, kwargs: {kwargs}") + except asyncio.QueueEmpty: + # Queue was empty, no cancellation needed + pass + except Exception as cancellation_exception: + self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}") + self.shutdown() async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): # List to track running tasks @@ -2074,19 +2153,22 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor except asyncio.TimeoutError: continue - func, args, kwargs = work + func, cancellation_func, args, kwargs = work # Create a concurrent task for processing task = asyncio.create_task( - self._process_work_item(semaphore, queue, func, args, kwargs) + self._process_work_item(semaphore, queue, func, cancellation_func, args, kwargs) ) running_tasks.add(task) async def _process_work_item( - self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs + self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, cancellation_func, args, kwargs ): async with semaphore: try: await self._run_func(func, *args, **kwargs) + except Exception as work_exception: + self._logger.error(f"Uncaught error while processing work item, item will be abandoned: {work_exception}") + await self._run_func(cancellation_func, *args, **kwargs) finally: queue.task_done() @@ -2105,8 +2187,10 @@ async def _run_func(self, func, *args, **kwargs): self.thread_pool, lambda: func(*args, **kwargs) ) - def submit_activity(self, func, *args, **kwargs): - work_item = (func, args, kwargs) + def submit_activity(self, func, cancellation_func, *args, **kwargs): + if self._shutdown: + raise RuntimeError("Cannot submit new work items after shutdown has been initiated.") + work_item = (func, cancellation_func, args, kwargs) self._ensure_queues_for_current_loop() if self.activity_queue is not None: self.activity_queue.put_nowait(work_item) @@ -2114,8 +2198,10 @@ def submit_activity(self, func, *args, **kwargs): # No event loop running, store in pending list self._pending_activity_work.append(work_item) - def submit_orchestration(self, func, *args, **kwargs): - work_item = (func, args, kwargs) + def submit_orchestration(self, func, cancellation_func, *args, **kwargs): + if self._shutdown: + raise RuntimeError("Cannot submit new work items after shutdown has been initiated.") + work_item = (func, cancellation_func, args, kwargs) self._ensure_queues_for_current_loop() if self.orchestration_queue is not None: self.orchestration_queue.put_nowait(work_item) @@ -2123,8 +2209,10 @@ def submit_orchestration(self, func, *args, **kwargs): # No event loop running, store in pending list self._pending_orchestration_work.append(work_item) - def submit_entity_batch(self, func, *args, **kwargs): - work_item = (func, args, kwargs) + def submit_entity_batch(self, func, cancellation_func, *args, **kwargs): + if self._shutdown: + raise RuntimeError("Cannot submit new work items after shutdown has been initiated.") + work_item = (func, cancellation_func, args, kwargs) self._ensure_queues_for_current_loop() if self.entity_batch_queue is not None: self.entity_batch_queue.put_nowait(work_item) @@ -2136,7 +2224,7 @@ def shutdown(self): self._shutdown = True self.thread_pool.shutdown(wait=True) - def reset_for_new_run(self): + async def reset_for_new_run(self): """Reset the manager state for a new run.""" self._shutdown = False # Clear any existing queues - they'll be recreated when needed @@ -2145,18 +2233,28 @@ def reset_for_new_run(self): # This ensures no items from previous runs remain try: while not self.activity_queue.empty(): - self.activity_queue.get_nowait() - except Exception: - pass + func, cancellation_func, args, kwargs = self.activity_queue.get_nowait() + await self._run_func(cancellation_func, *args, **kwargs) + except Exception as reset_exception: + self._logger.warning(f"Error while clearing activity queue during reset: {reset_exception}") if self.orchestration_queue is not None: try: while not self.orchestration_queue.empty(): - self.orchestration_queue.get_nowait() - except Exception: - pass + func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait() + await self._run_func(cancellation_func, *args, **kwargs) + except Exception as reset_exception: + self._logger.warning(f"Error while clearing orchestration queue during reset: {reset_exception}") + if self.entity_batch_queue is not None: + try: + while not self.entity_batch_queue.empty(): + func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait() + await self._run_func(cancellation_func, *args, **kwargs) + except Exception as reset_exception: + self._logger.warning(f"Error while clearing entity queue during reset: {reset_exception}") # Clear pending work lists self._pending_activity_work.clear() self._pending_orchestration_work.clear() + self._pending_entity_batch_work.clear() # Export public API diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py index de6753b..6fd1270 100644 --- a/tests/durabletask/test_worker_concurrency_loop.py +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -52,13 +52,21 @@ def dummy_orchestrator(req, stub, completionToken): time.sleep(0.1) stub.CompleteOrchestratorTask('ok') + def cancel_dummy_orchestrator(req, stub, completionToken): + pass + def dummy_activity(req, stub, completionToken): time.sleep(0.1) stub.CompleteActivityTask('ok') + def cancel_dummy_activity(req, stub, completionToken): + pass + # Patch the worker's _execute_orchestrator and _execute_activity worker._execute_orchestrator = dummy_orchestrator + worker._cancel_orchestrator = cancel_dummy_orchestrator worker._execute_activity = dummy_activity + worker._cancel_activity = cancel_dummy_activity orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] @@ -67,9 +75,9 @@ async def run_test(): # Start the worker manager's run loop in the background worker_task = asyncio.create_task(worker._async_worker_manager.run()) for req in orchestrator_requests: - worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken()) for req in activity_requests: - worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken()) await asyncio.sleep(1.0) orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') activity_count = sum(1 for t, _ in stub.completed if t == 'activity') @@ -120,8 +128,8 @@ def fn(*args, **kwargs): # Submit more work than concurrency allows for i in range(5): - manager.submit_orchestration(make_work("orch", i)) - manager.submit_activity(make_work("act", i)) + manager.submit_orchestration(make_work("orch", i), lambda *a, **k: None) + manager.submit_activity(make_work("act", i), lambda *a, **k: None) # Run the manager loop in a thread (sync context) def run_manager(): @@ -131,6 +139,11 @@ def run_manager(): t.start() time.sleep(1.5) # Let work process manager.shutdown() + + # Ensure the queues have been started + if (manager.activity_queue is None or manager.orchestration_queue is None): + raise RuntimeError("Worker manager queues not initialized") + # Unblock the consumers by putting dummy items in the queues manager.activity_queue.put_nowait((lambda: None, (), {})) manager.orchestration_queue.put_nowait((lambda: None, (), {})) diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py index c7ba238..8482c20 100644 --- a/tests/durabletask/test_worker_concurrency_loop_async.py +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -50,13 +50,21 @@ async def dummy_orchestrator(req, stub, completionToken): await asyncio.sleep(0.1) stub.CompleteOrchestratorTask('ok') + async def cancel_dummy_orchestrator(req, stub, completionToken): + pass + async def dummy_activity(req, stub, completionToken): await asyncio.sleep(0.1) stub.CompleteActivityTask('ok') + async def cancel_dummy_activity(req, stub, completionToken): + pass + # Patch the worker's _execute_orchestrator and _execute_activity - grpc_worker._execute_orchestrator = dummy_orchestrator - grpc_worker._execute_activity = dummy_activity + grpc_worker._execute_orchestrator = dummy_orchestrator.__get__(grpc_worker, TaskHubGrpcWorker) + grpc_worker._cancel_orchestrator = cancel_dummy_orchestrator.__get__(grpc_worker, TaskHubGrpcWorker) + grpc_worker._execute_activity = dummy_activity.__get__(grpc_worker, TaskHubGrpcWorker) + grpc_worker._cancel_activity = cancel_dummy_activity.__get__(grpc_worker, TaskHubGrpcWorker) orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] @@ -65,10 +73,15 @@ async def run_test(): # Clear stub state before each run stub.completed.clear() worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) + # Need to yield to that thread in order to let it start up on the second run + startup_attempts = 0 + while grpc_worker._async_worker_manager._shutdown and startup_attempts < 10: + await asyncio.sleep(0.1) + startup_attempts += 1 for req in orchestrator_requests: - grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken()) for req in activity_requests: - grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + grpc_worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken()) await asyncio.sleep(1.0) orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') activity_count = sum(1 for t, _ in stub.completed if t == 'activity')