diff --git a/HISTORY.rst b/HISTORY.rst index ee16445..21f0b1c 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -1,5 +1,14 @@ .. :changelog: +.. Unreleased Changes + +0.4.0 (2018-04-18) +------------------ +* Auto-versioning now happens via ``setuptools_scm``, replacing previous calls to ``natcap.versioner``. +* Added an option to `TaskGraph` constructor to allow negative values in the `n_workers` argument to indicate that the entire object should run in the main thread. A value of 0 will indicate that no multiprocessing will be used but concurrency will be allowed for non-blocking `add_task`. +* Added an abstract class `task.EncapsulatedTaskOp` that can be used to instance a class that needs scope in order to be used as an operation passed to a process. The advantage of using `EncapsulatedTaskOp` is that the `__name__` hash used by `TaskGraph` to determine if a task is unique is calculated in the superclass and the subclass need only worry about implementation of `__call__`. +* Added a `priority` optional scalar argument to `TaskGraph.add_task` to indicates the priority preference of the task to be executed. A higher priority task whose dependencies are satisfied will executed before one with a lower priority. + 0.3.0 (2017-11-17) ------------------ * Refactor of core scheduler. Old scheduler used asynchronicity to attempt to test if a Task was complete, occasionally testing all Tasks in potential work queue per task completion. Scheduler now uses bookkeeping to keep track of all dependencies and submits tasks for work only when all dependencies are satisfied. diff --git a/setup.py b/setup.py index c94bf4e..96184bf 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,9 @@ setup( name='taskgraph', - natcap_version='taskgraph/version.py', + use_scm_version={'version_scheme': 'post-release', + 'local_scheme': 'node-and-date'}, + setup_requires=['setuptools_scm'], description='Parallel task graph framework.', long_description=README, maintainer='Rich Sharp', diff --git a/taskgraph/Task.py b/taskgraph/Task.py index 640973f..26304c9 100644 --- a/taskgraph/Task.py +++ b/taskgraph/Task.py @@ -1,4 +1,5 @@ """Task graph framework.""" +import heapq import pprint import types import collections @@ -12,6 +13,7 @@ import errno import Queue import inspect +import abc try: import psutil @@ -35,24 +37,36 @@ def __init__(self, taskgraph_cache_dir_path, n_workers): taskgraph_cache_dir_path (string): path to a directory that either contains a taskgraph cache from a previous instance or will create a new one if none exists. - n_workers (int): number of parallel workers to allow during - task graph execution. If set to 0, use current process. + n_workers (int): number of parallel *subprocess* workers to allow + during task graph execution. If set to 0, don't use + subprocesses. If set to <0, use only the main thread for any + execution and scheduling. In the case of the latter, + `add_task` will be a blocking call. """ # the work queue is the feeder to active worker threads self.taskgraph_cache_dir_path = taskgraph_cache_dir_path - self.work_queue = Queue.Queue() self.n_workers = n_workers - # used to synchronize a pass through potential tasks to add to the - # work queue - self.process_pending_tasks_event = threading.Event() - # keep track if the task graph has been forcibly terminated self.terminated = False + # use this to keep track of all the tasks added to the graph by their + # task ids. Used to determine if an identical task has been added + # to the taskgraph during `add_task` + self.task_id_map = dict() + + # used to remember if task_graph has been closed + self.closed = False + # if n_workers > 0 this will be a multiprocessing pool used to execute # the __call__ functions in Tasks self.worker_pool = None + + # no need to set up schedulers if n_workers is single threaded + if n_workers < 0: + return + + # set up multrpocessing if n_workers > 0 if n_workers > 0: self.worker_pool = multiprocessing.Pool(n_workers) if HAS_PSUTIL: @@ -67,17 +81,10 @@ def __init__(self, taskgraph_cache_dir_path, n_workers): "to nice %s. This might be a bug in `psutil` so " "it should be okay to ignore.") - # use this to keep track of all the tasks added to the graph by their - # task ids. Used to determine if an identical task has been added - # to the taskgraph during `add_task` - self.task_id_map = dict() - - # used to remember if task_graph has been closed - self.closed = False - - # this is the set of threads created by taskgraph - self.thread_set = set() - + # used to synchronize a pass through potential tasks to add to the + # work queue + self.work_queue = Queue.Queue() + self.worker_semaphore = threading.Semaphore(max(1, n_workers)) # launch threads to manage the workers for thread_id in xrange(max(1, n_workers)): worker_thread = threading.Thread( @@ -85,60 +92,47 @@ def __init__(self, taskgraph_cache_dir_path, n_workers): name='taskgraph_worker_thread_%d' % thread_id) worker_thread.daemon = True worker_thread.start() - self.thread_set.add(worker_thread) - - # tasks that get passed right to add_task get put in this queue for - # scheduling - self.pending_task_queue = Queue.Queue() - - # launch thread to monitor the pending task set - pending_task_scheduler = threading.Thread( - target=self._process_pending_tasks, - name='_pending_task_scheduler') - pending_task_scheduler.daemon = True - pending_task_scheduler.start() - self.thread_set.add(pending_task_scheduler) + # tasks that get passed add_task get put in this queue for scheduling self.waiting_task_queue = Queue.Queue() - waiting_task_scheduler = threading.Thread( target=self._process_waiting_tasks, name='_waiting_task_scheduler') waiting_task_scheduler.daemon = True waiting_task_scheduler.start() - self.thread_set.add(waiting_task_scheduler) + + # tasks in the work ready queue have dependencies satisfied but need + # priority scheduling + self.work_ready_queue = Queue.Queue() + priority_task_scheduler = threading.Thread( + target=self._schedule_priority_tasks, + name='_priority_task_scheduler') + priority_task_scheduler.daemon = True + priority_task_scheduler.start() def _task_worker(self): """Execute and manage Task objects. - - This worker extracts (task object, args, kwargs) tuples from - `self.work_queue`, processes the return value to ensure either a - successful completion OR handle an error. On successful completion - the task's hash and dependent files are recorded in TaskGraph's - cache structure to test and prevent future-re-executions. """ for task in iter(self.work_queue.get, 'STOP'): try: - if not task.is_precalculated(): - target_path_stats = task._call() - else: - task._task_complete_event.set() - # task complete, signal to pending task scheduler that this - # task is complete + # precondition: task wouldn't be in queue if it were + # precalculated + task._call() + self.worker_semaphore.release() self.waiting_task_queue.put((task, 'done')) - except Exception as subprocess_exception: + except Exception: # An error occurred on a call, terminate the taskgraph LOGGER.exception( 'A taskgraph _task_worker failed on Task ' - '%s with exception "%s". ' - 'Terminating taskgraph.', task, subprocess_exception) + '%s. Terminating taskgraph.', task) self._terminate() raise def add_task( self, func=None, args=None, kwargs=None, task_name=None, target_path_list=None, ignore_path_list=None, - dependent_task_list=None, ignore_directories=True): + dependent_task_list=None, ignore_directories=True, + priority=0): """Add a task to the task graph. Parameters: @@ -160,6 +154,11 @@ def add_task( ignore_directories (boolean): if the existence/timestamp of any directories discovered in args or kwargs is used as part of the work token hash. + priority (numeric): the priority of a task is considered when + there is more than one task whose dependencies have been + met and are ready for scheduling. Tasks are inserted into the + work queue in order of decreasing priority. This value can be + positive, negative, and/or floating point. Returns: Task which was just added to the graph or an existing Task that @@ -189,7 +188,7 @@ def add_task( new_task = Task( task_name, func, args, kwargs, target_path_list, ignore_path_list, dependent_task_list, ignore_directories, - self.worker_pool, self.taskgraph_cache_dir_path) + self.worker_pool, self.taskgraph_cache_dir_path, priority) task_hash = new_task.task_hash # it may be this task was already created in an earlier call, @@ -198,7 +197,20 @@ def add_task( return self.task_id_map[task_hash] self.task_id_map[task_hash] = new_task - self.pending_task_queue.put(new_task) + + if self.n_workers < 0: + # call directly if single threaded + if not new_task.is_precalculated(): + new_task._call() + else: + # send to scheduler + if not new_task.is_precalculated(): + self.waiting_task_queue.put((new_task, 'wait')) + else: + # this is a shortcut to clear pre-calculated tasks + new_task._task_complete_event.set() + self.waiting_task_queue.put((new_task, 'done')) + return new_task except Exception: @@ -206,35 +218,43 @@ def add_task( self._terminate() raise - def _process_pending_tasks(self): - """Process pending task queue, send ready tasks to work queue. + def _schedule_priority_tasks(self): + """Priority schedules the `self.work_ready` queue. - There are two reasons a task will be on the pending_task_queue. One - is to potentially process it for work. The other is to alert that - the task is complete and any tasks that were dependent on it may - now be processed. + Reads the `self.work_ready` queue and feeds in highest priority tasks + when the self.work_queue is ready for them. """ - tasks_sent_to_work = set() - for task in iter(self.pending_task_queue.get, 'STOP'): - # invariant: a task coming in was put in the queue before it was - # complete and because it was a dependent task of another task - # that completed. OR a task is complete and alerting that any - # tasks that were dependent on it can be processed. - - # it's a new task, check and see if its dependencies are complete - outstanding_dependent_task_list = [ - dep_task for dep_task in task.dependent_task_list - if not dep_task.is_complete()] - - if outstanding_dependent_task_list: - # if outstanding tasks, delay execution and put a reminder - # that this task is dependent on another - self.waiting_task_queue.put((task, 'wait')) - elif task.task_hash not in tasks_sent_to_work: - # otherwise if not already sent to work, put in the work queue - # and record it was sent - tasks_sent_to_work.add(task.task_hash) - self.work_queue.put(task) + stopped = False + priority_queue = [] + while not stopped: + while True: + try: + # only block if the priority queque is empty + task = self.work_ready_queue.get(not priority_queue) + if task == 'STOP': + # encounter STOP so break and don't get more elements + stopped = True + break + # push task to priority queue + heapq.heappush(priority_queue, task) + except Queue.Empty: + # this triggers when work_ready_queue is empty and + # there's something in the work_ready_queue + break + # only put elements if there are workers available + self.worker_semaphore.acquire() + while priority_queue: + # push high priority on the queue until queue is full + # or if thread is stopped, drain the priority queue + self.work_queue.put(priority_queue[0]) + heapq.heappop(priority_queue) + if not stopped: + # by stopping after one put, we can give the chance for + # other higher priority tasks to flow in + break + # got a 'STOP' so signal worker threads to stop too + for _ in xrange(max(1, self.n_workers)): + self.work_queue.put('STOP') def _process_waiting_tasks(self): """Process any tasks that are waiting on dependencies. @@ -242,30 +262,26 @@ def _process_waiting_tasks(self): This worker monitors the self.waiting_task_queue Queue and looks for (task, 'wait'), or (task, 'done') tuples. - If mode is 'wait' the task is indexed locally with reference to - its incomplete tasks. If its depedent tasks are complete, the - task is sent to the work queue. If mode is 'done' this signals the - worker to re-'wait' any task that was dependent on the one that - arrived in the queue. + If mode is 'wait' the task is indexed locally with reference to + its incomplete tasks. If its dependent tasks are complete, the + task is sent to the work queue. If mode is 'done' this signals the + worker to re-'wait' any task that was dependent on the one that + arrived in the queue. """ task_dependent_map = collections.defaultdict(set) dependent_task_map = collections.defaultdict(set) completed_tasks = set() for task, mode in iter(self.waiting_task_queue.get, 'STOP'): + tasks_ready_to_work = set() if mode == 'wait': - # invariant: task has come directly from `add_task` and has - # been determined that is has at least one unsatisfied - # dependency - + # see if this task's dependencies are satisfied, if so send + # to work. outstanding_dependent_task_list = [ dep_task for dep_task in task.dependent_task_list if dep_task not in completed_tasks] - # possible a dependency has been satisfied since `add_task` - # was able to add this task to the waiting queue. if not outstanding_dependent_task_list: # if nothing is outstanding, send to work queue - self.work_queue.put(task) - continue + tasks_ready_to_work.add(task) # there are unresolved tasks that the waiting process # scheduler has not been notified of. Record dependencies. @@ -279,6 +295,10 @@ def _process_waiting_tasks(self): # invariant: task has not previously been sent as a 'done' # notification and task is done. completed_tasks.add(task) + if task not in task_dependent_map: + # this can occur if add_task identifies task is complete + # before any other analysis. + continue for waiting_task in task_dependent_map[task]: # remove `task` from the set of tasks that `waiting_task` # was waiting on. @@ -288,12 +308,15 @@ def _process_waiting_tasks(self): if not dependent_task_map[waiting_task]: # if we removed the last task we can put it to the # work queue - self.work_queue.put(waiting_task) + tasks_ready_to_work.add(waiting_task) del task_dependent_map[task] + for ready_task in sorted( + tasks_ready_to_work, key=lambda x: x.priority): + self.work_ready_queue.put(ready_task) + tasks_ready_to_work = None # if we got here, the waiting task queue is shut down, pass signal - # to the workers - for _ in xrange(max(1, self.n_workers)): - self.work_queue.put('STOP') + # to the lower queue + self.work_ready_queue.put('STOP') def join(self, timeout=None): """Join all threads in the graph. @@ -306,6 +329,9 @@ def join(self, timeout=None): Returns: True if successful join, False if timed out. """ + # if single threaded, nothing to join. + if self.n_workers < 0: + return True try: timedout = False for task in self.task_id_map.itervalues(): @@ -317,18 +343,16 @@ def join(self, timeout=None): if self.closed: # inject sentinels to the queues self.waiting_task_queue.put('STOP') - for _ in xrange(max(1, self.n_workers)): - self.work_queue.put('STOP') return not timedout - except Exception as e: + except Exception: # If there's an exception on a join it means that a task failed # to execute correctly. Print a helpful message then terminate the # taskgraph object. - LOGGER.error( - "Exception \"%s\" raised when joining task %s. It's possible " + LOGGER.exception( + "Exception raised when joining task %s. It's possible " "that this task did not cause the exception, rather another " "exception terminated the task_graph. Check the log to see " - "if there are other exceptions.", e, task) + "if there are other exceptions.", task) self._terminate() raise @@ -337,11 +361,9 @@ def close(self): if self.closed: return self.closed = True - self.pending_task_queue.put('STOP') def _terminate(self): """Forcefully terminate remaining task graph computation.""" - LOGGER.debug("********* calling _terminate") if self.terminated: return self.close() @@ -358,7 +380,7 @@ class Task(object): def __init__( self, task_name, func, args, kwargs, target_path_list, ignore_path_list, dependent_task_list, ignore_directories, - worker_pool, cache_dir): + worker_pool, cache_dir, priority): """Make a Task. Parameters: @@ -386,6 +408,11 @@ def __init__( multiprocessing pool that can be used for `_call` execution. cache_dir (string): path to a directory to both write and expect data recorded from a previous Taskgraph run. + priority (numeric): the priority of a task is considered when + there is more than one task whose dependencies have been + met and are ready for scheduling. Tasks are inserted into the + work queue in order of decreasing priority. This value can be + positive, negative, and/or floating point. """ self.task_name = task_name self.func = func @@ -397,6 +424,8 @@ def __init__( self.ignore_path_list = ignore_path_list self.ignore_directories = ignore_directories self.worker_pool = worker_pool + # invert the priority since heapq goes smallest to largest + self.priority = -priority self.terminated = False self.exception_object = None @@ -445,10 +474,25 @@ def __init__( [x for x in self.task_hash[0:3]] + [self.task_hash + '.json'])) + def __eq__(self, other): + """Two tasks are equal if their hashes are equal.""" + if isinstance(self, other.__class__): + return self.task_hash == other.task_hash + return False + + def __ne__(self, other): + """Inverse of __eq__.""" + return not self.__eq__(other) + + def __lt__(self, other): + """Less than based on priority.""" + return self.priority < other.priority + def __str__(self): return "Task object %s:\n\n" % (id(self)) + pprint.pformat( { "task_name": self.task_name, + "priority": self.priority, "target_path_list": self.target_path_list, "dependent_task_list": self.dependent_task_list, "ignore_path_list": self.ignore_path_list, @@ -485,7 +529,7 @@ def _call(self): missing_target_paths = [ target_path for target_path in self.target_path_list if not os.path.exists(target_path)] - if len(missing_target_paths) > 0: + if missing_target_paths: raise RuntimeError( "The following paths were expected but not found " "after the function call: %s" % missing_target_paths) @@ -500,7 +544,7 @@ def _call(self): raise RuntimeError( "In Task: %s\nMissing expected target path results.\n" "Expected: %s\nObserved: %s\n" % ( - self.task_name, target_path_list, + self.task_name, self.target_path_list, result_target_path_set)) # otherwise record target path stats in a file located at @@ -516,7 +560,7 @@ def _call(self): # successful run, return target path stats return result_target_path_stats except Exception as e: - LOGGER.error("Exception %s in Task: %s" % (e, self)) + LOGGER.exception("Exception Task: %s", self) self._terminate(e) raise finally: @@ -549,16 +593,19 @@ def is_precalculated(self): True if the Task's target paths exist in the same state as the last recorded run. False otherwise. """ - if not os.path.exists(self.task_cache_path): - return False - with open(self.task_cache_path, 'rb') as task_cache_file: - result_target_path_stats = pickle.load(task_cache_file) - for path, modified_time, size in result_target_path_stats: - if not (os.path.exists(path) and - modified_time == os.path.getmtime(path) and - size == os.path.getsize(path)): + try: + if not os.path.exists(self.task_cache_path): return False - return True + with open(self.task_cache_path, 'rb') as task_cache_file: + result_target_path_stats = pickle.load(task_cache_file) + for path, modified_time, size in result_target_path_stats: + if not (os.path.exists(path) and + modified_time == os.path.getmtime(path) and + size == os.path.getsize(path)): + return False + return True + except EOFError: + return False def join(self, timeout=None): """Block until task is complete, raise exception if runtime failed.""" @@ -572,6 +619,35 @@ def _terminate(self, exception_object=None): self._task_complete_event.set() +class EncapsulatedTaskOp: + """Used as a superclass for Task operations that need closures. + + This class will automatically hash the subclass's __call__ method source + as well as the arguments to its __init__ function to calculate the + Task's unique hash. + """ + __metaclass__ = abc.ABCMeta + + def __init__(self, *args, **kwargs): + # try to get the source code of __call__ so task graph will recompute + # if the function has changed + args_as_str = str([args, kwargs]) + try: + # hash the args plus source code of __call__ + id_hash = hashlib.sha1(args_as_str + inspect.getsource( + self.__class__.__call__)).hexdigest() + except IOError: + # this will fail if the code is compiled, that's okay just do + # the args + id_hash = hashlib.sha1(args_as_str) + # prefix the classname + self.__name__ = '%s_%s' % (self.__class__.__name__, id_hash) + + @abc.abstractmethod + def __call__(self, *args, **kwargs): + pass + + def _get_file_stats(base_value, ignore_list, ignore_directories): """Iterate over any values that are filepaths by getting filestats. diff --git a/taskgraph/__init__.py b/taskgraph/__init__.py index f2891ab..c930e78 100644 --- a/taskgraph/__init__.py +++ b/taskgraph/__init__.py @@ -1,7 +1,20 @@ """taskgraph module.""" -from taskgraph.Task import TaskGraph -from taskgraph.Task import Task +import pkg_resources +from taskgraph.Task import * -import natcap.versioner - -__version__ = natcap.versioner.get_version('taskgraph') +try: + __version__ = pkg_resources.get_distribution(__name__).version +except pkg_resources.DistributionNotFound: + # Package is not installed, so the package metadata is not available. + # This should only happen if a package is importable but the package + # metadata is not, as might happen if someone copied files into their + # system site-packages or they're importing this package from the CWD. + raise RuntimeError( + "Could not load version from installed metadata.\n\n" + "This is often because the package was not installed properly. " + "Ensure that the package is installed in a way that the metadata is " + "maintained. Calls to ``pip`` and this package's ``setup.py`` " + "maintain metadata. Examples include:\n" + " * python setup.py install\n" + " * python setup.py develop\n" + " * pip install ") diff --git a/tests/test_task.py b/tests/test_task.py index 14acf97..6540d4d 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -9,6 +9,7 @@ import logging import taskgraph +import mock logging.basicConfig(level=logging.DEBUG) @@ -37,7 +38,6 @@ def _div_by_zero(): """Divide by zero to raise an exception.""" return 1/0 - class TaskGraphTests(unittest.TestCase): """Tests for the taskgraph.""" @@ -51,6 +51,27 @@ def tearDown(self): """Overriding tearDown function to remove temporary directory.""" shutil.rmtree(self.workspace_dir) + def test_version_loaded(self): + """TaskGraph: verify we can load the version.""" + try: + import taskgraph + # Verifies that there's a version attribute and it has a value. + self.assertTrue(len(taskgraph.__version__) > 0) + except Exception as error: + self.fail('Could not load the taskgraph version as expected.') + + def test_version_not_loaded(self): + """TaskGraph: verify exception when not installed.""" + from pkg_resources import DistributionNotFound + import taskgraph + + with mock.patch('taskgraph.pkg_resources.get_distribution', + side_effect=DistributionNotFound('taskgraph')): + with self.assertRaises(RuntimeError): + # RuntimeError is a side effect of `import taskgraph`, so we + # reload it to retrigger the metadata load. + taskgraph = reload(taskgraph) + def test_single_task(self): """TaskGraph: Test a single task.""" task_graph = taskgraph.TaskGraph(self.workspace_dir, 0) @@ -61,6 +82,7 @@ def test_single_task(self): func=_create_list_on_disk, args=(value, list_len, target_path), target_path_list=[target_path]) + task_graph.close() task_graph.join() result = pickle.load(open(target_path, 'rb')) self.assertEqual(result, [value]*list_len) @@ -71,6 +93,7 @@ def test_timeout_task(self): target_path = os.path.join(self.workspace_dir, '1000.dat') _ = task_graph.add_task( func=_long_running_function,) + task_graph.close() timedout = not task_graph.join(0.5) # this should timeout since function runs for 5 seconds self.assertTrue(timedout) @@ -97,6 +120,7 @@ def test_precomputed_task(self): func=_create_list_on_disk, args=(value, list_len, target_path), target_path_list=[target_path]) + task_graph.close() task_graph.join() # taskgraph shouldn't have recomputed the result @@ -146,16 +170,70 @@ def test_task_chain(self): args=(target_a_path, result_path, result_2_path), target_path_list=[result_2_path], dependent_task_list=[task_a, sum_task]) + task_graph.close() + sum_3_task.join() + result3 = pickle.load(open(result_2_path, 'rb')) + expected_result = [(value_a*2+value_b)]*list_len + self.assertEqual(result3, expected_result) + task_graph.join() + + + def test_task_chain_single_thread(self): + """TaskGraph: Test a single threaded task chain.""" + task_graph = taskgraph.TaskGraph(self.workspace_dir, -1) + target_a_path = os.path.join(self.workspace_dir, 'a.dat') + target_b_path = os.path.join(self.workspace_dir, 'b.dat') + result_path = os.path.join(self.workspace_dir, 'result.dat') + result_2_path = os.path.join(self.workspace_dir, 'result2.dat') + value_a = 5 + value_b = 10 + list_len = 10 + task_a = task_graph.add_task( + func=_create_list_on_disk, + args=(value_a, list_len, target_a_path), + target_path_list=[target_a_path]) + task_b = task_graph.add_task( + func=_create_list_on_disk, + args=(value_b, list_len, target_b_path), + target_path_list=[target_b_path]) + sum_task = task_graph.add_task( + func=_sum_lists_from_disk, + args=(target_a_path, target_b_path, result_path), + target_path_list=[result_path], + dependent_task_list=[task_a, task_b]) + sum_task.join() + + result = pickle.load(open(result_path, 'rb')) + self.assertEqual(result, [value_a+value_b]*list_len) + + sum_2_task = task_graph.add_task( + func=_sum_lists_from_disk, + args=(target_a_path, result_path, result_2_path), + target_path_list=[result_2_path], + dependent_task_list=[task_a, sum_task]) + sum_2_task.join() + result2 = pickle.load(open(result_2_path, 'rb')) + expected_result = [(value_a*2+value_b)]*list_len + self.assertEqual(result2, expected_result) + + sum_3_task = task_graph.add_task( + func=_sum_lists_from_disk, + args=(target_a_path, result_path, result_2_path), + target_path_list=[result_2_path], + dependent_task_list=[task_a, sum_task]) + task_graph.close() sum_3_task.join() result3 = pickle.load(open(result_2_path, 'rb')) expected_result = [(value_a*2+value_b)]*list_len self.assertEqual(result3, expected_result) + task_graph.join() def test_broken_task(self): """TaskGraph: Test that a task with an exception won't hang.""" task_graph = taskgraph.TaskGraph(self.workspace_dir, 1) _ = task_graph.add_task( func=_div_by_zero, task_name='test_broken_task') + task_graph.close() with self.assertRaises(RuntimeError): task_graph.join() file_results = glob.glob(os.path.join(self.workspace_dir, '*')) @@ -176,6 +254,7 @@ def test_broken_task_chain(self): args=(value, list_len, target_path), target_path_list=[target_path], dependent_task_list=[base_task]) + task_graph.close() with self.assertRaises(RuntimeError): task_graph.join() file_results = glob.glob(os.path.join(self.workspace_dir, '*')) @@ -186,6 +265,7 @@ def test_empty_task(self): """TaskGraph: Test an empty task.""" task_graph = taskgraph.TaskGraph(self.workspace_dir, 0) _ = task_graph.add_task() + task_graph.close() task_graph.join() file_results = glob.glob(os.path.join(self.workspace_dir, '*')) # we should have a file in there that's the token @@ -203,6 +283,7 @@ def test_closed_graph(self): func=_create_list_on_disk, args=(value, list_len, target_path), target_path_list=[target_path]) + task_graph.join() def test_single_task_multiprocessing(self): """TaskGraph: Test a single task with multiprocessing.""" @@ -214,6 +295,7 @@ def test_single_task_multiprocessing(self): func=_create_list_on_disk, args=(value, list_len, target_path), target_path_list=[target_path]) + task_graph.close() task_graph.join() result = pickle.load(open(target_path, 'rb')) self.assertEqual(result, [value]*list_len) @@ -240,3 +322,53 @@ def test_get_file_stats(self): result = list(_get_file_stats(u'foo', [], False)) self.assertEqual(result, []) + + def test_encapsulatedtaskop(self): + """TaskGraph: Test abstract closure task class.""" + from taskgraph.Task import EncapsulatedTaskOp + + class TestAbstract(EncapsulatedTaskOp): + def __init__(self): + pass + + # __call__ is abstract so TypeError since it's not implemented + with self.assertRaises(TypeError): + x = TestAbstract() + + class TestA(EncapsulatedTaskOp): + def __call__(self, x): + return x + + class TestB(EncapsulatedTaskOp): + def __call__(self, x): + return x + + # TestA and TestB should be different because of different class names + a = TestA() + b = TestB() + # results of calls should be the same + self.assertEqual(a.__call__(5), b.__call__(5)) + self.assertNotEqual(a.__name__, b.__name__) + + # two instances with same args should be the same + self.assertEqual(TestA().__name__, TestA().__name__) + + # redefine TestA so we get a different hashed __name__ + class TestA(EncapsulatedTaskOp): + def __call__(self, x): + return x*x + + new_a = TestA() + self.assertNotEqual(a.__name__, new_a.__name__) + + # change internal class constructor to get different hashes + class TestA(EncapsulatedTaskOp): + def __init__(self, q): + super(TestA, self).__init__(q) + self.q = q + + def __call__(self, x): + return x*x + + init_new_a = TestA(1) + self.assertNotEqual(new_a.__name__, init_new_a.__name__)