Skip to content

Commit

Permalink
Merge 2f6abd7 into 1bc3776
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jun 19, 2023
2 parents 1bc3776 + 2f6abd7 commit 9ddc9e8
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 39 deletions.
2 changes: 1 addition & 1 deletion WDL/StdLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _f(
outfile: BinaryIO = outfile # pyre-ignore
serialize(v, outfile)
filename = outfile.name
chmod_R_plus(filename, file_bits=0o660)
chmod_R_plus(filename, file_bits=0o660) # ensure accessibility to downstream tasks
vfn = self._virtualize_filename(filename)
return Value.File(vfn)

Expand Down
91 changes: 63 additions & 28 deletions WDL/runtime/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import json
import os
import shutil
import hashlib
import logging
from pathlib import Path
from typing import Dict, Optional, Union
Expand All @@ -16,7 +16,7 @@

from . import config

from .. import Env, Value, Type
from .. import Env, Value, Type, Tree
from .._util import (
StructuredLogMessage as _,
FlockHolder,
Expand Down Expand Up @@ -102,7 +102,7 @@ def get(
# check that no files/directories referenced by the inputs & cached outputs are newer
# than the cache file itself
if _check_files_coherence(
self._cfg, self._logger, file_path, inputs
self._cfg, self._logger, file_path, inputs, skip_stdlib_written=True
) and _check_files_coherence(self._cfg, self._logger, file_path, cache):
return cache
else:
Expand Down Expand Up @@ -229,11 +229,11 @@ def put_download(
"""
if directory:
uri = uri.rstrip("/")
logger = logger.getChild("CallCache") if logger else self._logger
p = self.download_cacheable(uri, directory=directory)
if not p:
self.memo_download(uri, filename, directory=directory)
return filename
logger = logger.getChild("CallCache") if logger else self._logger
moved = False
# transient exclusive flock on whole cache directory (serializes entry add/remove)
with FlockHolder(logger) as transient:
Expand Down Expand Up @@ -316,7 +316,11 @@ def flock(


def _check_files_coherence(
cfg: config.Loader, logger: logging.Logger, cache_file: str, values: Env.Bindings[Value.Base]
cfg: config.Loader,
logger: logging.Logger,
cache_file: str,
values: Env.Bindings[Value.Base],
skip_stdlib_written: bool = False,
) -> bool:
"""
Verify that none of the files/directories referenced by values are newer than cache_file itself
Expand All @@ -338,30 +342,29 @@ def raiser(exc):

def check_one(v: Union[Value.File, Value.Directory]):
assert isinstance(v, (Value.File, Value.Directory))
if not downloadable(cfg, v.value):
try:
if mtime(v.value) > cache_file_mtime:
raise StopIteration
if isinstance(v, Value.Directory):
# check everything in directory
for root, subdirs, subfiles in os.walk(
v.value, onerror=raiser, followlinks=False
):
for subdir in subdirs:
if mtime(os.path.join(root, subdir)) > cache_file_mtime:
raise StopIteration
for fn in subfiles:
if mtime(os.path.join(root, fn)) > cache_file_mtime:
raise StopIteration
except (FileNotFoundError, NotADirectoryError, StopIteration):
logger.warning(
_(
"cache entry invalid due to deleted or modified file/directory",
cache_file=cache_file,
changed=v.value,
)
)
if downloadable(cfg, v.value) or (skip_stdlib_written and _is_stdlib_written_file(v.value)):
return
try:
if mtime(v.value) > cache_file_mtime:
raise StopIteration
if isinstance(v, Value.Directory):
# check everything in directory
for root, subdirs, subfiles in os.walk(v.value, onerror=raiser, followlinks=False):
for subdir in subdirs:
if mtime(os.path.join(root, subdir)) > cache_file_mtime:
raise StopIteration
for fn in subfiles:
if mtime(os.path.join(root, fn)) > cache_file_mtime:
raise StopIteration
except (FileNotFoundError, NotADirectoryError, StopIteration):
logger.warning(
_(
"cache entry invalid due to deleted or modified file/directory",
cache_file=cache_file,
changed=v.value,
)
)
raise StopIteration

try:
Value.rewrite_env_paths(values, check_one)
Expand All @@ -388,3 +391,35 @@ def new(cfg: config.Loader, logger: logging.Logger) -> CallCache:
ans = impl_cls(cfg, logger)
assert isinstance(ans, CallCache)
return ans


def derive_call_cache_key(
exe: Union[Tree.Task, Tree.Workflow], inputs: Env.Bindings[Value.Base]
) -> str:
"""
Derive the call cache key for invocation of the given task/workflow with the inputs
"""

# Digesting inputs: for most Files we cache based on the absolute path (and implicitly mtime
# relative to the cache entry itself). But we make an exception for files written by a stdlib
# write_* function, since those always get unique temporary filenames that therefore could
# never be cached based on path. For those only, we use a content digest in the cache key
# derivation. Such files cannot be too large since their contents were originally held in
# memory.
def rewriter(fn: str) -> str:
if _is_stdlib_written_file(fn):
hasher = hashlib.sha256()
with open(fn, "rb") as f:
for chunk in iter(lambda: f.read(1048576), b""):
hasher.update(chunk)
# needs not be pretty, since we just digest it again below:
return "_miniwdl_write_" + hasher.hexdigest()
return fn

inputs = Value.rewrite_env_files(inputs, rewriter)
return f"{exe.name}/{exe.digest}/{Value.digest_env(inputs)}"


def _is_stdlib_written_file(fn: str) -> bool:
# heuristic to determine if a file was created by one of the stdlib write_* functions
return os.path.isfile(fn) and os.path.basename(os.path.dirname(fn)) == "_miniwdl_write_"
6 changes: 3 additions & 3 deletions WDL/runtime/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .._util import StructuredLogMessage as _
from . import config, _statusbar
from .download import able as downloadable, run_cached as download
from .cache import CallCache, new as new_call_cache
from .cache import CallCache, derive_call_cache_key, new as new_call_cache
from .error import OutputError, Interrupted, Terminated, RunFailed, error_json


Expand Down Expand Up @@ -110,7 +110,7 @@ def run_local_task(
cleanup.enter_context(_statusbar.task_slotted())
container = None
try:
cache_key = f"{task.name}/{task.digest}/{Value.digest_env(inputs)}"
cache_key = derive_call_cache_key(task, inputs)
cached = cache.get(cache_key, inputs, task.effective_outputs)
if cached is not None:
for decl in task.outputs:
Expand Down Expand Up @@ -1014,7 +1014,7 @@ def __init__(
container: "runtime.task_container.TaskContainer",
inputs_only: bool,
) -> None:
super().__init__(wdl_version, write_dir=os.path.join(container.host_dir, "write_"))
super().__init__(wdl_version, write_dir=os.path.join(container.host_dir, "_miniwdl_write_"))
self.logger = logger
self.container = container
self.inputs_only = inputs_only
Expand Down
6 changes: 3 additions & 3 deletions WDL/runtime/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
)
from .._util import StructuredLogMessage as _
from . import config, _statusbar
from .cache import CallCache, new as new_call_cache
from .cache import CallCache, derive_call_cache_key, new as new_call_cache
from .error import RunFailed, Terminated, error_json


Expand Down Expand Up @@ -680,7 +680,7 @@ class _StdLib(StdLib.Base):
def __init__(
self, wdl_version: str, cfg: config.Loader, state: StateMachine, cache: CallCache
) -> None:
super().__init__(wdl_version, write_dir=os.path.join(state.run_dir, "write_"))
super().__init__(wdl_version, write_dir=os.path.join(state.run_dir, "_miniwdl_write_"))
self.cfg = cfg
self.state = state
self.cache = cache
Expand Down Expand Up @@ -842,7 +842,7 @@ def run_local_workflow(
write_values_json(inputs, os.path.join(run_dir, "inputs.json"), namespace=workflow.name)

# query call cache
cache_key = f"{workflow.name}/{workflow.digest}/{Value.digest_env(inputs)}"
cache_key = derive_call_cache_key(workflow, inputs)
cached = cache.get(cache_key, inputs, workflow.effective_outputs)
if cached is not None:
for outp in workflow.effective_outputs:
Expand Down
12 changes: 8 additions & 4 deletions tests/test_8cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,14 @@ def test_directory_coherence(self):
input:
full_name = read_person.full_name
}
call hello as hello2 {
input:
full_name = write_lines([read_string(read_person.full_name)])
}
}
output {
Array[File] messages = hello.message
Array[File] messages = flatten([hello.message, hello2.message])
}
}
Expand All @@ -542,7 +546,7 @@ def test_directory_coherence(self):
command {}
output {
File full_name = write_lines([sep(" ", [person.first, person.last])])
File full_name = write_lines([sep(" ", select_all([person.first, person.middle, person.last]))])
}
}
Expand Down Expand Up @@ -576,7 +580,7 @@ def test_workflow_digest(self):

# ensure digest is sensitive to changes in the struct type and called task (but not the
# uncalled task, or comments/whitespace)
doc2 = WDL.parse_document(self.test_workflow_wdl.replace("String? middle", ""))
doc2 = WDL.parse_document(self.test_workflow_wdl.replace("String? middle", "String? middle Int? age"))
doc2.typecheck()
self.assertNotEqual(doc.workflow.digest, doc2.workflow.digest)

Expand Down Expand Up @@ -626,5 +630,5 @@ def test_workflow_cache(self):
print('{"first":"Alyssa","last":"Hacker","middle":"P"}', file=outfile)
_, outp2 = self._run(self.test_workflow_wdl, inp, cfg=self.cfg)
self.assertEqual(wmock.call_count, 1)
self.assertEqual(tmock.call_count, 2) # reran Alyssa, cached Ben
self.assertEqual(tmock.call_count, 3) # reran Alyssa, cached Ben
self.assertNotEqual(WDL.values_to_json(outp), WDL.values_to_json(outp2))

0 comments on commit 9ddc9e8

Please sign in to comment.