From de13ab856a14ac90627a5fef33ed36556326c661 Mon Sep 17 00:00:00 2001 From: Joel Collins Date: Thu, 23 Apr 2020 12:05:58 +0100 Subject: [PATCH 1/3] Updated task management to better match Gevent pool interface --- labthings/core/tasks/__init__.py | 14 +-- labthings/core/tasks/pool.py | 132 +++++++++--------------- labthings/core/tasks/thread.py | 15 +-- labthings/server/default_views/tasks.py | 8 +- tests/test_core_tasks_pool.py | 17 ++- tests/test_core_tasks_thread.py | 6 +- tests/test_server_decorators.py | 4 +- tests/test_server_default_views.py | 10 +- tests/test_server_schema.py | 2 +- 9 files changed, 86 insertions(+), 122 deletions(-) diff --git a/labthings/core/tasks/__init__.py b/labthings/core/tasks/__init__.py index f9255de8..aa18d802 100644 --- a/labthings/core/tasks/__init__.py +++ b/labthings/core/tasks/__init__.py @@ -1,24 +1,26 @@ __all__ = [ + "Pool", "taskify", "tasks", - "dictionary", + "to_dict", "states", "current_task", "update_task_progress", - "cleanup_tasks", - "remove_task", + "cleanup", + "discard_id", "update_task_data", "ThreadTerminationError", ] from .pool import ( + Pool, tasks, - dictionary, + to_dict, states, current_task, update_task_progress, - cleanup_tasks, - remove_task, + cleanup, + discard_id, update_task_data, taskify, ) diff --git a/labthings/core/tasks/pool.py b/labthings/core/tasks/pool.py index 04aeaa81..855cab82 100644 --- a/labthings/core/tasks/pool.py +++ b/labthings/core/tasks/pool.py @@ -1,111 +1,71 @@ import logging from functools import wraps from gevent import getcurrent +from gevent.pool import Pool as _Pool, PoolFull from .thread import TaskThread -from flask import copy_current_request_context, has_request_context +class Pool(_Pool): + def __init__(self, size=None): + _Pool.__init__(self, size=size, greenlet_class=TaskThread) -class TaskMaster: - def __init__(self, *args, **kwargs): - self._tasks = [] + def add(self, greenlet, blocking=True, timeout=None): + """ + Override the default Gevent pool `add` method so that + tasks are not discarded as soon as they finish. + """ + if not self._semaphore.acquire(blocking=blocking, timeout=timeout): + # We failed to acquire the semaphore. + # If blocking was True, then there was a timeout. If blocking was + # False, then there was no capacity. Either way, raise PoolFull. + raise PoolFull() + + try: + self.greenlets.add(greenlet) + self._empty_event.clear() + except: + self._semaphore.release() + raise - @property def tasks(self): """ Returns: list: List of TaskThread objects. """ - return self._tasks + return list(self.greenlets) - @property - def dict(self): + def states(self): """ Returns: - dict: Dictionary of TaskThread objects. Key is TaskThread ID. + dict: Dictionary of TaskThread.state dictionaries. Key is TaskThread ID. """ - return {str(t.id): t for t in self._tasks} + return {str(t.id): t.state for t in self.greenlets} - @property - def states(self): + def to_dict(self): """ Returns: - dict: Dictionary of TaskThread.state dictionaries. Key is TaskThread ID. + dict: Dictionary of TaskThread objects. Key is TaskThread ID. """ - return {str(t.id): t.state for t in self._tasks} - - def new(self, f, *args, **kwargs): - # copy_current_request_context allows threads to access flask current_app - if has_request_context(): - target = copy_current_request_context(f) - else: - target = f - task = TaskThread(target=target, args=args, kwargs=kwargs) - self._tasks.append(task) - return task + return {str(t.id): t for t in self.greenlets} - def remove(self, task_id): - for task in self._tasks: + def discard_id(self, task_id): + marked_for_discard = set() + for task in self.greenlets: if (str(task.id) == str(task_id)) and task.dead: - self._tasks.remove(task) + marked_for_discard.add(task) + + for greenlet in marked_for_discard: + self.discard(greenlet) def cleanup(self): - for i, task in enumerate(self._tasks): + marked_for_discard = set() + for task in self.greenlets: if task.dead: - # Mark for delection - self._tasks[i] = None - # Remove items marked for deletion - self._tasks = [t for t in self._tasks if t] - - -# Task management - - -def tasks(): - """ - List of tasks in default taskmaster - Returns: - list: List of tasks in default taskmaster - """ - global DEFAULT_TASK_MASTER - return DEFAULT_TASK_MASTER.tasks + marked_for_discard.add(task) - -def dictionary(): - """ - Dictionary of tasks in default taskmaster - Returns: - dict: Dictionary of tasks in default taskmaster - """ - global DEFAULT_TASK_MASTER - return DEFAULT_TASK_MASTER.dict - - -def states(): - """ - Dictionary of TaskThread.state dictionaries. Key is TaskThread ID. - Returns: - dict: Dictionary of task states in default taskmaster - """ - global DEFAULT_TASK_MASTER - return DEFAULT_TASK_MASTER.states - - -def cleanup_tasks(): - """Remove all finished tasks from the task list""" - global DEFAULT_TASK_MASTER - return DEFAULT_TASK_MASTER.cleanup() - - -def remove_task(task_id: str): - """Remove a particular task from the task list - - Arguments: - task_id {str} -- ID of the target task - """ - global DEFAULT_TASK_MASTER - return DEFAULT_TASK_MASTER.remove(task_id) + for greenlet in marked_for_discard: + self.discard(greenlet) # Operations on the current task @@ -161,17 +121,23 @@ def taskify(f): A decorator that wraps the passed in function and surpresses exceptions should one occur """ + global default_pool @wraps(f) def wrapped(*args, **kwargs): - task = DEFAULT_TASK_MASTER.new( + task = default_pool.spawn( f, *args, **kwargs ) # Append to parent object's task list - task.start() # Start the function return task return wrapped # Create our default, protected, module-level task pool -DEFAULT_TASK_MASTER = TaskMaster() +default_pool = Pool() + +tasks = default_pool.tasks +to_dict = default_pool.to_dict +states = default_pool.states +cleanup = default_pool.cleanup +discard_id = default_pool.discard_id diff --git a/labthings/core/tasks/thread.py b/labthings/core/tasks/thread.py index e9d96bb0..8b9b3ea8 100644 --- a/labthings/core/tasks/thread.py +++ b/labthings/core/tasks/thread.py @@ -1,6 +1,7 @@ from gevent import Greenlet, GreenletExit from gevent.thread import get_ident from gevent.event import Event +from flask import copy_current_request_context, has_request_context import datetime import logging import traceback @@ -18,13 +19,8 @@ class TaskKillException(Exception): class TaskThread(Greenlet): - def __init__(self, target=None, args=None, kwargs=None): + def __init__(self, target, *args, **kwargs): Greenlet.__init__(self) - # Handle arguments - if args is None: - args = () - if kwargs is None: - kwargs = {} # A UUID for the TaskThread (not the same as the threading.Thread ident) self._ID = uuid.uuid4() # Task ID @@ -83,7 +79,12 @@ def update_data(self, data: dict): self.data.update(data) def _run(self): # pylint: disable=E0202 - return self._thread_proc(self._target)(*self._args, **self._kwargs) + # copy_current_request_context allows threads to access flask current_app + if has_request_context(): + target = copy_current_request_context(self._target) + else: + target = self._target + return self._thread_proc(target)(*self._args, **self._kwargs) def _thread_proc(self, f): """ diff --git a/labthings/server/default_views/tasks.py b/labthings/server/default_views/tasks.py index 09bdb4fa..e380c9d7 100644 --- a/labthings/server/default_views/tasks.py +++ b/labthings/server/default_views/tasks.py @@ -31,9 +31,9 @@ def get(self, task_id): Includes progress and intermediate data. """ - task_dict = tasks.dictionary() + task_dict = tasks.to_dict() - if not task_id in task_dict: + if task_id not in task_dict: return abort(404) # 404 Not Found task = task_dict.get(task_id) @@ -47,9 +47,9 @@ def delete(self, task_id): If the task is finished, deletes its entry. """ - task_dict = tasks.dictionary() + task_dict = tasks.to_dict() - if not task_id in task_dict: + if task_id not in task_dict: return abort(404) # 404 Not Found task = task_dict.get(task_id) diff --git a/tests/test_core_tasks_pool.py b/tests/test_core_tasks_pool.py index 9bdb09b6..902dd9f3 100644 --- a/tests/test_core_tasks_pool.py +++ b/tests/test_core_tasks_pool.py @@ -55,13 +55,10 @@ def test_tasks_list(): def test_tasks_dict(): assert all( - [ - isinstance(task_obj, gevent.Greenlet) - for task_obj in tasks.dictionary().values() - ] + [isinstance(task_obj, gevent.Greenlet) for task_obj in tasks.to_dict().values()] ) - assert all([k == str(t.id) for k, t in tasks.dictionary().items()]) + assert all([k == str(t.id) for k, t in tasks.to_dict().items()]) def test_task_states(): @@ -80,16 +77,16 @@ def test_task_states(): assert all(k in state for k in state_keys) -def test_remove_task(): +def test_discard_id(): def task_func(): pass task_obj = tasks.taskify(task_func)() - assert str(task_obj.id) in tasks.dictionary() + assert str(task_obj.id) in tasks.to_dict() task_obj.join() - tasks.remove_task(task_obj.id) - assert not str(task_obj.id) in tasks.dictionary() + tasks.discard_id(task_obj.id) + assert not str(task_obj.id) in tasks.to_dict() def test_cleanup_task(): @@ -105,5 +102,5 @@ def task_func(): gevent.joinall(tasks.tasks()) assert len(tasks.tasks()) > 0 - tasks.cleanup_tasks() + tasks.cleanup() assert len(tasks.tasks()) == 0 diff --git a/tests/test_core_tasks_thread.py b/tests/test_core_tasks_thread.py index e82fc22e..890b41fb 100644 --- a/tests/test_core_tasks_thread.py +++ b/tests/test_core_tasks_thread.py @@ -9,9 +9,7 @@ def test_task_with_args(): def task_func(arg, kwarg=False): pass - task_obj = thread.TaskThread( - target=task_func, args=("String arg",), kwargs={"kwarg": True} - ) + task_obj = thread.TaskThread(task_func, "String arg", kwarg=True) assert isinstance(task_obj, gevent.Greenlet) assert task_obj._target == task_func assert task_obj._args == ("String arg",) @@ -110,7 +108,7 @@ def test_task_log_without_thread(): def test_task_log_with_incorrect_thread(): - task_obj = thread.TaskThread() + task_obj = thread.TaskThread(None) task_log_handler = thread.ThreadLogHandler(thread=task_obj) # Should always return False if called from outside the log handlers thread diff --git a/tests/test_server_decorators.py b/tests/test_server_decorators.py index 94116f59..4a523bda 100644 --- a/tests/test_server_decorators.py +++ b/tests/test_server_decorators.py @@ -91,7 +91,7 @@ def func(): def test_marshal_task(app_ctx): def func(): - return TaskThread() + return TaskThread(None) wrapped_func = decorators.marshal_task(func) @@ -102,7 +102,7 @@ def func(): def test_marshal_task_response_tuple(app_ctx): def func(): - return (TaskThread(), 201, {}) + return (TaskThread(None), 201, {}) wrapped_func = decorators.marshal_task(func) diff --git a/tests/test_server_default_views.py b/tests/test_server_default_views.py index 834eb4f2..c609bda3 100644 --- a/tests/test_server_default_views.py +++ b/tests/test_server_default_views.py @@ -1,4 +1,4 @@ -from labthings.core.tasks import taskify, dictionary +from labthings.core import tasks import gevent @@ -22,7 +22,7 @@ def test_tasks_list(thing_client): def task_func(): pass - task_obj = taskify(task_func)() + task_obj = tasks.taskify(task_func)() with thing_client as c: response = c.get("/tasks").json @@ -34,7 +34,7 @@ def test_task_representation(thing_client): def task_func(): pass - task_obj = taskify(task_func)() + task_obj = tasks.taskify(task_func)() task_id = str(task_obj.id) with thing_client as c: @@ -52,12 +52,12 @@ def task_func(): while True: gevent.sleep(0) - task_obj = taskify(task_func)() + task_obj = tasks.taskify(task_func)() task_id = str(task_obj.id) # Wait for task to start task_obj.started_event.wait() - assert task_id in dictionary() + assert task_id in tasks.to_dict() # Send a DELETE request to terminate the task with thing_client as c: diff --git a/tests/test_server_schema.py b/tests/test_server_schema.py index 3bb6fea6..ecd60538 100644 --- a/tests/test_server_schema.py +++ b/tests/test_server_schema.py @@ -55,7 +55,7 @@ def test_field_schema(app_ctx): def test_task_schema(app_ctx): test_schema = schema.TaskSchema() - test_task_thread = TaskThread() + test_task_thread = TaskThread(None) with app_ctx.test_request_context(): d = test_schema.dump(test_task_thread) From 217b87e7a1d30298355dae3b29f7305dffbd2178 Mon Sep 17 00:00:00 2001 From: Joel Collins Date: Thu, 23 Apr 2020 12:33:55 +0100 Subject: [PATCH 2/3] Fixed ws test accounting for INFO logging level --- tests/test_server_default_views_socket_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_server_default_views_socket_handler.py b/tests/test_server_default_views_socket_handler.py index 3b626b94..0e390783 100644 --- a/tests/test_server_default_views_socket_handler.py +++ b/tests/test_server_default_views_socket_handler.py @@ -5,5 +5,6 @@ def test_socket_handler(thing_ctx, fake_websocket): with thing_ctx.test_request_context(): ws = fake_websocket("", recieve_once=True) socket_handler(ws) - # Expect no response - assert ws.responses == [] + # Only responses should be announcing new subscribers + for response in ws.responses: + assert '"data": "Added subscriber' in response From 9c723c393b70d3a2cc1c8668647f3aa581c81efc Mon Sep 17 00:00:00 2001 From: Joel Collins Date: Thu, 23 Apr 2020 16:27:09 +0100 Subject: [PATCH 3/3] Improved spec utilities coverage --- tests/test_server_spec_utilities.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_server_spec_utilities.py b/tests/test_server_spec_utilities.py index 2b8fdd35..400134d2 100644 --- a/tests/test_server_spec_utilities.py +++ b/tests/test_server_spec_utilities.py @@ -23,6 +23,23 @@ def test_update_spec(view_cls): } +def test_tag_spec(view_cls): + utilities.tag_spec(view_cls, set(["tag1"])) + assert view_cls.__apispec__.get("tags") == set(["tag1"]) + utilities.tag_spec(view_cls, set(["tag2"])) + assert view_cls.__apispec__.get("tags") == set(["tag1", "tag2"]) + + +def test_tag_spec_string(view_cls): + utilities.tag_spec(view_cls, "tag1") + assert view_cls.__apispec__.get("tags") == set(["tag1"]) + + +def test_tag_spec_invalid(view_cls): + with pytest.raises(TypeError): + utilities.tag_spec(view_cls, set([object(), "tag"])) + + def test_get_spec(view_cls): assert utilities.get_spec(None) == {} assert utilities.get_spec(view_cls) == {}