Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Aug 5, 2020
1 parent 0a8554e commit 1db0765
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 24 deletions.
15 changes: 15 additions & 0 deletions WDL/runtime/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from contextlib import AbstractContextManager
from urllib.parse import urlparse, urlunparse
from fnmatch import fnmatchcase
from threading import Lock

from . import config

Expand All @@ -29,10 +30,18 @@ class CallCache(AbstractContextManager):
_flocker: FlockHolder
_logger: logging.Logger

# URIs->files cached only for the lifetime of this CallCache instance. These are downloaded in
# the course of the current workflow run, but not eligible for persistent caching in future
# runs; we just want to remember them for potential reuse later in the current run.
_workflow_downloads: Dict[str, str]
_lock: Lock

def __init__(self, cfg: config.Loader, logger: logging.Logger):
self._cfg = cfg
self._logger = logger.getChild("CallCache")
self._flocker = FlockHolder(self._logger)
self._workflow_downloads = {}
self._lock = Lock()
self.call_cache_dir = cfg["call_cache"]["dir"]

try:
Expand Down Expand Up @@ -152,6 +161,9 @@ def get_download(self, uri: str, logger: Optional[logging.Logger] = None) -> Opt
Return filename of the cached download of uri, if available. If so then opens a shared
flock on the local file, which will remain for the life of the CallCache object.
"""
with self._lock:
if uri in self._workflow_downloads:
return self._workflow_downloads[uri]
logger = logger.getChild("CallCache") if logger else self._logger
p = self.download_path(uri)
if not (self._cfg["download_cache"].get_bool("get") and p and os.path.isfile(p)):
Expand Down Expand Up @@ -187,6 +199,9 @@ def put_download(
os.rename(filename, p)
logger.info(_("stored in download cache", uri=uri, cache_path=p))
ans = p
else:
with self._lock:
self._workflow_downloads[uri] = ans
self.flock(ans)
return ans

Expand Down
7 changes: 3 additions & 4 deletions WDL/runtime/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,9 @@ def run_cached(
cached = cache.get_download(uri, logger=logger)
if cached:
return True, cached
if not cache.download_cacheable(uri):
return False, run(cfg, logger, uri, run_dir=run_dir, **kwargs)
# run the download within the cache directory
run_dir = os.path.join(cfg["download_cache"]["dir"], "ops")
if cache.download_cacheable(uri):
# run the download within the cache directory
run_dir = os.path.join(cfg["download_cache"]["dir"], "ops")
filename = run(cfg, logger, uri, run_dir=run_dir, **kwargs)
return False, cache.put_download(uri, os.path.realpath(filename), logger=logger)

Expand Down
48 changes: 28 additions & 20 deletions WDL/runtime/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def outputs(self) -> Optional[Env.Bindings[Value.Base]]:
:param inputs: ``WDL.Env.Bindings[Value.Base]`` of call inputs
"""

def step(self, cfg: config.Loader) -> "Optional[StateMachine.CallInstructions]":
def step(
self, cfg: config.Loader, stdlib: StdLib.Base
) -> "Optional[StateMachine.CallInstructions]":
"""
Advance the workflow state machine, returning the next call to initiate.
Expand Down Expand Up @@ -258,7 +260,7 @@ def step(self, cfg: config.Loader) -> "Optional[StateMachine.CallInstructions]":

# do the job
try:
res = self._do_job(cfg, job)
res = self._do_job(cfg, stdlib, job)
except Exception as exn:
setattr(exn, "job_id", job.id)
raise exn
Expand Down Expand Up @@ -308,7 +310,7 @@ def _schedule(self, job: _Job) -> None:
self.waiting.add(job.id)

def _do_job(
self, cfg: config.Loader, job: _Job
self, cfg: config.Loader, stdlib: StdLib.Base, job: _Job
) -> "Union[StateMachine.CallInstructions, Env.Bindings[Value.Base]]":
if isinstance(job.node, Tree.Gather):
return _gather(
Expand All @@ -331,8 +333,6 @@ def _do_job(
)
)

stdlib = _StdLib(self)

if isinstance(job.node, (Tree.Scatter, Tree.Conditional)):
for newjob in _scatter(self.workflow, job.node, env, job.scatter_stack, stdlib):
self._schedule(newjob)
Expand Down Expand Up @@ -564,13 +564,21 @@ def _gather(

class _StdLib(StdLib.Base):
"checks against & updates the filename whitelist for the read_* and write_* functions"
cfg: config.Loader
state: StateMachine
cache: CallCache

def __init__(self, state: StateMachine) -> None:
def __init__(self, cfg: config.Loader, state: StateMachine, cache: CallCache) -> None:
super().__init__(write_dir=os.path.join(state.run_dir, "write_"))
self.cfg = cfg
self.state = state
self.cache = cache

def _devirtualize_filename(self, filename: str) -> str:
if downloadable(self.cfg, filename):
cached = self.cache.get_download(filename)
if cached:
return cached
if filename in self.state.filename_whitelist:
return filename
raise InputError("attempted read from unknown or inaccessible file " + filename)
Expand Down Expand Up @@ -727,19 +735,17 @@ def _workflow_main_loop(
inputs = recv["inputs"]

# download input files, if needed
posix_inputs = _download_input_files(
cfg, logger, logger_id, run_dir, inputs, thread_pools[0], cache
)
_download_input_files(cfg, logger, logger_id, run_dir, inputs, thread_pools[0], cache)

# run workflow state machine to completion
state = StateMachine(".".join(logger_id), run_dir, workflow, posix_inputs)
state = StateMachine(".".join(logger_id), run_dir, workflow, inputs)
while state.outputs is None:
if _test_pickle:
state = pickle.loads(pickle.dumps(state))
if terminating():
raise Terminated()
# schedule all runnable calls
next_call = state.step(cfg)
next_call = state.step(cfg, _StdLib(cfg, state, cache))
while next_call:
call_dir = os.path.join(run_dir, next_call.id)
if os.path.exists(call_dir):
Expand All @@ -765,7 +771,7 @@ def _workflow_main_loop(
else:
assert False
call_futures[future] = next_call.id
next_call = state.step(cfg)
next_call = state.step(cfg, _StdLib(cfg, state, cache))
# no more calls to launch right now; wait for an outstanding call to finish
future = next(futures.as_completed(call_futures), None)
if future:
Expand Down Expand Up @@ -830,12 +836,17 @@ def _download_input_files(
inputs: Env.Bindings[Value.Base],
thread_pool: futures.ThreadPoolExecutor,
cache: CallCache,
) -> Env.Bindings[Value.Base]:
) -> None:
"""
Find all File values in the inputs (including any nested within compound values) that need
to / can be downloaded. Download them to some location under run_dir and return a copy of the
inputs with the URI values replaced by the downloaded filenames. Parallelize the download
operations on thread_pool.
to / can be downloaded. Download them to some location under run_dir, parallelized on
thread_pool, and put them into the cache (either in the persistent cache if enabled, otherwise,
in the transient cache just for this run). The inputs env is not modified, but later steps
encountering a URI therein can expect it to be cached, with one exception:
As an optimization, exclude any downloadable File input whose sole use is to be passed directly
into one, non-scattered call. Then downloading it becomes the responsibility of that call, and
other unrelated workflow steps may proceed without waiting for it.
"""

# scan for URIs and schedule their downloads on the thread pool
Expand All @@ -861,7 +872,7 @@ def schedule_downloads(v: Value.Base) -> None:

inputs.map(lambda b: schedule_downloads(b.value))
if not ops:
return inputs
return
logger.notice(_("downloading input files", count=len(ops))) # pyre-fixme

# collect the results, with "clean" fail-fast
Expand Down Expand Up @@ -909,6 +920,3 @@ def schedule_downloads(v: Value.Base) -> None:
cached_bytes=cached_bytes,
)
)

# rewrite the input URIs to the downloaded filenames
return Value.rewrite_env_files(inputs, lambda uri: downloaded.get(uri, uri))

0 comments on commit 1db0765

Please sign in to comment.