From 7a2a044cb02d2e025793cc0454a3abcefe90397e Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Fri, 7 Jul 2023 17:41:05 +0200 Subject: [PATCH] Implements a task scheduler --- changelog.d/15891.feature | 1 + synapse/app/generic_worker.py | 2 + synapse/handlers/task_scheduler.py | 221 ++++++++++++++++++ synapse/server.py | 6 + synapse/storage/databases/main/__init__.py | 2 + .../storage/databases/main/task_scheduler.py | 123 ++++++++++ .../main/delta/79/03_scheduled_tasks.sql | 26 +++ synapse/types/__init__.py | 20 ++ tests/handlers/test_task_scheduler.py | 108 +++++++++ 9 files changed, 509 insertions(+) create mode 100644 changelog.d/15891.feature create mode 100644 synapse/handlers/task_scheduler.py create mode 100644 synapse/storage/databases/main/task_scheduler.py create mode 100644 synapse/storage/schema/main/delta/79/03_scheduled_tasks.sql create mode 100644 tests/handlers/test_task_scheduler.py diff --git a/changelog.d/15891.feature b/changelog.d/15891.feature new file mode 100644 index 000000000000..5a3d12a32e2f --- /dev/null +++ b/changelog.d/15891.feature @@ -0,0 +1 @@ +Implements a task scheduler. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index dc79efcc142f..d25e3548e075 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -91,6 +91,7 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore +from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore @@ -144,6 +145,7 @@ class GenericWorkerStore( TransactionWorkerStore, LockStore, SessionStore, + TaskSchedulerWorkerStore, ): # Properties that multiple storage classes define. Tell mypy what the # expected type is. diff --git a/synapse/handlers/task_scheduler.py b/synapse/handlers/task_scheduler.py new file mode 100644 index 000000000000..3c499c364b87 --- /dev/null +++ b/synapse/handlers/task_scheduler.py @@ -0,0 +1,221 @@ +import logging +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set, Tuple + +from twisted.python.failure import Failure + +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonMapping, ScheduledTask, TaskStatus +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class TaskSchedulerHandler: + # Precision of the scheduler, evaluation of tasks to run will only happen + # every `SCHEDULE_INTERVAL_MS` ms + SCHEDULE_INTERVAL_MS = 5 * 60 * 1000 # 5mn + # Time before a complete or failed task is deleted from the DB + KEEP_TASKS_FOR_MS = 7 * 24 * 60 * 60 * 1000 # 1 week + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + self.running_tasks: Set[str] = set() + self.actions: Dict[ + str, + Callable[ + [ScheduledTask, bool], + Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], + ], + ] = {} + self.run_background_tasks = hs.config.worker.run_background_tasks + + if self.run_background_tasks: + self.clock.looping_call( + run_as_background_process, + TaskSchedulerHandler.SCHEDULE_INTERVAL_MS, + "scheduled_tasks_loop", + self._scheduled_tasks_loop, + ) + + def register_action( + self, + function: Callable[ + [ScheduledTask, bool], + Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], + ], + action_name: str, + ) -> None: + """Register a function to be executed when an action is scheduled with + the specified action name. + + Actions need to be registered as early as possible so that a resumed action + can find its matching function. It's usually better to NOT do that right before + calling `schedule_task` but rather in an `__init__` method. + + Args: + function: The function to be executed for this action. The parameters + passed to the function when launched are the `ScheduledTask` being run, + and a `first_launch` boolean to signal if it's a resumed task or the first + launch of it. The function should return a tuple of new `status`, `result` + and `error` as specified in `ScheduledTask`. + action_name: The name of the action to be associated with the function + """ + self.actions[action_name] = function + + async def schedule_task( + self, + action: str, + *, + resource_id: Optional[str] = None, + timestamp: Optional[int] = None, + params: Optional[JsonMapping] = None, + ) -> str: + """Schedule a new potentially resumable task. A function matching the specified + `action` should have been previously registered with `register_action`. + + Args: + action: the name of a previously registered action + resource_id: a task can be associated with a resource id to facilitate + getting all tasks associated with a specific resource + timestamp: if `None`, the task will be launched immediatly + params: a set of parameters that can be easily accessed from inside the + executed function + + Returns: the id of the scheduled task + """ + if action not in self.actions: + raise Exception( + f"No function associated with the action {action} of the scheduled task" + ) + + launch_now = False + if timestamp is None or timestamp < self.clock.time_msec(): + timestamp = self.clock.time_msec() + launch_now = True + + task = ScheduledTask( + random_string(16), + action, + TaskStatus.SCHEDULED, + timestamp, + resource_id, + params, + None, + None, + ) + await self.store.upsert_scheduled_task(task) + + if launch_now and self.run_background_tasks: + await self._launch_task(task, True) + + return task.id + + async def update_task( + self, + id: str, + *, + status: Optional[TaskStatus] = None, + result: Optional[JsonMapping] = None, + error: Optional[str] = None, + ) -> None: + """Update some task associated values. + + This is used internally in this handler, and also exposed publically so it can + be used inside task functions. This allows to store in DB the progress of a task + so it can be resumed properly after a restart of synapse. + + Args: + id: the id of the task to update + status: the new `TaskStatus` of the task + result: the new result of the task + error: the new error of the task + """ + await self.store.update_scheduled_task( + id, + timestamp=self.clock.time_msec(), + status=status, + result=result, + error=error, + ) + + async def get_task(self, id: str) -> Optional[ScheduledTask]: + """Get a specific task description by id. + + Args: + id: the id of the task to retrieve + + Returns: the task description or `None` if it doesn't exist + or it has already been cleaned + """ + return await self.store.get_scheduled_task(id) + + async def get_tasks( + self, action: str, resource_id: Optional[str] + ) -> List[ScheduledTask]: + """Get a list of tasks associated with an action name, and + optionally with a resource id. + + Args: + action: the action name of the tasks to retrieve + resource_id: if `None`, returns all associated tasks for + the specified action name, regardless of the resource id + + Returns: a list of `ScheduledTask` + """ + return await self.store.get_scheduled_tasks(action, resource_id) + + async def _scheduled_tasks_loop(self) -> None: + for task in await self.store.get_scheduled_tasks(): + if task.id not in self.running_tasks: + if ( + task.status == TaskStatus.SCHEDULED + and task.timestamp < self.clock.time_msec() + ): + await self._launch_task(task, True) + elif task.status == TaskStatus.ACTIVE: + await self._launch_task(task, False) + elif ( + task.status == TaskStatus.COMPLETE + or task.status == TaskStatus.FAILED + ) and self.clock.time_msec() > task.timestamp + TaskSchedulerHandler.KEEP_TASKS_FOR_MS: + await self.store.delete_scheduled_task(task.id) + + async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None: + if task.action not in self.actions: + raise Exception( + f"No function associated with the action {task.action} of the scheduled task" + ) + + function = self.actions[task.action] + + async def wrapper() -> None: + try: + (status, result, error) = await function(task, first_launch) + except Exception: + f = Failure() + logger.error( + f"scheduled task {task.id} failed", + exc_info=(f.type, f.value, f.getTracebackObject()), + ) + status = TaskStatus.FAILED + result = None + error = f.getErrorMessage() + + await self.update_task( + task.id, + status=status, + result=result, + error=error, + ) + self.running_tasks.remove(task.id) + + await self.update_task(task.id, status=TaskStatus.ACTIVE) + self.running_tasks.add(task.id) + description = task.action + if task.resource_id: + description += f"-{task.resource_id}" + run_as_background_process(description, wrapper) diff --git a/synapse/server.py b/synapse/server.py index b72b76a38b35..c0e1277f6ae4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -105,6 +105,7 @@ from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler +from synapse.handlers.task_scheduler import TaskSchedulerHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.http.client import ( @@ -242,6 +243,7 @@ class HomeServer(metaclass=abc.ABCMeta): "profile", "room_forgetter", "stats", + "task_scheduler", ] # This is overridden in derived application classes @@ -912,3 +914,7 @@ def get_request_ratelimiter(self) -> RequestRatelimiter: def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: """Usage metrics shared between phone home stats and the prometheus exporter.""" return CommonUsageMetricsManager(self) + + @cache_in_self + def get_task_scheduler_handler(self) -> TaskSchedulerHandler: + return TaskSchedulerHandler(self) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index b6028853c939..cb8fb665e478 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -70,6 +70,7 @@ from .stats import StatsStore from .stream import StreamWorkerStore from .tags import TagsStore +from .task_scheduler import TaskSchedulerWorkerStore from .transactions import TransactionWorkerStore from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore @@ -127,6 +128,7 @@ class DataStore( CacheInvalidationWorkerStore, LockStore, SessionStore, + TaskSchedulerWorkerStore, ): def __init__( self, diff --git a/synapse/storage/databases/main/task_scheduler.py b/synapse/storage/databases/main/task_scheduler.py new file mode 100644 index 000000000000..ea7ccd3fa45a --- /dev/null +++ b/synapse/storage/databases/main/task_scheduler.py @@ -0,0 +1,123 @@ +import json +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection +from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class TaskSchedulerWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + @staticmethod + def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask: + row["status"] = TaskStatus(row["status"]) + if row["params"] is not None: + row["params"] = json.loads(row["params"]) + if row["result"] is not None: + row["result"] = json.loads(row["result"]) + return ScheduledTask(**row) + + async def get_scheduled_tasks( + self, action: Optional[str] = None, resource_id: Optional[str] = None + ) -> List[ScheduledTask]: + keyvalues = {} + if action: + keyvalues["action"] = action + if resource_id: + keyvalues["resource_id"] = resource_id + + rows = await self.db_pool.simple_select_list( + table="scheduled_tasks", + keyvalues=keyvalues, + retcols=( + "id", + "action", + "status", + "timestamp", + "resource_id", + "params", + "result", + "error", + ), + desc="get_scheduled_tasks", + ) + + return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows] + + async def upsert_scheduled_task(self, task: ScheduledTask) -> None: + await self.db_pool.simple_upsert( + "scheduled_tasks", + {"id": task.id}, + { + "action": task.action, + "status": task.status, + "timestamp": task.timestamp, + "resource_id": task.resource_id, + "params": None if task.params is None else json.dumps(task.params), + "result": None if task.result is None else json.dumps(task.result), + "error": task.error, + }, + desc="upsert_scheduled_task", + ) + + async def update_scheduled_task( + self, + id: str, + *, + timestamp: Optional[int] = None, + status: Optional[TaskStatus] = None, + result: Optional[JsonMapping] = None, + error: Optional[str] = None, + ) -> None: + updatevalues: JsonDict = {} + if timestamp is not None: + updatevalues["timestamp"] = timestamp + if status is not None: + updatevalues["status"] = status + if result is not None: + updatevalues["result"] = json.dumps(result) + if error is not None: + updatevalues["error"] = error + await self.db_pool.simple_update( + "scheduled_tasks", + {"id": id}, + updatevalues, + desc="update_scheduled_task", + ) + + async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]: + row = await self.db_pool.simple_select_one( + table="scheduled_tasks", + keyvalues={"id": id}, + retcols=( + "id", + "action", + "status", + "timestamp", + "resource_id", + "params", + "result", + "error", + ), + allow_none=True, + desc="get_scheduled_task", + ) + + return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None + + async def delete_scheduled_task(self, id: str) -> None: + await self.db_pool.simple_delete( + "scheduled_tasks", + keyvalues={"id": id}, + desc="delete_scheduled_task", + ) diff --git a/synapse/storage/schema/main/delta/79/03_scheduled_tasks.sql b/synapse/storage/schema/main/delta/79/03_scheduled_tasks.sql new file mode 100644 index 000000000000..4ee43887b6c7 --- /dev/null +++ b/synapse/storage/schema/main/delta/79/03_scheduled_tasks.sql @@ -0,0 +1,26 @@ +/* Copyright 2023 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- cf ScheduledTask docstring for the meaning of the fields. +CREATE TABLE IF NOT EXISTS scheduled_tasks( + id text PRIMARY KEY, + action text NOT NULL, + status text NOT NULL, + timestamp bigint NOT NULL, + resource_id text, + params text, + result text, + error text +); diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 095be070e0c5..78affbd884cc 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -15,6 +15,7 @@ import abc import re import string +from enum import Enum from typing import ( TYPE_CHECKING, AbstractSet, @@ -979,3 +980,22 @@ class UserProfile(TypedDict): class RetentionPolicy: min_lifetime: Optional[int] = None max_lifetime: Optional[int] = None + + +class TaskStatus(str, Enum): + SCHEDULED = "scheduled" + ACTIVE = "active" + COMPLETE = "complete" + FAILED = "failed" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ScheduledTask: + id: str + action: str + status: TaskStatus + timestamp: int + resource_id: Optional[str] + params: Optional[JsonMapping] + result: Optional[JsonMapping] + error: Optional[str] diff --git a/tests/handlers/test_task_scheduler.py b/tests/handlers/test_task_scheduler.py new file mode 100644 index 000000000000..2b9df13ebb30 --- /dev/null +++ b/tests/handlers/test_task_scheduler.py @@ -0,0 +1,108 @@ +from typing import Optional, Tuple + +from twisted.internet.task import deferLater +from twisted.test.proto_helpers import MemoryReactor + +from synapse.handlers.task_scheduler import TaskSchedulerHandler +from synapse.server import HomeServer +from synapse.types import JsonMapping, ScheduledTask, TaskStatus +from synapse.util import Clock + +from tests import unittest + + +class TestTaskScheduler(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.handler = hs.get_task_scheduler_handler() + self.handler.register_action(self._test_task, "_test_task") + self.handler.register_action(self._raising_task, "_raising_task") + self.handler.register_action(self._resumable_task, "_resumable_task") + + async def _test_task( + self, task: ScheduledTask, first_launch: bool + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + result = None + if task.params: + result = task.params + return (TaskStatus.COMPLETE, result, None) + + def test_schedule_task(self) -> None: + timestamp = self.clock.time_msec() + 2 * 60 * 1000 + task_id = self.get_success( + self.handler.schedule_task( + "_test_task", + timestamp=timestamp, + params={"val": 1}, + ) + ) + + task = self.get_success(self.handler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.SCHEDULED) + self.assertIsNone(task.result) + + self.reactor.advance((TaskSchedulerHandler.SCHEDULE_INTERVAL_MS / 1000) + 1) + + task = self.get_success(self.handler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.COMPLETE) + assert task.result is not None + self.assertTrue(task.result.get("val") == 1) + + self.reactor.advance((TaskSchedulerHandler.KEEP_TASKS_FOR_MS / 1000) + 1) + + task = self.get_success(self.handler.get_task(task_id)) + self.assertIsNone(task) + + def test_schedule_task_now(self) -> None: + task_id = self.get_success( + self.handler.schedule_task("_test_task", params={"val": 1}) + ) + + task = self.get_success(self.handler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.COMPLETE) + assert task.result is not None + self.assertTrue(task.result.get("val") == 1) + + async def _raising_task( + self, task: ScheduledTask, first_launch: bool + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + raise Exception("raising") + + def test_schedule_raising_task_now(self) -> None: + task_id = self.get_success(self.handler.schedule_task("_raising_task")) + + task = self.get_success(self.handler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.FAILED) + self.assertEqual(task.error, "raising") + + async def _resumable_task( + self, task: ScheduledTask, first_launch: bool + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + if task.result and "in_progress" in task.result: + return TaskStatus.COMPLETE, {"success": True}, None + else: + await self.handler.update_task(task.id, result={"in_progress": True}) + # Await forever to simulate an aborted task because of a restart + await deferLater(self.reactor, 2**16) + # This should never been called + return TaskStatus.ACTIVE, None, None + + def test_schedule_resumable_task_now(self) -> None: + task_id = self.get_success(self.handler.schedule_task("_resumable_task")) + + task = self.get_success(self.handler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.ACTIVE) + + # Simulate a synapse restart by emptying the list of running tasks + self.handler.running_tasks = set() + self.reactor.advance((TaskSchedulerHandler.SCHEDULE_INTERVAL_MS / 1000) + 1) + + task = self.get_success(self.handler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.COMPLETE) + assert task.result is not None + self.assertTrue(task.result.get("success"))