Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions labthings/core/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
__all__ = [
"Pool",
"taskify",
"tasks",
"dictionary",
"to_dict",
"states",
"current_task",
"update_task_progress",
"cleanup_tasks",
"remove_task",
"cleanup",
"discard_id",
"update_task_data",
"ThreadTerminationError",
]

from .pool import (
Pool,
tasks,
dictionary,
to_dict,
states,
current_task,
update_task_progress,
cleanup_tasks,
remove_task,
cleanup,
discard_id,
update_task_data,
taskify,
)
Expand Down
132 changes: 49 additions & 83 deletions labthings/core/tasks/pool.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,71 @@
import logging
from functools import wraps
from gevent import getcurrent
from gevent.pool import Pool as _Pool, PoolFull

from .thread import TaskThread

from flask import copy_current_request_context, has_request_context

class Pool(_Pool):
def __init__(self, size=None):
_Pool.__init__(self, size=size, greenlet_class=TaskThread)

class TaskMaster:
def __init__(self, *args, **kwargs):
self._tasks = []
def add(self, greenlet, blocking=True, timeout=None):
"""
Override the default Gevent pool `add` method so that
tasks are not discarded as soon as they finish.
"""
if not self._semaphore.acquire(blocking=blocking, timeout=timeout):
# We failed to acquire the semaphore.
# If blocking was True, then there was a timeout. If blocking was
# False, then there was no capacity. Either way, raise PoolFull.
raise PoolFull()

try:
self.greenlets.add(greenlet)
self._empty_event.clear()
except:
self._semaphore.release()
raise

@property
def tasks(self):
"""
Returns:
list: List of TaskThread objects.
"""
return self._tasks
return list(self.greenlets)

@property
def dict(self):
def states(self):
"""
Returns:
dict: Dictionary of TaskThread objects. Key is TaskThread ID.
dict: Dictionary of TaskThread.state dictionaries. Key is TaskThread ID.
"""
return {str(t.id): t for t in self._tasks}
return {str(t.id): t.state for t in self.greenlets}

@property
def states(self):
def to_dict(self):
"""
Returns:
dict: Dictionary of TaskThread.state dictionaries. Key is TaskThread ID.
dict: Dictionary of TaskThread objects. Key is TaskThread ID.
"""
return {str(t.id): t.state for t in self._tasks}

def new(self, f, *args, **kwargs):
# copy_current_request_context allows threads to access flask current_app
if has_request_context():
target = copy_current_request_context(f)
else:
target = f
task = TaskThread(target=target, args=args, kwargs=kwargs)
self._tasks.append(task)
return task
return {str(t.id): t for t in self.greenlets}

def remove(self, task_id):
for task in self._tasks:
def discard_id(self, task_id):
marked_for_discard = set()
for task in self.greenlets:
if (str(task.id) == str(task_id)) and task.dead:
self._tasks.remove(task)
marked_for_discard.add(task)

for greenlet in marked_for_discard:
self.discard(greenlet)

def cleanup(self):
for i, task in enumerate(self._tasks):
marked_for_discard = set()
for task in self.greenlets:
if task.dead:
# Mark for delection
self._tasks[i] = None
# Remove items marked for deletion
self._tasks = [t for t in self._tasks if t]


# Task management


def tasks():
"""
List of tasks in default taskmaster
Returns:
list: List of tasks in default taskmaster
"""
global DEFAULT_TASK_MASTER
return DEFAULT_TASK_MASTER.tasks
marked_for_discard.add(task)


def dictionary():
"""
Dictionary of tasks in default taskmaster
Returns:
dict: Dictionary of tasks in default taskmaster
"""
global DEFAULT_TASK_MASTER
return DEFAULT_TASK_MASTER.dict


def states():
"""
Dictionary of TaskThread.state dictionaries. Key is TaskThread ID.
Returns:
dict: Dictionary of task states in default taskmaster
"""
global DEFAULT_TASK_MASTER
return DEFAULT_TASK_MASTER.states


def cleanup_tasks():
"""Remove all finished tasks from the task list"""
global DEFAULT_TASK_MASTER
return DEFAULT_TASK_MASTER.cleanup()


def remove_task(task_id: str):
"""Remove a particular task from the task list

Arguments:
task_id {str} -- ID of the target task
"""
global DEFAULT_TASK_MASTER
return DEFAULT_TASK_MASTER.remove(task_id)
for greenlet in marked_for_discard:
self.discard(greenlet)


# Operations on the current task
Expand Down Expand Up @@ -161,17 +121,23 @@ def taskify(f):
A decorator that wraps the passed in function
and surpresses exceptions should one occur
"""
global default_pool

@wraps(f)
def wrapped(*args, **kwargs):
task = DEFAULT_TASK_MASTER.new(
task = default_pool.spawn(
f, *args, **kwargs
) # Append to parent object's task list
task.start() # Start the function
return task

return wrapped


# Create our default, protected, module-level task pool
DEFAULT_TASK_MASTER = TaskMaster()
default_pool = Pool()

tasks = default_pool.tasks
to_dict = default_pool.to_dict
states = default_pool.states
cleanup = default_pool.cleanup
discard_id = default_pool.discard_id
15 changes: 8 additions & 7 deletions labthings/core/tasks/thread.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from gevent import Greenlet, GreenletExit
from gevent.thread import get_ident
from gevent.event import Event
from flask import copy_current_request_context, has_request_context
import datetime
import logging
import traceback
Expand All @@ -18,13 +19,8 @@ class TaskKillException(Exception):


class TaskThread(Greenlet):
def __init__(self, target=None, args=None, kwargs=None):
def __init__(self, target, *args, **kwargs):
Greenlet.__init__(self)
# Handle arguments
if args is None:
args = ()
if kwargs is None:
kwargs = {}

# A UUID for the TaskThread (not the same as the threading.Thread ident)
self._ID = uuid.uuid4() # Task ID
Expand Down Expand Up @@ -83,7 +79,12 @@ def update_data(self, data: dict):
self.data.update(data)

def _run(self): # pylint: disable=E0202
return self._thread_proc(self._target)(*self._args, **self._kwargs)
# copy_current_request_context allows threads to access flask current_app
if has_request_context():
target = copy_current_request_context(self._target)
else:
target = self._target
return self._thread_proc(target)(*self._args, **self._kwargs)

def _thread_proc(self, f):
"""
Expand Down
8 changes: 4 additions & 4 deletions labthings/server/default_views/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def get(self, task_id):

Includes progress and intermediate data.
"""
task_dict = tasks.dictionary()
task_dict = tasks.to_dict()

if not task_id in task_dict:
if task_id not in task_dict:
return abort(404) # 404 Not Found

task = task_dict.get(task_id)
Expand All @@ -47,9 +47,9 @@ def delete(self, task_id):

If the task is finished, deletes its entry.
"""
task_dict = tasks.dictionary()
task_dict = tasks.to_dict()

if not task_id in task_dict:
if task_id not in task_dict:
return abort(404) # 404 Not Found

task = task_dict.get(task_id)
Expand Down
17 changes: 7 additions & 10 deletions tests/test_core_tasks_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,10 @@ def test_tasks_list():

def test_tasks_dict():
assert all(
[
isinstance(task_obj, gevent.Greenlet)
for task_obj in tasks.dictionary().values()
]
[isinstance(task_obj, gevent.Greenlet) for task_obj in tasks.to_dict().values()]
)

assert all([k == str(t.id) for k, t in tasks.dictionary().items()])
assert all([k == str(t.id) for k, t in tasks.to_dict().items()])


def test_task_states():
Expand All @@ -80,16 +77,16 @@ def test_task_states():
assert all(k in state for k in state_keys)


def test_remove_task():
def test_discard_id():
def task_func():
pass

task_obj = tasks.taskify(task_func)()
assert str(task_obj.id) in tasks.dictionary()
assert str(task_obj.id) in tasks.to_dict()
task_obj.join()

tasks.remove_task(task_obj.id)
assert not str(task_obj.id) in tasks.dictionary()
tasks.discard_id(task_obj.id)
assert not str(task_obj.id) in tasks.to_dict()


def test_cleanup_task():
Expand All @@ -105,5 +102,5 @@ def task_func():
gevent.joinall(tasks.tasks())

assert len(tasks.tasks()) > 0
tasks.cleanup_tasks()
tasks.cleanup()
assert len(tasks.tasks()) == 0
6 changes: 2 additions & 4 deletions tests/test_core_tasks_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ def test_task_with_args():
def task_func(arg, kwarg=False):
pass

task_obj = thread.TaskThread(
target=task_func, args=("String arg",), kwargs={"kwarg": True}
)
task_obj = thread.TaskThread(task_func, "String arg", kwarg=True)
assert isinstance(task_obj, gevent.Greenlet)
assert task_obj._target == task_func
assert task_obj._args == ("String arg",)
Expand Down Expand Up @@ -121,7 +119,7 @@ def test_task_log_without_thread():

def test_task_log_with_incorrect_thread():

task_obj = thread.TaskThread()
task_obj = thread.TaskThread(None)
task_log_handler = thread.ThreadLogHandler(thread=task_obj)

# Should always return False if called from outside the log handlers thread
Expand Down
4 changes: 2 additions & 2 deletions tests/test_server_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def func():

def test_marshal_task(app_ctx):
def func():
return TaskThread()
return TaskThread(None)

wrapped_func = decorators.marshal_task(func)

Expand All @@ -102,7 +102,7 @@ def func():

def test_marshal_task_response_tuple(app_ctx):
def func():
return (TaskThread(), 201, {})
return (TaskThread(None), 201, {})

wrapped_func = decorators.marshal_task(func)

Expand Down
10 changes: 5 additions & 5 deletions tests/test_server_default_views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from labthings.core.tasks import taskify, dictionary
from labthings.core import tasks

import gevent

Expand All @@ -22,7 +22,7 @@ def test_tasks_list(thing_client):
def task_func():
pass

task_obj = taskify(task_func)()
task_obj = tasks.taskify(task_func)()

with thing_client as c:
response = c.get("/tasks").json
Expand All @@ -34,7 +34,7 @@ def test_task_representation(thing_client):
def task_func():
pass

task_obj = taskify(task_func)()
task_obj = tasks.taskify(task_func)()
task_id = str(task_obj.id)

with thing_client as c:
Expand All @@ -52,12 +52,12 @@ def task_func():
while True:
gevent.sleep(0)

task_obj = taskify(task_func)()
task_obj = tasks.taskify(task_func)()
task_id = str(task_obj.id)

# Wait for task to start
task_obj.started_event.wait()
assert task_id in dictionary()
assert task_id in tasks.to_dict()

# Send a DELETE request to terminate the task
with thing_client as c:
Expand Down
Loading