Skip to content

Commit

Permalink
Merge 86fa8d0 into cb71834
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Aug 5, 2020
2 parents cb71834 + 86fa8d0 commit 5df6620
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 49 deletions.
33 changes: 26 additions & 7 deletions WDL/runtime/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from contextlib import AbstractContextManager
from urllib.parse import urlparse, urlunparse
from fnmatch import fnmatchcase
from threading import Lock

from . import config

Expand All @@ -29,10 +30,18 @@ class CallCache(AbstractContextManager):
_flocker: FlockHolder
_logger: logging.Logger

# URIs->files cached only for the lifetime of this CallCache instance. These are downloaded in
# the course of the current workflow run, but not eligible for persistent caching in future
# runs; we just want to remember them for potential reuse later in the current run.
_workflow_downloads: Dict[str, str]
_lock: Lock

def __init__(self, cfg: config.Loader, logger: logging.Logger):
self._cfg = cfg
self._logger = logger.getChild("CallCache")
self._flocker = FlockHolder(self._logger)
self._workflow_downloads = {}
self._lock = Lock()
self.call_cache_dir = cfg["call_cache"]["dir"]

try:
Expand Down Expand Up @@ -152,6 +161,9 @@ def get_download(self, uri: str, logger: Optional[logging.Logger] = None) -> Opt
Return filename of the cached download of uri, if available. If so then opens a shared
flock on the local file, which will remain for the life of the CallCache object.
"""
with self._lock:
if uri in self._workflow_downloads:
return self._workflow_downloads[uri]
logger = logger.getChild("CallCache") if logger else self._logger
p = self.download_path(uri)
if not (self._cfg["download_cache"].get_bool("get") and p and os.path.isfile(p)):
Expand Down Expand Up @@ -181,16 +193,23 @@ def put_download(
"""
logger = logger.getChild("CallCache") if logger else self._logger
ans = filename
if self._cfg["download_cache"].get_bool("put"):
p = self.download_path(uri)
if p:
os.makedirs(os.path.dirname(p), exist_ok=True)
os.rename(filename, p)
logger.info(_("stored in download cache", uri=uri, cache_path=p))
ans = p
p = self.download_cacheable(uri)
if p:
os.makedirs(os.path.dirname(p), exist_ok=True)
os.rename(filename, p)
logger.info(_("stored in download cache", uri=uri, cache_path=p))
ans = p
else:
with self._lock:
self._workflow_downloads[uri] = ans
self.flock(ans)
return ans

def download_cacheable(self, uri: str) -> Optional[str]:
if not self._cfg["download_cache"].get_bool("put"):
return None
return self.download_path(uri)

def flock(self, filename: str, exclusive: bool = False) -> None:
self._flocker.flock(filename, update_atime=True, exclusive=exclusive)

Expand Down
7 changes: 3 additions & 4 deletions WDL/runtime/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,9 @@ def run_cached(
cached = cache.get_download(uri, logger=logger)
if cached:
return True, cached
if not cfg["download_cache"].get_bool("put") or not cache.download_path(uri):
return False, run(cfg, logger, uri, run_dir=run_dir, **kwargs)
# run the download within the cache directory
run_dir = os.path.join(cfg["download_cache"]["dir"], "ops")
if cache.download_cacheable(uri):
# run the download within the cache directory
run_dir = os.path.join(cfg["download_cache"]["dir"], "ops")
filename = run(cfg, logger, uri, run_dir=run_dir, **kwargs)
return False, cache.put_download(uri, os.path.realpath(filename), logger=logger)

Expand Down
79 changes: 41 additions & 38 deletions WDL/runtime/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def outputs(self) -> Optional[Env.Bindings[Value.Base]]:
:param inputs: ``WDL.Env.Bindings[Value.Base]`` of call inputs
"""

def step(self, cfg: config.Loader) -> "Optional[StateMachine.CallInstructions]":
def step(
self, cfg: config.Loader, stdlib: StdLib.Base
) -> "Optional[StateMachine.CallInstructions]":
"""
Advance the workflow state machine, returning the next call to initiate.
Expand Down Expand Up @@ -258,7 +260,7 @@ def step(self, cfg: config.Loader) -> "Optional[StateMachine.CallInstructions]":

# do the job
try:
res = self._do_job(cfg, job)
res = self._do_job(cfg, stdlib, job)
except Exception as exn:
setattr(exn, "job_id", job.id)
raise exn
Expand Down Expand Up @@ -308,7 +310,7 @@ def _schedule(self, job: _Job) -> None:
self.waiting.add(job.id)

def _do_job(
self, cfg: config.Loader, job: _Job
self, cfg: config.Loader, stdlib: StdLib.Base, job: _Job
) -> "Union[StateMachine.CallInstructions, Env.Bindings[Value.Base]]":
if isinstance(job.node, Tree.Gather):
return _gather(
Expand All @@ -331,8 +333,6 @@ def _do_job(
)
)

stdlib = _StdLib(self)

if isinstance(job.node, (Tree.Scatter, Tree.Conditional)):
for newjob in _scatter(self.workflow, job.node, env, job.scatter_stack, stdlib):
self._schedule(newjob)
Expand Down Expand Up @@ -564,13 +564,21 @@ def _gather(

class _StdLib(StdLib.Base):
"checks against & updates the filename whitelist for the read_* and write_* functions"
cfg: config.Loader
state: StateMachine
cache: CallCache

def __init__(self, state: StateMachine) -> None:
def __init__(self, cfg: config.Loader, state: StateMachine, cache: CallCache) -> None:
super().__init__(write_dir=os.path.join(state.run_dir, "write_"))
self.cfg = cfg
self.state = state
self.cache = cache

def _devirtualize_filename(self, filename: str) -> str:
if downloadable(self.cfg, filename):
cached = self.cache.get_download(filename)
if cached:
return cached
if filename in self.state.filename_whitelist:
return filename
raise InputError("attempted read from unknown or inaccessible file " + filename)
Expand Down Expand Up @@ -727,19 +735,17 @@ def _workflow_main_loop(
inputs = recv["inputs"]

# download input files, if needed
posix_inputs = _download_input_files(
cfg, logger, logger_id, run_dir, inputs, thread_pools[0], cache
)
_download_input_files(cfg, logger, logger_id, run_dir, inputs, thread_pools[0], cache)

# run workflow state machine to completion
state = StateMachine(".".join(logger_id), run_dir, workflow, posix_inputs)
state = StateMachine(".".join(logger_id), run_dir, workflow, inputs)
while state.outputs is None:
if _test_pickle:
state = pickle.loads(pickle.dumps(state))
if terminating():
raise Terminated()
# schedule all runnable calls
next_call = state.step(cfg)
next_call = state.step(cfg, _StdLib(cfg, state, cache))
while next_call:
call_dir = os.path.join(run_dir, next_call.id)
if os.path.exists(call_dir):
Expand All @@ -765,7 +771,7 @@ def _workflow_main_loop(
else:
assert False
call_futures[future] = next_call.id
next_call = state.step(cfg)
next_call = state.step(cfg, _StdLib(cfg, state, cache))
# no more calls to launch right now; wait for an outstanding call to finish
future = next(futures.as_completed(call_futures), None)
if future:
Expand Down Expand Up @@ -830,38 +836,38 @@ def _download_input_files(
inputs: Env.Bindings[Value.Base],
thread_pool: futures.ThreadPoolExecutor,
cache: CallCache,
) -> Env.Bindings[Value.Base]:
) -> None:
"""
Find all File values in the inputs (including any nested within compound values) that need
to / can be downloaded. Download them to some location under run_dir and return a copy of the
inputs with the URI values replaced by the downloaded filenames. Parallelize the download
operations on thread_pool.
Find all File values in the inputs, including any nested within compound values, that are
downloadable URIs, and ensure the cache is "primed" with them -- including performing actual
download tasks on thread_pool, if necessary. The inputs are not modified, but the CallCache
will be ready to quickly produce a local filename corresponding to any URI therein, because
it's either stored in the persistent download cache (if enabled), or downloaded to the
current/parent run directory and transiently memoized.
"""

# scan for URIs and schedule their downloads on the thread pool
ops = {}

def schedule_downloads(v: Value.Base) -> None:
def schedule_download(uri: str) -> str:
nonlocal ops
if isinstance(v, Value.File):
if v.value not in ops and downloadable(cfg, v.value):
logger.info(_("schedule input file download", uri=v.value))
future = thread_pool.submit(
download,
cfg,
logger,
cache,
v.value,
run_dir=os.path.join(run_dir, "download", str(len(ops)), "."),
logger_prefix=logger_prefix + [f"download{len(ops)}"],
)
ops[future] = v.value
for ch in v.children:
schedule_downloads(ch)
if downloadable(cfg, uri):
logger.info(_("schedule input file download", uri=uri))
future = thread_pool.submit(
download,
cfg,
logger,
cache,
uri,
run_dir=os.path.join(run_dir, "download", str(len(ops)), "."),
logger_prefix=logger_prefix + [f"download{len(ops)}"],
)
ops[future] = uri
return uri

inputs.map(lambda b: schedule_downloads(b.value))
Value.rewrite_env_files(inputs, schedule_download)
if not ops:
return inputs
return
logger.notice(_("downloading input files", count=len(ops))) # pyre-fixme

# collect the results, with "clean" fail-fast
Expand Down Expand Up @@ -909,6 +915,3 @@ def schedule_downloads(v: Value.Base) -> None:
cached_bytes=cached_bytes,
)
)

# rewrite the input URIs to the downloaded filenames
return Value.rewrite_env_files(inputs, lambda uri: downloaded.get(uri, uri))
51 changes: 51 additions & 0 deletions tests/test_7runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,57 @@ def test_download_cache4(self):
if "downloaded input files" in line["message"]:
self.assertEqual(line["downloaded"], 0)

def test_download_cache5(self):
# passing workflow-level URI inputs through to task, which should find them in the cache
wdl5 = """
version 1.0
task t {
input {
File f1
File f2
}
command {}
output {
Int size2 = floor(size(f1) + size(f2))
}
}
workflow w {
input {
Array[File] af1
}
scatter (f1 in af1) {
call t { input: f1 = f1 }
}
output {
Array[Int] sizes = t.size2
}
}
"""
cfg = WDL.runtime.config.Loader(logging.getLogger(self.id()))
cfg.override({
"download_cache": {
"put": True,
"get": True,
"dir": os.path.join(self._dir, "cache5"),
"disable_patterns": ["*://google.com/*"]
},
"logging": { "json": True }
})
inp = {
"af1": ["https://raw.githubusercontent.com/chanzuckerberg/miniwdl/main/tests/alyssa_ben.txt", "s3://1000genomes/CHANGELOG" ],
"t.f2": "https://google.com/robots.txt"
}
self._run(wdl5, inp, cfg=cfg)
with open(os.path.join(self._rundir, "workflow.log")) as logfile:
for line in logfile:
line = json.loads(line)
if "t:call-t" not in line["source"] and "downloaded input files" in line["message"]:
self.assertEqual(line["downloaded"], 3)
if "t:call-t" in line["source"] and "downloaded input files" in line["message"]:
self.assertEqual(line["downloaded"], 0)
self.assertEqual(line["cached"], 2)


class MiscRegressionTests(RunnerTestCase):
def test_repeated_file_rewriting(self):
wdl = """
Expand Down

0 comments on commit 5df6620

Please sign in to comment.