Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jun 18, 2023
1 parent 1bc3776 commit 00ce21c
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 14 deletions.
2 changes: 1 addition & 1 deletion WDL/StdLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _f(
outfile: BinaryIO = outfile # pyre-ignore
serialize(v, outfile)
filename = outfile.name
chmod_R_plus(filename, file_bits=0o660)
chmod_R_plus(filename, file_bits=0o660) # ensure accessibility to downstream tasks
vfn = self._virtualize_filename(filename)
return Value.File(vfn)

Expand Down
52 changes: 46 additions & 6 deletions WDL/runtime/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import json
import os
import shutil
import hashlib
import logging
from pathlib import Path
from typing import Dict, Optional, Union
Expand All @@ -16,7 +16,7 @@

from . import config

from .. import Env, Value, Type
from .. import Env, Value, Type, Tree
from .._util import (
StructuredLogMessage as _,
FlockHolder,
Expand Down Expand Up @@ -102,7 +102,7 @@ def get(
# check that no files/directories referenced by the inputs & cached outputs are newer
# than the cache file itself
if _check_files_coherence(
self._cfg, self._logger, file_path, inputs
self._cfg, self._logger, file_path, inputs, skip_stdlib_written=True
) and _check_files_coherence(self._cfg, self._logger, file_path, cache):
return cache
else:
Expand Down Expand Up @@ -229,11 +229,11 @@ def put_download(
"""
if directory:
uri = uri.rstrip("/")
logger = logger.getChild("CallCache") if logger else self._logger
p = self.download_cacheable(uri, directory=directory)
if not p:
self.memo_download(uri, filename, directory=directory)
return filename
logger = logger.getChild("CallCache") if logger else self._logger
moved = False
# transient exclusive flock on whole cache directory (serializes entry add/remove)
with FlockHolder(logger) as transient:
Expand Down Expand Up @@ -316,7 +316,11 @@ def flock(


def _check_files_coherence(
cfg: config.Loader, logger: logging.Logger, cache_file: str, values: Env.Bindings[Value.Base]
cfg: config.Loader,
logger: logging.Logger,
cache_file: str,
values: Env.Bindings[Value.Base],
skip_stdlib_written: bool = False,
) -> bool:
"""
Verify that none of the files/directories referenced by values are newer than cache_file itself
Expand All @@ -338,7 +342,12 @@ def raiser(exc):

def check_one(v: Union[Value.File, Value.Directory]):
assert isinstance(v, (Value.File, Value.Directory))
if not downloadable(cfg, v.value):
if not downloadable(cfg, v.value) and not (
# special case to skip stdlib-written files because we cache them based on content
# digest rather than filename+mtime (see derive_call_cache_key below)
skip_stdlib_written
and _is_stdlib_written_file(v.value)
):
try:
if mtime(v.value) > cache_file_mtime:
raise StopIteration
Expand Down Expand Up @@ -388,3 +397,34 @@ def new(cfg: config.Loader, logger: logging.Logger) -> CallCache:
ans = impl_cls(cfg, logger)
assert isinstance(ans, CallCache)
return ans


def derive_call_cache_key(
exe: Union[Tree.Task, Tree.Workflow], inputs: Env.Bindings[Value.Base]
) -> str:
"""
Derive the call cache key for invocation of the given task/workflow with the inputs
"""

# Digesting inputs: for most Files we cache based on filename (and implicitly mtime). But we
# make an exception for files written by a stdlib write_* function, since those always get
# unique temporary filenames that therefore could never be cached based on filename. For those
# only, we replace the filename with a content digest. Such files cannot be too large since
# their contents were originally held in memory.
def rewriter(fn: str) -> str:
if _is_stdlib_written_file(fn):
hasher = hashlib.sha256()
with open(fn, "rb") as f:
for chunk in iter(lambda: f.read(1048576), b""):
hasher.update(chunk)
# needs not be pretty, since we just digest it again below:
return "_miniwdl_write_" + hasher.hexdigest()
return fn

inputs = Value.rewrite_env_files(inputs, rewriter)
return f"{exe.name}/{exe.digest}/{Value.digest_env(inputs)}"


def _is_stdlib_written_file(fn: str) -> bool:
# heuristic to determine if a file was created by one of the stdlib write_* functions
return os.path.isfile(fn) and os.path.basename(os.path.dirname(fn)) == "_miniwdl_write_"
6 changes: 3 additions & 3 deletions WDL/runtime/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .._util import StructuredLogMessage as _
from . import config, _statusbar
from .download import able as downloadable, run_cached as download
from .cache import CallCache, new as new_call_cache
from .cache import CallCache, derive_call_cache_key, new as new_call_cache
from .error import OutputError, Interrupted, Terminated, RunFailed, error_json


Expand Down Expand Up @@ -110,7 +110,7 @@ def run_local_task(
cleanup.enter_context(_statusbar.task_slotted())
container = None
try:
cache_key = f"{task.name}/{task.digest}/{Value.digest_env(inputs)}"
cache_key = derive_call_cache_key(task, inputs)
cached = cache.get(cache_key, inputs, task.effective_outputs)
if cached is not None:
for decl in task.outputs:
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def __init__(
container: "runtime.task_container.TaskContainer",
inputs_only: bool,
) -> None:
super().__init__(wdl_version, write_dir=os.path.join(container.host_dir, "write_"))
super().__init__(wdl_version, write_dir=os.path.join(container.host_dir, "_miniwdl_write_"))
self.logger = logger
self.container = container
self.inputs_only = inputs_only
Expand Down
6 changes: 3 additions & 3 deletions WDL/runtime/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)
from .._util import StructuredLogMessage as _
from . import config, _statusbar
from .cache import CallCache, new as new_call_cache
from .cache import CallCache, derive_call_cache_key, new as new_call_cache
from .error import RunFailed, Terminated, error_json


Expand Down Expand Up @@ -680,7 +680,7 @@ class _StdLib(StdLib.Base):
def __init__(
self, wdl_version: str, cfg: config.Loader, state: StateMachine, cache: CallCache
) -> None:
super().__init__(wdl_version, write_dir=os.path.join(state.run_dir, "write_"))
super().__init__(wdl_version, write_dir=os.path.join(state.run_dir, "_miniwdl_write_"))
self.cfg = cfg
self.state = state
self.cache = cache
Expand Down Expand Up @@ -842,7 +842,7 @@ def run_local_workflow(
write_values_json(inputs, os.path.join(run_dir, "inputs.json"), namespace=workflow.name)

# query call cache
cache_key = f"{workflow.name}/{workflow.digest}/{Value.digest_env(inputs)}"
cache_key = derive_call_cache_key(workflow, inputs)
cached = cache.get(cache_key, inputs, workflow.effective_outputs)
if cached is not None:
for outp in workflow.effective_outputs:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_8cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,14 @@ def test_directory_coherence(self):
input:
full_name = read_person.full_name
}
call hello as hello2 {
input:
full_name = write_lines([read_string(read_person.full_name)])
}
}
output {
Array[File] messages = hello.message
Array[File] messages = flatten([hello.message, hello2.message])
}
}
Expand Down

0 comments on commit 00ce21c

Please sign in to comment.