From 6626754b5980a1290e4eec121f267f954db5ad8d Mon Sep 17 00:00:00 2001 From: Thomas Steinacher Date: Wed, 24 Apr 2024 13:40:26 +0100 Subject: [PATCH] Task lock should be non-local (#333) Analyzing the thread safety of the heartbeat function in the sync executor: - Redis-py should by threadsafe by default unless `single_connection_client` is given, which is false by default. - When doing a `heartbeat()` we update the task timestamp, renew any task locks and any queue lock. - The queue lock ([`Semaphore`](https://github.com/closeio/tasktiger/blob/master/tasktiger/redis_semaphore.py)) does not care about threads so we can therefore renew it from a different thread. - Redis locks, which we use for locking tasks, are thread-local by default, we therefore need to set `thread_local=False`. - We use Redis locks in a couple other places but these are not being accessed by the heartbeat thread. Also: - Improved heartbeat so we don't need to do a heartbeat if the task already completed after the wait, and added exception handling. - Fixed test to give it a bit more time to start (had some failures locally) + have at least one task & queue lock that should be renewed. --- tasktiger/executor.py | 8 +++++--- tasktiger/worker.py | 2 ++ tests/test_workers.py | 39 +++++++++++++++++++++++++++++++++------ 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/tasktiger/executor.py b/tasktiger/executor.py index 293f4ea..0b98b95 100644 --- a/tasktiger/executor.py +++ b/tasktiger/executor.py @@ -403,9 +403,11 @@ def _periodic_heartbeat( queue_lock: Optional[Semaphore], stop_event: threading.Event, ) -> None: - while not stop_event.is_set(): - stop_event.wait(self.config["ACTIVE_TASK_UPDATE_TIMER"]) - self.heartbeat(queue, task_ids, log, locks, queue_lock) + while not stop_event.wait(self.config["ACTIVE_TASK_UPDATE_TIMER"]): + try: + self.heartbeat(queue, task_ids, log, locks, queue_lock) + except Exception: + log.exception("task heartbeat failed") def execute( self, diff --git a/tasktiger/worker.py b/tasktiger/worker.py index 6c1c49c..e6d1252 100644 --- a/tasktiger/worker.py +++ b/tasktiger/worker.py @@ -674,6 +674,8 @@ def _execute_task_group( lock = self.connection.lock( self._key("lockv2", lock_id), timeout=self.config["ACTIVE_TASK_UPDATE_TIMEOUT"], + # Sync worker uses a thread to renew the lock. + thread_local=False, ) if not lock.acquire(blocking=False): log.info("could not acquire lock", task_id=task.id) diff --git a/tests/test_workers.py b/tests/test_workers.py index 4288812..1b30fdf 100644 --- a/tests/test_workers.py +++ b/tests/test_workers.py @@ -10,6 +10,7 @@ from tasktiger import Task, Worker from tasktiger._internal import ACTIVE from tasktiger.executor import SyncExecutor +from tasktiger.worker import LOCK_REDIS_KEY from .config import DELAY from .tasks import ( @@ -188,7 +189,9 @@ def test_handles_timeout(self, tiger, ensure_queues): ensure_queues(error={"default": 1}) def test_heartbeat(self, tiger): - task = Task(tiger, sleep_task) + # Test both task heartbeat and lock renewal. + # We set unique=True so the task ID matches the lock key. + task = Task(tiger, sleep_task, lock=True, unique=True) task.delay() # Start a worker and wait until it starts processing. @@ -196,18 +199,42 @@ def test_heartbeat(self, tiger): target=external_worker, kwargs={ "patch_config": {"ACTIVE_TASK_UPDATE_TIMER": DELAY / 2}, - "worker_kwargs": {"executor_class": SyncExecutor}, + "worker_kwargs": { + # Test queue lock. + "max_workers_per_queue": 1, + "executor_class": SyncExecutor, + }, }, ) worker.start() - time.sleep(DELAY / 2) + time.sleep(DELAY) + + queue_key = tiger._key(ACTIVE, "default") + queue_lock_key = tiger._key(LOCK_REDIS_KEY, "default") + task_lock_key = tiger._key("lockv2", task.id) - key = tiger._key(ACTIVE, "default") conn = tiger.connection - heartbeat_1 = conn.zscore(key, task.id) + + heartbeat_1 = conn.zscore(queue_key, task.id) + queue_lock_1 = conn.zrange(queue_lock_key, 0, -1, withscores=True)[0][ + 1 + ] + task_lock_1 = conn.pttl(task_lock_key) + time.sleep(DELAY / 2) - heartbeat_2 = conn.zscore(key, task.id) + + heartbeat_2 = conn.zscore(queue_key, task.id) + queue_lock_2 = conn.zrange(queue_lock_key, 0, -1, withscores=True)[0][ + 1 + ] + task_lock_2 = conn.pttl(task_lock_key) + assert heartbeat_2 > heartbeat_1 > 0 + assert queue_lock_2 > queue_lock_1 > 0 + + # Active task update timeout is 2 * DELAY and we renew every DELAY / 2. + assert task_lock_1 > DELAY + assert task_lock_2 > DELAY worker.kill()