Skip to content

Commit

Permalink
Merge c61a371 into f75b475
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jan 9, 2021
2 parents f75b475 + c61a371 commit 8eb0555
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 21 deletions.
46 changes: 39 additions & 7 deletions WDL/runtime/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __exit__(self, *args) -> None:
self._flocker.__exit__(*args)

def get(
self, key: str, output_types: Env.Bindings[Type.Base], inputs: Env.Bindings[Value.Base]
self, key: str, inputs: Env.Bindings[Value.Base], output_types: Env.Bindings[Type.Base]
) -> Optional[Env.Bindings[Value.Base]]:
"""
Resolve cache key to call outputs, if available, or None. When matching outputs are found, check to ensure the
Expand Down Expand Up @@ -214,7 +214,7 @@ def put_download(
) -> str:
"""
Move the downloaded file to the cache location & return the new path; or if the uri isn't
cacheable, return the given path.
cacheable, memoize the association and return the given path.
"""
if directory:
uri = uri.rstrip("/")
Expand Down Expand Up @@ -243,23 +243,35 @@ def put_download(
if directory and os.path.isdir(p):
rmtree_atomic(p)
os.renames(filename, p)
self.flock(p)
# the renames() op should be atomic, because the download operation should have
# been run under the cache directory (download.py:run_cached)
logger.info(_("stored in download cache", uri=uri, cache_path=p))
ans = p
if not p:
with self._lock:
(self._workflow_directory_downloads if directory else self._workflow_downloads)[
uri
] = ans
self.flock(ans)
self.memo_download(uri, filename, directory=directory)
return ans

def download_cacheable(self, uri: str, directory: bool = False) -> Optional[str]:
if not self._cfg["download_cache"].get_bool("put"):
return None
return self.download_path(uri, directory=directory)

def memo_download(
self,
uri: str,
filename: str,
directory: bool = False,
) -> None:
"""
Memoize (for the lifetime of self) that filename is a local copy of uri; flock it as well.
"""
with self._lock:
memo = self._workflow_directory_downloads if directory else self._workflow_downloads
if uri not in memo:
memo[uri] = filename
self.flock(filename)

def flock(self, filename: str, exclusive: bool = False) -> None:
self._flocker.flock(filename, update_atime=True, exclusive=exclusive)

Expand Down Expand Up @@ -317,3 +329,23 @@ def check_one(v: Union[Value.File, Value.Directory]):
return True
except StopIteration:
return False


_backends_lock = Lock()
_backends = {}


def new(cfg: config.Loader, logger: logging.Logger) -> CallCache:
"""
Instantiate a CallCache, either the built-in implementation or a plugin-defined subclass per
the configuration.
"""
global _backends
with _backends_lock:
if not _backends:
for plugin_name, plugin_cls in config.load_plugins(cfg, "cache_backend"):
_backends[plugin_name] = plugin_cls
impl_cls = _backends[cfg["call_cache"]["backend"]]
ans = impl_cls(cfg, logger)
assert isinstance(ans, CallCache)
return ans
7 changes: 7 additions & 0 deletions WDL/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,13 @@ def _parse_list(v: str) -> List[Any]:
value="WDL.runtime.task_container:SwarmContainer",
)
],
"cache_backend": [
importlib_metadata.EntryPoint(
group="miniwdl.plugin.cache_backend",
name="dir",
value="WDL.runtime.cache:CallCache",
)
],
}


Expand Down
4 changes: 4 additions & 0 deletions WDL/runtime/config_templates/default.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ docker = "google/cloud-sdk@sha256:64ddc4e5d3f7fdc5a198c8acf1c361702994462dbe79e7
put = false
# enable retrieval of cached outputs
get = false
# pluggable implementation: the default stores cache JSON files in a local directory, and checks
# posix mtimes of any local files referenced in the cached inputs/outputs (invalidating the cache
# entry if any referenced files are newer)
backend = dir
dir = ~/.cache/miniwdl


Expand Down
10 changes: 3 additions & 7 deletions WDL/runtime/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .._util import StructuredLogMessage as _
from . import config, _statusbar
from .download import able as downloadable, run_cached as download
from .cache import CallCache
from .cache import CallCache, new as new_call_cache
from .error import OutputError, Interrupted, Terminated, CommandFailed, RunFailed, error_json
from .task_container import TaskContainer, new as new_task_container

Expand Down Expand Up @@ -95,7 +95,7 @@ def run_local_task(
write_values_json(inputs, os.path.join(run_dir, "inputs.json"))

if not _run_id_stack:
cache = _cache or cleanup.enter_context(CallCache(cfg, logger))
cache = _cache or cleanup.enter_context(new_call_cache(cfg, logger))
cache.flock(logfile, exclusive=True) # no containing workflow; flock task.log
else:
cache = _cache
Expand All @@ -104,11 +104,7 @@ def run_local_task(
container = None
try:
cache_key = f"{task.name}/{task.digest}/{Value.digest_env(inputs)}"
cached = cache.get(
key=cache_key,
output_types=task.effective_outputs,
inputs=inputs,
)
cached = cache.get(cache_key, inputs, task.effective_outputs)
if cached is not None:
for decl in task.outputs:
v = cached[decl.name]
Expand Down
10 changes: 3 additions & 7 deletions WDL/runtime/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
)
from .._util import StructuredLogMessage as _
from . import config, _statusbar
from .cache import CallCache
from .cache import CallCache, new as new_call_cache
from .error import RunFailed, Terminated, error_json


Expand Down Expand Up @@ -651,13 +651,9 @@ def run_local_workflow(
write_values_json(inputs, os.path.join(run_dir, "inputs.json"), namespace=workflow.name)

# query call cache
cache = _cache if _cache else cleanup.enter_context(CallCache(cfg, logger))
cache = _cache if _cache else cleanup.enter_context(new_call_cache(cfg, logger))
cache_key = f"{workflow.name}/{workflow.digest}/{Value.digest_env(inputs)}"
cached = cache.get(
key=cache_key,
output_types=workflow.effective_outputs,
inputs=inputs,
)
cached = cache.get(cache_key, inputs, workflow.effective_outputs)
if cached is not None:
for outp in workflow.effective_outputs:
v = cached[outp.name]
Expand Down

0 comments on commit 8eb0555

Please sign in to comment.