diff --git a/clients/python/src/examples/cli.py b/clients/python/src/examples/cli.py index 02ce6e40..4c9fdc60 100644 --- a/clients/python/src/examples/cli.py +++ b/clients/python/src/examples/cli.py @@ -83,22 +83,34 @@ def scheduler() -> None: type=int, ) def worker(rpc_host: str, concurrency: int, push_mode: bool, grpc_port: int) -> None: - from taskbroker_client.worker import TaskWorker + from taskbroker_client.worker import PushTaskWorker, TaskWorker click.echo("Starting worker") - worker = TaskWorker( - app_module="examples.app:app", - broker_hosts=[rpc_host], - max_child_task_count=100, - concurrency=concurrency, - child_tasks_queue_maxsize=concurrency * 2, - result_queue_maxsize=concurrency * 2, - rebalance_after=32, - processing_pool_name="examples", - process_type="forkserver", - push_mode=push_mode, - grpc_port=grpc_port, - ) + if push_mode: + worker: PushTaskWorker | TaskWorker = PushTaskWorker( + app_module="examples.app:app", + broker_service=rpc_host, + max_child_task_count=100, + concurrency=concurrency, + child_tasks_queue_maxsize=concurrency * 2, + result_queue_maxsize=concurrency * 2, + rebalance_after=32, + processing_pool_name="examples", + process_type="forkserver", + grpc_port=grpc_port, + ) + else: + worker = TaskWorker( + app_module="examples.app:app", + broker_hosts=[rpc_host], + max_child_task_count=100, + concurrency=concurrency, + child_tasks_queue_maxsize=concurrency * 2, + result_queue_maxsize=concurrency * 2, + rebalance_after=32, + processing_pool_name="examples", + process_type="forkserver", + ) exitcode = worker.start() raise SystemExit(exitcode) diff --git a/clients/python/src/taskbroker_client/worker/__init__.py b/clients/python/src/taskbroker_client/worker/__init__.py index 82fd7acf..d94f2b62 100644 --- a/clients/python/src/taskbroker_client/worker/__init__.py +++ b/clients/python/src/taskbroker_client/worker/__init__.py @@ -1,3 +1,3 @@ -from .worker import TaskWorker +from .worker import PushTaskWorker, TaskWorker -__all__ = ("TaskWorker",) +__all__ = ("TaskWorker", "PushTaskWorker") diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index 751b4c6c..2a1c56b2 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -486,3 +486,106 @@ def update_task( receive_timestamp=time.monotonic(), ) return None + + +class PushTaskbrokerClient: + """ + Taskworker RPC client wrapper + + Push brokers are a deployment so they don't need to be connected to individually. There is one service provided + that works for all the brokers. + """ + + def __init__( + self, + service: str, + application: str, + metrics: MetricsBackend, + health_check_settings: HealthCheckSettings | None = None, + rpc_secret: str | None = None, + grpc_config: str | None = None, + ) -> None: + self._application = application + self._service = service + self._rpc_secret = rpc_secret + self._metrics = metrics + + self._grpc_options: list[tuple[str, Any]] = [ + ("grpc.max_receive_message_length", MAX_ACTIVATION_SIZE) + ] + if grpc_config: + self._grpc_options.append(("grpc.service_config", grpc_config)) + + logger.info( + "taskworker.push_client.start", + extra={"service": service, "options": self._grpc_options}, + ) + + self._stub = self._connect_to_host(service) + + self._health_check_settings = health_check_settings + self._timestamp_since_touch_lock = threading.Lock() + self._timestamp_since_touch = 0.0 + + def _emit_health_check(self) -> None: + if self._health_check_settings is None: + return + + with self._timestamp_since_touch_lock: + cur_time = time.time() + if ( + cur_time - self._timestamp_since_touch + < self._health_check_settings.touch_interval_sec + ): + return + + self._health_check_settings.file_path.touch() + self._metrics.incr( + "taskworker.client.health_check.touched", + ) + self._timestamp_since_touch = cur_time + + def _connect_to_host(self, host: str) -> ConsumerServiceStub: + logger.info("taskworker.push_client.connect", extra={"host": host}) + channel = grpc.insecure_channel(host, options=self._grpc_options) + secrets = parse_rpc_secret_list(self._rpc_secret) + if secrets: + channel = grpc.intercept_channel(channel, RequestSignatureInterceptor(secrets)) + return ConsumerServiceStub(channel) + + def update_task( + self, + processing_result: ProcessingResult, + ) -> None: + """ + Update the status for a given task activation. + """ + self._emit_health_check() + + request = SetTaskStatusRequest( + id=processing_result.task_id, + status=processing_result.status, + fetch_next_task=None, + ) + + retries = 0 + exception = None + while retries < 3: + try: + with self._metrics.timer( + "taskworker.update_task.rpc", tags={"service": self._service} + ): + self._stub.SetTaskStatus(request) + exception = None + break + except grpc.RpcError as err: + exception = err + self._metrics.incr( + "taskworker.client.rpc_error", + tags={"method": "SetTaskStatus", "status": err.code().name}, + ) + finally: + retries += 1 + + if exception: + raise exception diff --git a/clients/python/src/taskbroker_client/worker/worker.py b/clients/python/src/taskbroker_client/worker/worker.py index 27aa5fca..b4441f6a 100644 --- a/clients/python/src/taskbroker_client/worker/worker.py +++ b/clients/python/src/taskbroker_client/worker/worker.py @@ -10,7 +10,7 @@ from multiprocessing.context import ForkContext, ForkServerContext, SpawnContext from multiprocessing.process import BaseProcess from pathlib import Path -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any, Callable, List import grpc from sentry_protos.taskbroker.v1 import taskbroker_pb2_grpc @@ -31,6 +31,7 @@ from taskbroker_client.worker.client import ( HealthCheckSettings, HostTemporarilyUnavailable, + PushTaskbrokerClient, RequestSignatureServerInterceptor, TaskbrokerClient, parse_rpc_secret_list, @@ -51,8 +52,8 @@ class WorkerServicer(taskbroker_pb2_grpc.WorkerServiceServicer): gRPC servicer that receives task activations pushed from the broker """ - def __init__(self, worker: TaskWorker) -> None: - self.worker = worker + def __init__(self, worker: TaskWorkerProcessingPool) -> None: + self.worker_pool = worker def PushTask( self, @@ -61,9 +62,9 @@ def PushTask( ) -> PushTaskResponse: """Handle incoming task activation.""" start_time = time.monotonic() - self.worker._metrics.incr( + self.worker_pool._metrics.incr( "taskworker.worker.push_rpc", - tags={"result": "attempt", "processing_pool": self.worker._processing_pool_name}, + tags={"result": "attempt", "processing_pool": self.worker_pool._processing_pool_name}, ) # Create `InflightTaskActivation` from the pushed task @@ -74,51 +75,51 @@ def PushTask( ) # Push the task to the worker queue (wait at most 5 seconds) - if not self.worker.push_task(inflight, timeout=5): - self.worker._metrics.incr( + if not self.worker_pool.push_task(inflight, timeout=5): + self.worker_pool._metrics.incr( "taskworker.worker.push_rpc", - tags={"result": "busy", "processing_pool": self.worker._processing_pool_name}, + tags={"result": "busy", "processing_pool": self.worker_pool._processing_pool_name}, ) - self.worker._metrics.distribution( + self.worker_pool._metrics.distribution( "taskworker.worker.push_rpc.duration", time.monotonic() - start_time, - tags={"result": "busy", "processing_pool": self.worker._processing_pool_name}, + tags={"result": "busy", "processing_pool": self.worker_pool._processing_pool_name}, ) context.abort(grpc.StatusCode.RESOURCE_EXHAUSTED, "worker busy") + else: + self.worker_pool._metrics.incr( + "taskworker.worker.push_rpc", + tags={ + "result": "accepted", + "processing_pool": self.worker_pool._processing_pool_name, + }, + ) - self.worker._metrics.incr( - "taskworker.worker.push_rpc", - tags={"result": "accepted", "processing_pool": self.worker._processing_pool_name}, - ) - - self.worker._metrics.distribution( - "taskworker.worker.push_rpc.duration", - time.monotonic() - start_time, - tags={"result": "accepted", "processing_pool": self.worker._processing_pool_name}, - ) + self.worker_pool._metrics.distribution( + "taskworker.worker.push_rpc.duration", + time.monotonic() - start_time, + tags={ + "result": "accepted", + "processing_pool": self.worker_pool._processing_pool_name, + }, + ) return PushTaskResponse() -class TaskWorker: - """ - A TaskWorker fetches tasks from a taskworker RPC host and handles executing task activations. +class RequeueException(Exception): + pass - Tasks are executed in a forked process so that processing timeouts can be enforced. - As tasks are completed status changes will be sent back to the RPC host and new tasks - will be fetched. - - Taskworkers can be run with `sentry run taskworker` - """ - mp_context: ForkContext | SpawnContext | ForkServerContext +class PushTaskWorker: + _mp_context: ForkContext | SpawnContext | ForkServerContext def __init__( self, app_module: str, - broker_hosts: list[str], + broker_service: str, max_child_task_count: int | None = None, namespace: str | None = None, concurrency: int = 1, @@ -129,27 +130,37 @@ def __init__( process_type: str = "spawn", health_check_file_path: str | None = None, health_check_sec_per_touch: float = DEFAULT_WORKER_HEALTH_CHECK_SEC_PER_TOUCH, - push_mode: bool = False, grpc_port: int = 50052, - **kwargs: dict[str, Any], ) -> None: - self.options = kwargs - self._app_module = app_module - self._max_child_task_count = max_child_task_count - self._namespace = namespace - self._concurrency = concurrency app = import_app(app_module) - if push_mode: - logger.info("Running in PUSH mode") + if process_type == "fork": + self._mp_context = multiprocessing.get_context("fork") + elif process_type == "spawn": + self._mp_context = multiprocessing.get_context("spawn") + elif process_type == "forkserver": + self._mp_context = multiprocessing.get_context("forkserver") else: - logger.info("Running in PULL mode") + raise ValueError(f"Invalid process type: {process_type}") - self.client = TaskbrokerClient( - hosts=broker_hosts, + self.worker_pool = TaskWorkerProcessingPool( + app_module=app_module, + mp_context=self._mp_context, + send_result_fn=self._send_result, + max_child_task_count=max_child_task_count, + concurrency=concurrency, + child_tasks_queue_maxsize=child_tasks_queue_maxsize, + result_queue_maxsize=result_queue_maxsize, + processing_pool_name=processing_pool_name, + process_type=process_type, + ) + + logger.info("Running in PUSH mode") + + self.client = PushTaskbrokerClient( + service=broker_service, application=app.name, metrics=app.metrics, - max_tasks_before_rebalance=rebalance_after, health_check_settings=( None if health_check_file_path is None @@ -159,175 +170,230 @@ def __init__( grpc_config=app.config["grpc_config"], ) self._metrics = app.metrics + self._concurrency = concurrency + self._grpc_sync_event = self._mp_context.Event() - if process_type == "fork": - self.mp_context = multiprocessing.get_context("fork") - elif process_type == "spawn": - self.mp_context = multiprocessing.get_context("spawn") - elif process_type == "forkserver": - self.mp_context = multiprocessing.get_context("forkserver") - else: - raise ValueError(f"Invalid process type: {process_type}") - self._process_type = process_type - - self._child_tasks: multiprocessing.Queue[InflightTaskActivation] = self.mp_context.Queue( - maxsize=child_tasks_queue_maxsize - ) - self._processed_tasks: multiprocessing.Queue[ProcessingResult] = self.mp_context.Queue( - maxsize=result_queue_maxsize - ) - self._children: list[BaseProcess] = [] - self._shutdown_event = self.mp_context.Event() - self._result_thread: threading.Thread | None = None - self._spawn_children_thread: threading.Thread | None = None - - self._gettask_backoff_seconds = 0 self._setstatus_backoff_seconds = 0 self._processing_pool_name: str = processing_pool_name or "unknown" - self._push_mode = push_mode self._grpc_port = grpc_port self._grpc_secrets = parse_rpc_secret_list(app.config["rpc_secret"]) + def _send_result( + self, result: ProcessingResult, is_draining: bool = False + ) -> InflightTaskActivation | None: + """ + Send a result to the broker. If the set has failed before, sleep briefly before retrying. + """ + self._metrics.distribution( + "taskworker.worker.complete_duration", + time.monotonic() - result.receive_timestamp, + tags={"processing_pool": self._processing_pool_name}, + ) + + logger.debug( + "taskworker.workers._send_result", + extra={ + "task_id": result.task_id, + "next": False, # Push mode doesn't support fetching next tasks + "processing_pool": self._processing_pool_name, + }, + ) + # Use the shutdown_event as a sleep mechanism + self._grpc_sync_event.wait(self._setstatus_backoff_seconds) + + try: + self.client.update_task(result) + self._setstatus_backoff_seconds = 0 + return None + except grpc.RpcError as e: + self._setstatus_backoff_seconds = min(self._setstatus_backoff_seconds + 1, 10) + logger.warning( + "taskworker.send_update_task.failed", + extra={"task_id": result.task_id, "error": e}, + ) + if e.code() != grpc.StatusCode.NOT_FOUND: + # If the task was not found, we can't update it, so we should just return None + raise RequeueException(f"Failed to update task: {e}") + except HostTemporarilyUnavailable as e: + self._setstatus_backoff_seconds = min( + self._setstatus_backoff_seconds + 4, MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE + ) + logger.info( + "taskworker.send_update_task.temporarily_unavailable", + extra={"task_id": result.task_id, "error": str(e)}, + ) + raise RequeueException(f"Failed to update task: {e}") + + return None + def start(self) -> int: """ - When in PULL mode, this starts a loop that runs until the worker completes its `max_task_count` or it is killed. - When in PUSH mode, this starts the worker gRPC server. + This starts the worker gRPC server. """ - self.start_result_thread() - self.start_spawn_children_thread() + self.worker_pool.start_result_thread() + self.worker_pool.start_spawn_children_thread() # Convert signals into KeyboardInterrupt. # Running shutdown() within the signal handler can lead to deadlocks + + server: grpc.Server | None = None + def signal_handler(*args: Any) -> None: + if server: + server.stop(grace=5) raise KeyboardInterrupt() signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - if self._push_mode: - server = None + try: + # Start gRPC server + interceptors: List[ServerInterceptor] = [] + + if self._grpc_secrets: + interceptors = [RequestSignatureServerInterceptor(self._grpc_secrets)] + + server = grpc.server( + ThreadPoolExecutor(max_workers=self._concurrency), + interceptors=interceptors, + ) + + taskbroker_pb2_grpc.add_WorkerServiceServicer_to_server( + WorkerServicer(self.worker_pool), server + ) + server.add_insecure_port(f"[::]:{self._grpc_port}") + server.start() + logger.info("taskworker.grpc_server.started", extra={"port": self._grpc_port}) try: - # Start gRPC server - interceptors: List[ServerInterceptor] = [] - - if self._grpc_secrets: - interceptors = [RequestSignatureServerInterceptor(self._grpc_secrets)] - - server = grpc.server( - ThreadPoolExecutor(max_workers=self._concurrency), - interceptors=interceptors, - ) - - taskbroker_pb2_grpc.add_WorkerServiceServicer_to_server( - WorkerServicer(self), server - ) - server.add_insecure_port(f"[::]:{self._grpc_port}") - server.start() - logger.info("taskworker.grpc_server.started", extra={"port": self._grpc_port}) - - try: - server.wait_for_termination() - except KeyboardInterrupt: - # Signals are converted to KeyboardInterrupt, swallow for exit code 0 - pass - - finally: - if server is not None: - server.stop(grace=5) - - self.shutdown() - else: - try: - while True: - self.run_once() + server.wait_for_termination() except KeyboardInterrupt: - self.shutdown() - raise + # Signals are converted to KeyboardInterrupt, swallow for exit code 0 + pass - return 0 + finally: + if server is not None: + server.stop(grace=5) - def run_once(self) -> None: - """Access point for tests to run a single worker loop""" - self._add_task() + self.worker_pool.shutdown() + + return 0 def shutdown(self) -> None: """ - Shutdown cleanly - Activate the shutdown event and drain results before terminating children. + Shutdown the worker. """ - logger.info("taskworker.worker.shutdown.start") - self._shutdown_event.set() + self._grpc_sync_event.set() + self.worker_pool.shutdown() - logger.info("taskworker.worker.shutdown.spawn_children") - if self._spawn_children_thread: - self._spawn_children_thread.join() - logger.info("taskworker.worker.shutdown.children") - for child in self._children: - child.terminate() - for child in self._children: - child.join() +class TaskWorker: + """ + A TaskWorker fetches tasks from a taskworker RPC host and handles executing task activations. - logger.info("taskworker.worker.shutdown.result") - if self._result_thread: - # Use a timeout as sometimes this thread can deadlock on the Event. - self._result_thread.join(timeout=5) + Tasks are executed in a forked/spawned/forkserver process so that processing timeouts can be enforced. + As tasks are completed status changes will be sent back to the RPC host and new tasks + will be fetched. + """ - # Drain any remaining results synchronously - while True: - try: - result = self._processed_tasks.get_nowait() - self._send_result(result, fetch=False) - except queue.Empty: - break + _mp_context: ForkContext | SpawnContext | ForkServerContext - logger.info("taskworker.worker.shutdown.complete") + def __init__( + self, + app_module: str, + broker_hosts: list[str], + max_child_task_count: int | None = None, + namespace: str | None = None, + concurrency: int = 1, + child_tasks_queue_maxsize: int = DEFAULT_WORKER_QUEUE_SIZE, + result_queue_maxsize: int = DEFAULT_WORKER_QUEUE_SIZE, + rebalance_after: int = DEFAULT_REBALANCE_AFTER, + processing_pool_name: str | None = None, + process_type: str = "spawn", + health_check_file_path: str | None = None, + health_check_sec_per_touch: float = DEFAULT_WORKER_HEALTH_CHECK_SEC_PER_TOUCH, + ) -> None: + self._namespace = namespace + app = import_app(app_module) - def push_task(self, inflight: InflightTaskActivation, timeout: float | None = None) -> bool: - """ - Push a task to child tasks queue. + if process_type == "fork": + self._mp_context = multiprocessing.get_context("fork") + elif process_type == "spawn": + self._mp_context = multiprocessing.get_context("spawn") + elif process_type == "forkserver": + self._mp_context = multiprocessing.get_context("forkserver") + else: + raise ValueError(f"Invalid process type: {process_type}") - When timeout is `None`, blocks until the queue has space. When timeout is - set (e.g. 5.0), waits at most that many seconds and returns `False` if the - queue is still full (worker busy). + self.worker_pool = TaskWorkerProcessingPool( + app_module=app_module, + mp_context=self._mp_context, + send_result_fn=self._send_result, + max_child_task_count=max_child_task_count, + concurrency=concurrency, + child_tasks_queue_maxsize=child_tasks_queue_maxsize, + result_queue_maxsize=result_queue_maxsize, + processing_pool_name=processing_pool_name, + process_type=process_type, + ) + + logger.info("Running in PULL mode") + + self.client = TaskbrokerClient( + hosts=broker_hosts, + application=app.name, + metrics=app.metrics, + max_tasks_before_rebalance=rebalance_after, + health_check_settings=( + None + if health_check_file_path is None + else HealthCheckSettings(Path(health_check_file_path), health_check_sec_per_touch) + ), + rpc_secret=app.config["rpc_secret"], + grpc_config=app.config["grpc_config"], + ) + self._metrics = app.metrics + + self._grpc_sync_event = self._mp_context.Event() + + self._gettask_backoff_seconds = 0 + self._setstatus_backoff_seconds = 0 + + self._processing_pool_name: str = processing_pool_name or "unknown" + + def start(self) -> int: """ - try: - self._metrics.gauge("taskworker.child_tasks.size", self._child_tasks.qsize()) - except Exception as e: - # 'qsize' does not work on macOS - logger.debug("taskworker.child_tasks.size.error", extra={"error": e}) + This starts a loop that runs until the worker completes its `max_task_count` or it is killed. + """ + self.worker_pool.start_result_thread() + self.worker_pool.start_spawn_children_thread() + + # Convert signals into KeyboardInterrupt. + # Running shutdown() within the signal handler can lead to deadlocks + def signal_handler(*args: Any) -> None: + raise KeyboardInterrupt() + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) - start_time = time.monotonic() try: - self._child_tasks.put(inflight, timeout=timeout) - except queue.Full: - self._metrics.incr( - "taskworker.worker.push_task.busy", - tags={"processing_pool": self._processing_pool_name}, - ) - return False - self._metrics.distribution( - "taskworker.worker.child_task.put.duration", - time.monotonic() - start_time, - tags={"processing_pool": self._processing_pool_name}, - ) - return True + while True: + self.run_once() + except KeyboardInterrupt: + self.shutdown() + raise + + def run_once(self) -> None: + """Access point for tests to run a single worker loop""" + self._add_task() def _add_task(self) -> bool: """ Add a task to child tasks queue. Returns False if no new task was fetched. """ - if self._child_tasks.full(): - # I want to see how this differs between pools that operate well, - # and those that are not as effective. I suspect that with a consistent - # load of slowish tasks (like 5-15 seconds) that this will happen - # infrequently, resulting in the child tasks queue being full - # causing processing deadline expiration. - # Whereas in pools that have consistent short tasks, this happens - # more frequently, allowing workers to run more smoothly. + if self.worker_pool.is_worker_full(): self._metrics.incr( "taskworker.worker.add_tasks.child_tasks_full", tags={"processing_pool": self._processing_pool_name}, @@ -338,106 +404,24 @@ def _add_task(self) -> bool: inflight = self.fetch_task() if inflight: - try: - start_time = time.monotonic() - self._child_tasks.put(inflight) - self._metrics.distribution( - "taskworker.worker.child_task.put.duration", - time.monotonic() - start_time, - tags={"processing_pool": self._processing_pool_name}, - ) - except queue.Full: - self._metrics.incr( - "taskworker.worker.child_tasks.put.full", - tags={"processing_pool": self._processing_pool_name}, - ) - logger.warning( - "taskworker.add_task.child_task_queue_full", - extra={ - "task_id": inflight.activation.id, - "processing_pool": self._processing_pool_name, - }, - ) - return True - else: - return False - - def start_result_thread(self) -> None: - """ - Start a thread that delivers results and fetches new tasks. - We need to ship results in a thread because the RPC calls block for 20-50ms, - and many tasks execute more quickly than that. - - Without additional threads, we end up publishing results too slowly - and tasks accumulate in the `processed_tasks` queues and can cross - their processing deadline. - """ - - def result_thread() -> None: - logger.debug("taskworker.worker.result_thread.started") - iopool = ThreadPoolExecutor(max_workers=self._concurrency) - with iopool as executor: - while not self._shutdown_event.is_set(): - # TODO We should remove fetch_next = False from sentry as it couldn't be rolled - # out everywhere. - # fetch_next = self._processing_pool_name not in options.get( - # "taskworker.fetch_next.disabled_pools" - # ) - try: - result = self._processed_tasks.get(timeout=1.0) - executor.submit(self._send_result, result, fetch=True) - except queue.Empty: - self._metrics.incr( - "taskworker.worker.result_thread.queue_empty", - tags={"processing_pool": self._processing_pool_name}, - ) - continue + return self.worker_pool.push_task(inflight) - self._result_thread = threading.Thread( - name="send-result", target=result_thread, daemon=True - ) - self._result_thread.start() + return False - def _send_result(self, result: ProcessingResult, fetch: bool = True) -> bool: + def _send_result( + self, result: ProcessingResult, is_draining: bool = False + ) -> InflightTaskActivation | None: """ - Send a result to the broker and conditionally fetch an additional task - - Run in a thread to avoid blocking the process, and during shutdown/ - See `start_result_thread` + Send a result to the broker and conditionally fetch an additional task. Return a boolean indicating whether the result was sent successfully. """ self._metrics.distribution( "taskworker.worker.complete_duration", time.monotonic() - result.receive_timestamp, tags={"processing_pool": self._processing_pool_name}, ) - - if fetch: - fetch_next = None - if not self._child_tasks.full(): - fetch_next = FetchNextTask(namespace=self._namespace) - - next = self._send_update_task(result, fetch_next) - if next: - try: - start_time = time.monotonic() - self._child_tasks.put(next) - self._metrics.distribution( - "taskworker.worker.child_task.put.duration", - time.monotonic() - start_time, - tags={"processing_pool": self._processing_pool_name}, - ) - except queue.Full: - logger.warning( - "taskworker.send_result.child_task_queue_full", - extra={ - "task_id": next.activation.id, - "processing_pool": self._processing_pool_name, - }, - ) - return True - - self._send_update_task(result, fetch_next=None) - return True + fetch_next = None if is_draining else FetchNextTask(namespace=self._namespace) + next_task = self._send_update_task(result, fetch_next) + return next_task def _send_update_task( self, result: ProcessingResult, fetch_next: FetchNextTask | None @@ -453,8 +437,8 @@ def _send_update_task( "processing_pool": self._processing_pool_name, }, ) - # Use the shutdown_event as a sleep mechanism - self._shutdown_event.wait(self._setstatus_backoff_seconds) + + self._grpc_sync_event.wait(self._setstatus_backoff_seconds) try: next_task = self.client.update_task(result, fetch_next) @@ -462,13 +446,11 @@ def _send_update_task( return next_task except grpc.RpcError as e: self._setstatus_backoff_seconds = min(self._setstatus_backoff_seconds + 1, 10) - if e.code() == grpc.StatusCode.UNAVAILABLE: - self._processed_tasks.put(result) logger.warning( "taskworker.send_update_task.failed", extra={"task_id": result.task_id, "error": e}, ) - return None + raise RequeueException(f"Failed to update task: {e}") except HostTemporarilyUnavailable as e: self._setstatus_backoff_seconds = min( self._setstatus_backoff_seconds + 4, MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE @@ -477,9 +459,133 @@ def _send_update_task( "taskworker.send_update_task.temporarily_unavailable", extra={"task_id": result.task_id, "error": str(e)}, ) - self._processed_tasks.put(result) + raise RequeueException(f"Failed to update task: {e}") + + def fetch_task(self) -> InflightTaskActivation | None: + self._grpc_sync_event.wait(self._gettask_backoff_seconds) + try: + activation = self.client.get_task(self._namespace) + except grpc.RpcError as e: + logger.info( + "taskworker.fetch_task.failed", + extra={"error": e, "processing_pool": self._processing_pool_name}, + ) + + self._gettask_backoff_seconds = min( + self._gettask_backoff_seconds + 4, MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE + ) + return None + + if not activation: + self._metrics.incr( + "taskworker.worker.fetch_task.not_found", + tags={"processing_pool": self._processing_pool_name}, + ) + logger.debug( + "taskworker.fetch_task.not_found", + extra={"processing_pool": self._processing_pool_name}, + ) + self._gettask_backoff_seconds = min(self._gettask_backoff_seconds + 1, 5) return None + self._gettask_backoff_seconds = 0 + return activation + + def shutdown(self) -> None: + """ + Shutdown the worker. + """ + self._grpc_sync_event.set() + self.worker_pool.shutdown() + + +class TaskWorkerProcessingPool: + def __init__( + self, + app_module: str, + # Here the bool is used to indicate whether this is a normal fetch or is being called + # during shutdown. + send_result_fn: Callable[[ProcessingResult, bool], InflightTaskActivation | None], + mp_context: ForkContext | SpawnContext | ForkServerContext, + max_child_task_count: int | None = None, + concurrency: int = 1, + child_tasks_queue_maxsize: int = DEFAULT_WORKER_QUEUE_SIZE, + result_queue_maxsize: int = DEFAULT_WORKER_QUEUE_SIZE, + processing_pool_name: str | None = None, + process_type: str = "spawn", + ) -> None: + self._concurrency = concurrency + self._processing_pool_name = processing_pool_name or "unknown" + self._send_result = send_result_fn + self._max_child_task_count = max_child_task_count + self._app_module = app_module + app = import_app(app_module) + self._metrics = app.metrics + + self._mp_context = mp_context + self._process_type = process_type + + self._child_tasks: multiprocessing.Queue[InflightTaskActivation] = self._mp_context.Queue( + maxsize=child_tasks_queue_maxsize + ) + self._processed_tasks: multiprocessing.Queue[ProcessingResult] = self._mp_context.Queue( + maxsize=result_queue_maxsize + ) + self._children: list[BaseProcess] = [] + self._shutdown_event = self._mp_context.Event() + self._result_thread: threading.Thread | None = None + self._spawn_children_thread: threading.Thread | None = None + + def send_result(self, result: ProcessingResult, is_draining: bool = False) -> None: + """ + Call the passed in function. If is_draining is True, the function should not fetch a new task. + That function should return: + - An InflightTaskActivation if a new task was fetched + - None if no new task was fetched + - A RequeueException if the result failed to send and should be retried + """ + try: + worker_full = is_draining or self._child_tasks.full() + next_task = self._send_result(result, worker_full) + if next_task: + self.push_task(next_task) + except RequeueException: + logger.warning("activation status couldn't be updated") + # This can cause an infinite loop if we are draining and the result fails to send + if not is_draining: + self.put_result(result) + + def start_result_thread(self) -> None: + """ + Start a thread that delivers results and fetches new tasks. + We need to ship results in a thread because the RPC calls block for 20-50ms, + and many tasks execute more quickly than that. + + Without additional threads, we end up publishing results too slowly + and tasks accumulate in the `processed_tasks` queues and can cross + their processing deadline. + """ + + def result_thread() -> None: + logger.debug("taskworker.worker.result_thread.started") + iopool = ThreadPoolExecutor(max_workers=self._concurrency) + with iopool as executor: + while not self._shutdown_event.is_set(): + try: + result = self._processed_tasks.get(timeout=1.0) + executor.submit(self.send_result, result, False) + except queue.Empty: + self._metrics.incr( + "taskworker.worker.result_thread.queue_empty", + tags={"processing_pool": self._processing_pool_name}, + ) + continue + + self._result_thread = threading.Thread( + name="send-result", target=result_thread, daemon=True + ) + self._result_thread.start() + def start_spawn_children_thread(self) -> None: def spawn_children_thread() -> None: logger.debug("taskworker.worker.spawn_children_thread.started") @@ -489,7 +595,7 @@ def spawn_children_thread() -> None: time.sleep(0.1) continue for i in range(self._concurrency - len(self._children)): - process = self.mp_context.Process( + process = self._mp_context.Process( name=f"taskworker-child-{i}", target=child_process, args=( @@ -518,33 +624,79 @@ def spawn_children_thread() -> None: ) self._spawn_children_thread.start() - def fetch_task(self) -> InflightTaskActivation | None: - # Use the shutdown_event as a sleep mechanism - self._shutdown_event.wait(self._gettask_backoff_seconds) - try: - activation = self.client.get_task(self._namespace) - except grpc.RpcError as e: - logger.info( - "taskworker.fetch_task.failed", - extra={"error": e, "processing_pool": self._processing_pool_name}, - ) + def push_task(self, inflight: InflightTaskActivation, timeout: float | None = None) -> bool: + """ + Push a task to child tasks queue. - self._gettask_backoff_seconds = min( - self._gettask_backoff_seconds + 4, MAX_BACKOFF_SECONDS_WHEN_HOST_UNAVAILABLE - ) - return None + When timeout is `None`, blocks until the queue has space. When timeout is + set (e.g. 5.0), waits at most that many seconds and returns `False` if the + queue is still full (worker busy). + """ + try: + self._metrics.gauge("taskworker.child_tasks.size", self._child_tasks.qsize()) + except Exception as e: + # 'qsize' does not work on macOS + logger.debug("taskworker.child_tasks.size.error", extra={"error": e}) - if not activation: + start_time = time.monotonic() + try: + self._child_tasks.put(inflight, timeout=timeout) + except queue.Full: self._metrics.incr( - "taskworker.worker.fetch_task.not_found", + "taskworker.worker.child_tasks.put.full", tags={"processing_pool": self._processing_pool_name}, ) - logger.debug( - "taskworker.fetch_task.not_found", - extra={"processing_pool": self._processing_pool_name}, + logger.warning( + "taskworker.add_task.child_task_queue_full", + extra={ + "task_id": inflight.activation.id, + "processing_pool": self._processing_pool_name, + }, ) - self._gettask_backoff_seconds = min(self._gettask_backoff_seconds + 1, 5) - return None + return False - self._gettask_backoff_seconds = 0 - return activation + self._metrics.distribution( + "taskworker.worker.child_task.put.duration", + time.monotonic() - start_time, + tags={"processing_pool": self._processing_pool_name}, + ) + return True + + def is_worker_full(self) -> bool: + return self._child_tasks.full() + + def put_result(self, result: ProcessingResult) -> None: + self._processed_tasks.put(result) + + def shutdown(self) -> None: + """ + Shutdown cleanly + Activate the shutdown event and drain results before terminating children. + """ + logger.info("taskworker.worker.shutdown.start") + self._shutdown_event.set() + + logger.info("taskworker.worker.shutdown.spawn_children") + if self._spawn_children_thread: + self._spawn_children_thread.join() + + logger.info("taskworker.worker.shutdown.children") + for child in self._children: + child.terminate() + for child in self._children: + child.join() + + logger.info("taskworker.worker.shutdown.result") + if self._result_thread: + # Use a timeout as sometimes this thread can deadlock on the Event. + self._result_thread.join(timeout=5) + + # Drain any remaining results synchronously + while True: + try: + result = self._processed_tasks.get_nowait() + self.send_result(result, True) + except queue.Empty: + break + + logger.info("taskworker.worker.shutdown.complete") diff --git a/clients/python/tests/worker/test_worker.py b/clients/python/tests/worker/test_worker.py index 9e9f5f33..8ec90e50 100644 --- a/clients/python/tests/worker/test_worker.py +++ b/clients/python/tests/worker/test_worker.py @@ -3,7 +3,7 @@ import queue import time from collections.abc import MutableMapping -from multiprocessing import Event +from multiprocessing import Event, get_context from typing import Any from unittest import TestCase, mock @@ -30,7 +30,12 @@ from taskbroker_client.retry import NoRetriesRemainingError from taskbroker_client.state import current_task from taskbroker_client.types import InflightTaskActivation, ProcessingResult -from taskbroker_client.worker.worker import TaskWorker, WorkerServicer +from taskbroker_client.worker.worker import ( + PushTaskWorker, + TaskWorker, + TaskWorkerProcessingPool, + WorkerServicer, +) from taskbroker_client.worker.workerchild import ProcessingDeadlineExceeded, child_process SIMPLE_TASK = InflightTaskActivation( @@ -233,8 +238,8 @@ def test_run_once_no_next_task(self) -> None: # No next_task returned mock_client.update_task.return_value = None - taskworker.start_result_thread() - taskworker.start_spawn_children_thread() + taskworker.worker_pool.start_result_thread() + taskworker.worker_pool.start_spawn_children_thread() start = time.time() while True: taskworker.run_once() @@ -273,8 +278,8 @@ def update_task_response(*args: Any, **kwargs: Any) -> InflightTaskActivation | mock_client.update_task.side_effect = update_task_response mock_client.get_task.return_value = SIMPLE_TASK - taskworker.start_result_thread() - taskworker.start_spawn_children_thread() + taskworker.worker_pool.start_result_thread() + taskworker.worker_pool.start_spawn_children_thread() # Run until two tasks have been processed start = time.time() @@ -325,8 +330,8 @@ def get_task_response(*args: Any, **kwargs: Any) -> InflightTaskActivation | Non mock_client.update_task.side_effect = update_task_response mock_client.get_task.side_effect = get_task_response - taskworker.start_result_thread() - taskworker.start_spawn_children_thread() + taskworker.worker_pool.start_result_thread() + taskworker.worker_pool.start_spawn_children_thread() # Run until the update has 'completed' start = time.time() @@ -343,12 +348,16 @@ def get_task_response(*args: Any, **kwargs: Any) -> InflightTaskActivation | Non assert mock_client.update_task.call_count == 3 def test_push_task_queue(self) -> None: - taskworker = TaskWorker( + taskworker = TaskWorkerProcessingPool( app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], + send_result_fn=lambda x, y: None, + mp_context=get_context("fork"), max_child_task_count=100, - process_type="fork", + concurrency=1, child_tasks_queue_maxsize=2, + result_queue_maxsize=2, + processing_pool_name="test", + process_type="fork", ) # We can enqueue the first task @@ -380,13 +389,14 @@ def update_task_response(*args: Any, **kwargs: Any) -> None: mock_client.update_task.side_effect = update_task_response mock_client.get_task.return_value = RETRY_STATE_TASK - taskworker.start_result_thread() - taskworker.start_spawn_children_thread() + taskworker.worker_pool.start_result_thread() + taskworker.worker_pool.start_spawn_children_thread() # Run until two tasks have been processed start = time.time() while True: taskworker.run_once() + time.sleep(0.1) if mock_client.update_task.call_count >= 1: break if time.time() - start > max_runtime: @@ -412,48 +422,35 @@ def update_task_response(*args: Any, **kwargs: Any) -> None: redis.delete("no-retries-remaining") def test_constructor_push_mode(self) -> None: - taskworker = TaskWorker( + taskworker = PushTaskWorker( app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], + broker_service="127.0.0.1:50051", max_child_task_count=100, process_type="fork", - push_mode=True, grpc_port=50099, ) - # Make sure delivery mode and gRPC port arguments are stored - self.assertTrue(taskworker._push_mode) + self.assertTrue(taskworker.client is not None) self.assertEqual(taskworker._grpc_port, 50099) - def test_constructor_pull_mode(self) -> None: - taskworker = TaskWorker( - app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], - max_child_task_count=100, - process_type="fork", - ) - - # Make sure delivery mode and gRPC port are set to their defaults - self.assertFalse(taskworker._push_mode) - self.assertEqual(taskworker._grpc_port, 50052) - class TestWorkerServicer(TestCase): def test_push_task_success(self) -> None: - taskworker = TaskWorker( + taskworker = PushTaskWorker( app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], + broker_service="127.0.0.1:50051", max_child_task_count=100, process_type="fork", - push_mode=True, ) - with mock.patch.object(taskworker, "push_task", return_value=True) as mock_push_task: + with mock.patch.object( + taskworker.worker_pool, "push_task", return_value=True + ) as mock_push_task: request = PushTaskRequest( task=SIMPLE_TASK.activation, callback_url="broker-host:50051", ) mock_context = mock.MagicMock() - servicer = WorkerServicer(taskworker) + servicer = WorkerServicer(taskworker.worker_pool) response = servicer.PushTask(request, mock_context) @@ -465,20 +462,20 @@ def test_push_task_success(self) -> None: self.assertEqual(inflight.host, "broker-host:50051") def test_push_task_worker_busy(self) -> None: - taskworker = TaskWorker( + taskworker = PushTaskWorker( app_module="examples.app:app", - broker_hosts=["127.0.0.1:50051"], + broker_service="127.0.0.1:50051", max_child_task_count=100, process_type="fork", child_tasks_queue_maxsize=1, ) - with mock.patch.object(taskworker, "push_task", return_value=False): + with mock.patch.object(taskworker.worker_pool, "push_task", return_value=False): request = PushTaskRequest( task=SIMPLE_TASK.activation, callback_url="broker-host:50051", ) mock_context = mock.MagicMock() - servicer = WorkerServicer(taskworker) + servicer = WorkerServicer(taskworker.worker_pool) servicer.PushTask(request, mock_context)