diff --git a/gc3libs/core.py b/gc3libs/core.py index ecd418ce..71fd22df 100755 --- a/gc3libs/core.py +++ b/gc3libs/core.py @@ -1208,6 +1208,7 @@ def __init__(self, controller, tasks=list(), store=None, self._to_kill = [] self._core = controller self._store = store + self._tasks_by_id = {} for task in tasks: self.add(task) # public attributes @@ -1251,6 +1252,11 @@ def add(self, task): queue = self._get_queue_for_task(task) if not _contained(task, queue): queue.append(task) + if self._store: + try: + self._tasks_by_id[task.persistent_id] = task + except AttributeError: + gc3libs.log.warning("Task %s has no persistent ID!", task) task.attach(self) @@ -1260,9 +1266,23 @@ def remove(self, task): """ queue = self._get_queue_for_task(task) queue.remove(task) + if self._store: + try: + del self._tasks_by_id[task.persistent_id] + except AttributeError: + # already removed + pass task.detach() + def find_task_by_id(self, task_id): + """ + Return the task with the given persistent ID added to this `Engine` instance. + If no task has that ID, raise a `KeyError`. + """ + return self._tasks_by_id[task_id] + + def progress(self): """ Update state of all registered tasks and take appropriate action. diff --git a/gc3libs/testing/helpers.py b/gc3libs/testing/helpers.py index 006c9438..741074e9 100644 --- a/gc3libs/testing/helpers.py +++ b/gc3libs/testing/helpers.py @@ -24,7 +24,8 @@ # stdlib imports from contextlib import contextmanager import sys -from tempfile import NamedTemporaryFile +from tempfile import NamedTemporaryFile, mkdtemp +import shutil # GC3Pie imports from gc3libs import Application, Run @@ -73,6 +74,13 @@ def temporary_core( del cfg.TYPE_CONSTRUCTOR_MAP['noop'] +@contextmanager +def temporary_directory(*args, **kwargs): + tmpdir = mkdtemp(*args, **kwargs) + yield tmpdir + shutil.rmtree(tmpdir, ignore_errors=True) + + @contextmanager def temporary_engine(transition_graph=None, **kw): with temporary_core(transition_graph, **kw) as core: diff --git a/gc3libs/tests/test_engine.py b/gc3libs/tests/test_engine.py index 73ff1053..c6af3ba7 100644 --- a/gc3libs/tests/test_engine.py +++ b/gc3libs/tests/test_engine.py @@ -31,9 +31,10 @@ from gc3libs import Run, Application, create_engine import gc3libs.config from gc3libs.core import Core, Engine, MatchMaker +from gc3libs.persistence.filesystem import FilesystemStore from gc3libs.quantity import GB, hours -from gc3libs.testing.helpers import SimpleParallelTaskCollection, SimpleSequentialTaskCollection, SuccessfulApp, temporary_config, temporary_config_file, temporary_engine +from gc3libs.testing.helpers import SimpleParallelTaskCollection, SimpleSequentialTaskCollection, SuccessfulApp, temporary_config, temporary_config_file, temporary_core, temporary_directory, temporary_engine def test_engine_progress(num_jobs=1, transition_graph=None, max_iter=100): @@ -395,6 +396,57 @@ def test_create_engine_with_core_options(): assert engine._core.auto_enable_auth == False +def test_engine_find_task_by_id(): + """ + Test that saved tasks are can be retrieved from the Engine given their ID only. + """ + with temporary_core() as core: + with temporary_directory() as tmpdir: + store = FilesystemStore(tmpdir) + engine = Engine(core, store=store) + + task = SuccessfulApp() + store.save(task) + engine.add(task) + + task_id = task.persistent_id + assert_equal(task, engine.find_task_by_id(task_id)) + + +@raises(KeyError) +def test_engine_cannot_find_task_by_id_if_not_saved(): + """ + Test that *unsaved* tasks are cannot be retrieved from the Engine given their ID only. + """ + with temporary_core() as core: + with temporary_directory() as tmpdir: + store = FilesystemStore(tmpdir) + engine = Engine(core, store=store) + + task = SuccessfulApp() + engine.add(task) + + store.save(task) # guarantee it has a `.persistent_id` + task_id = task.persistent_id + engine.find_task_by_id(task_id) + + +@raises(KeyError) +def test_engine_cannot_find_task_by_id_if_no_store(): + """ + Test that `Engine.find_task_by_id` always raises `KeyError` if the Engine has no associated store. + """ + with temporary_engine() as engine: + with temporary_directory() as tmpdir: + store = FilesystemStore(tmpdir) + + task = SuccessfulApp() + engine.add(task) + + store.save(task) # guarantee it has a `.persistent_id` + task_id = task.persistent_id + engine.find_task_by_id(task_id) + if __name__ == "__main__": import pytest