diff --git a/WDL/Value.py b/WDL/Value.py index 53274c9b..b47f07bd 100644 --- a/WDL/Value.py +++ b/WDL/Value.py @@ -12,10 +12,9 @@ from abc import ABC from typing import Any, List, Optional, Tuple, Dict, Iterable, Union, Callable from . import Error, Type, Env -from ._util import CustomDeepCopyMixin -class Base(CustomDeepCopyMixin, ABC): +class Base(ABC): """The abstract base class for WDL values""" type: Type.Base @@ -30,9 +29,6 @@ class Base(CustomDeepCopyMixin, ABC): from ``WDL.Expr.eval`` """ - # exempt type & expr from deep-copying since they're immutable - _shallow_copy_attrs: List[str] = ["expr", "type"] - def __init__(self, type: Type.Base, value: Any, expr: "Optional[Expr.Base]" = None) -> None: assert isinstance(type, Type.Base) self.type = type @@ -45,6 +41,39 @@ def __eq__(self, other) -> bool: def __str__(self) -> str: return json.dumps(self.json) + def __deepcopy__(self, memo: Dict[int, Any]) -> Any: + cls = self.__class__ + cp = cls.__new__(cls) + shallow = ("expr", "type") # avoid deep-copying large, immutable structures + for k, v in self.__dict__.items(): + if k != "value": + setattr(cp, k, copy.deepcopy(v, memo) if k not in shallow else v) + # override deepcopy of self.value to eliminate sharing; this accommodates rewrite_files() + # which wants a deep copy for the purpose of modifying the copied File.value, and isn't + # expecting to encounter shared ones. + if isinstance(self.value, list): + value2 = [] + for elt in self.value: + if isinstance(elt, tuple): + assert len(elt) == 2 + value2.append((copy.deepcopy(elt[0]), copy.deepcopy(elt[1]))) + else: + assert isinstance(elt, Base) + value2.append(copy.deepcopy(elt)) + cp.value = value2 + elif isinstance(self.value, tuple): + assert len(self.value) == 2 + cp.value = (copy.deepcopy(self.value[0]), copy.deepcopy(self.value[1])) + elif isinstance(self.value, dict): + value2 = {} + for key in self.value: + value2[copy.deepcopy(key)] = copy.deepcopy(self.value[key]) + cp.value = value2 + else: + assert self.value is None or isinstance(self.value, (int, float, bool, str)) + cp.value = self.value + return cp + def coerce(self, desired_type: Optional[Type.Base] = None) -> "Base": """ Coerce the value to the desired type and return it. Types should be @@ -293,7 +322,7 @@ class Null(Base): ``type`` and ``value`` are both None.""" def __init__(self, expr: "Optional[Expr.Base]" = None) -> None: - super().__init__(Type.Any(optional=True), expr) + super().__init__(Type.Any(optional=True), None, expr) def coerce(self, desired_type: Optional[Type.Base] = None) -> Base: "" @@ -440,9 +469,13 @@ def rewrite_files(v: Base, f: Callable[[str], str]) -> Base: (including Files nested inside compound Values). """ + mapped_files = set() + def map_files(v2: Base) -> Base: if isinstance(v2, File): + assert id(v2) not in mapped_files, f"File {id(v2)} reused in deepcopy" v2.value = f(v2.value) + mapped_files.add(id(v2)) for ch in v2.children: map_files(ch) return v2 diff --git a/WDL/_util.py b/WDL/_util.py index 71979d57..1421b8d2 100644 --- a/WDL/_util.py +++ b/WDL/_util.py @@ -513,31 +513,6 @@ def handle_signal(sig: int, frame: FrameType) -> None: _terminating = None -@export -class CustomDeepCopyMixin: - """ - Mixin class overrides __deepcopy__ to consult an internal list of attribute names to be merely - shallow-copied when the time comes. Useful for attributes referencing large, immutable data - structures. - - Override class variable _shallow_copy_attrs to a list of the attribute names to be - shallow-copied. - """ - - _shallow_copy_attrs: Optional[List[str]] = None - - def __deepcopy__(self, memo: Dict[int, Any]) -> Any: # pyre-ignore - cls = self.__class__ - cp = cls.__new__(cls) - memo[id(self)] = cp - for k in self._shallow_copy_attrs or []: - v = self.__dict__[k] - memo[id(v)] = v - for k, v in self.__dict__.items(): - setattr(cp, k, copy.deepcopy(v, memo)) - return cp - - byte_size_units = { "B": 1, "K": 1000, diff --git a/WDL/runtime/task.py b/WDL/runtime/task.py index f913e106..2979cc69 100644 --- a/WDL/runtime/task.py +++ b/WDL/runtime/task.py @@ -983,16 +983,10 @@ def _eval_task_inputs( container.add_files(_filenames(posix_inputs)) # copy posix_inputs with all Files mapped to their in-container paths - def map_files(v: Value.Base) -> Value.Base: - if isinstance(v, Value.File): - v.value = container.input_file_map[v.value] - for ch in v.children: - map_files(ch) - return v + def map_files(fn: str) -> str: + return container.input_file_map[fn] - container_inputs = posix_inputs.map( - lambda binding: Env.Binding(binding.name, map_files(copy.deepcopy(binding.value))) - ) + container_inputs = Value.rewrite_env_files(posix_inputs, map_files) # initialize value environment with the inputs container_env = Env.Bindings() @@ -1216,25 +1210,22 @@ def _eval_task_outputs( ) -> Env.Bindings[Value.Base]: # helper to rewrite Files from in-container paths to host paths - def rewrite_files(v: Value.Base, output_name: str) -> None: - if isinstance(v, Value.File): - host_file = container.host_file(v.value) - if host_file is None: - logger.warning( - _( - "output file not found in container (error unless declared type is optional)", - name=output_name, - file=v.value, - ) + def rewriter(fn: str, output_name: str) -> str: + host_file = container.host_file(fn) + if host_file is None: + logger.warning( + _( + "output file not found in container (error unless declared type is optional)", + name=output_name, + file=fn, ) - else: - logger.debug(_("output file", container=v.value, host=host_file)) - # We may overwrite File.value with None, which is an invalid state, then we'll fix it - # up (or abort) below. This trickery is because we don't, at this point, know whether - # the 'desired' output type is File or File?. - v.value = host_file - for ch in v.children: - rewrite_files(ch, output_name) + ) + else: + logger.debug(_("output file", container=fn, host=host_file)) + # We may overwrite File.value with None, which is an invalid state, then we'll fix it + # up (or abort) below. This trickery is because we don't, at this point, know whether + # the 'desired' output type is File or File?. + return host_file # pyre-fixme stdlib = OutputStdLib(logger, container) outputs = Env.Bindings() @@ -1258,11 +1249,10 @@ def rewrite_files(v: Value.Base, output_name: str) -> None: # compound values) # First bind the value as-is in the environment, so that subsequent output expressions will - # "see" the in-container path(s) if they use this binding. (Copy it though, because we'll - # then clobber v) - env = env.bind(decl.name, copy.deepcopy(v)) + # "see" the in-container path(s) if they use this binding. + env = env.bind(decl.name, v) # Rewrite each File.value to either a host path, or None if the file doesn't exist. - rewrite_files(v, decl.name) + v = Value.rewrite_files(v, lambda fn: rewriter(fn, decl.name)) # File.coerce has a special behavior for us so that, if the value is None: # - produces Value.Null() if the desired type is File? # - raises FileNotFoundError otherwise. diff --git a/requirements.dev.txt b/requirements.dev.txt index 9a0c96bc..5a2f99f4 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -2,7 +2,7 @@ # requirements.txt which are needed for miniwdl to run in common use. pyre-check==0.0.27 black==19.10b0 -pylint>=2.5.1 +pylint==2.4.4 sphinx sphinx-autobuild sphinx_rtd_theme diff --git a/tests/test_7runner.py b/tests/test_7runner.py index 979ca53c..c2d7573b 100644 --- a/tests/test_7runner.py +++ b/tests/test_7runner.py @@ -167,3 +167,36 @@ def test_download_cache4(self): line = json.loads(line) if "downloaded input files" in line["message"]: self.assertEqual(line["downloaded"], 0) + +class MiscRegressionTests(RunnerTestCase): + def test_repeated_file_rewriting(self): + wdl = """ + version 1.0 + task t { + input { + Array[File] files + } + command <<< + xargs cat < ~{write_lines(files)} + echo Bob > bob.txt + >>> + output { + Array[String] out = read_lines(stdout()) + File bob = "bob.txt" + Array[File] bob2 = [bob, bob] + } + } + workflow w { + input { + File file + } + call t { + input: + files = [file, file] + } + } + """ + with open(os.path.join(self._dir, "alice.txt"), "w") as alice: + print("Alice", file=alice) + outp = self._run(wdl, {"file": os.path.join(self._dir, "alice.txt")}) + self.assertEqual(outp["t.out"], ["Alice", "Alice"])