Skip to content

Commit

Permalink
fix Value.rewrite_files() on compound values with multiple appearance…
Browse files Browse the repository at this point in the history
…s of one File instance
  • Loading branch information
mlin committed May 30, 2020
1 parent bbe55fa commit 1f26d62
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 62 deletions.
45 changes: 39 additions & 6 deletions WDL/Value.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from abc import ABC
from typing import Any, List, Optional, Tuple, Dict, Iterable, Union, Callable
from . import Error, Type, Env
from ._util import CustomDeepCopyMixin


class Base(CustomDeepCopyMixin, ABC):
class Base(ABC):
"""The abstract base class for WDL values"""

type: Type.Base
Expand All @@ -30,9 +29,6 @@ class Base(CustomDeepCopyMixin, ABC):
from ``WDL.Expr.eval``
"""

# exempt type & expr from deep-copying since they're immutable
_shallow_copy_attrs: List[str] = ["expr", "type"]

def __init__(self, type: Type.Base, value: Any, expr: "Optional[Expr.Base]" = None) -> None:
assert isinstance(type, Type.Base)
self.type = type
Expand All @@ -45,6 +41,39 @@ def __eq__(self, other) -> bool:
def __str__(self) -> str:
return json.dumps(self.json)

def __deepcopy__(self, memo: Dict[int, Any]) -> Any:
cls = self.__class__
cp = cls.__new__(cls)
shallow = ("expr", "type") # avoid deep-copying large, immutable structures
for k, v in self.__dict__.items():
if k != "value":
setattr(cp, k, copy.deepcopy(v, memo) if k not in shallow else v)
# override deepcopy of self.value to eliminate sharing; this accommodates rewrite_files()
# which wants a deep copy for the purpose of modifying the copied File.value, and isn't
# expecting to encounter shared ones.
if isinstance(self.value, list):
value2 = []
for elt in self.value:
if isinstance(elt, tuple):
assert len(elt) == 2
value2.append((copy.deepcopy(elt[0]), copy.deepcopy(elt[1])))
else:
assert isinstance(elt, Base)
value2.append(copy.deepcopy(elt))
cp.value = value2
elif isinstance(self.value, tuple):
assert len(self.value) == 2
cp.value = (copy.deepcopy(self.value[0]), copy.deepcopy(self.value[1]))
elif isinstance(self.value, dict):
value2 = {}
for key in self.value:
value2[copy.deepcopy(key)] = copy.deepcopy(self.value[key])
cp.value = value2
else:
assert self.value is None or isinstance(self.value, (int, float, bool, str))
cp.value = self.value
return cp

def coerce(self, desired_type: Optional[Type.Base] = None) -> "Base":
"""
Coerce the value to the desired type and return it. Types should be
Expand Down Expand Up @@ -293,7 +322,7 @@ class Null(Base):
``type`` and ``value`` are both None."""

def __init__(self, expr: "Optional[Expr.Base]" = None) -> None:
super().__init__(Type.Any(optional=True), expr)
super().__init__(Type.Any(optional=True), None, expr)

def coerce(self, desired_type: Optional[Type.Base] = None) -> Base:
""
Expand Down Expand Up @@ -440,9 +469,13 @@ def rewrite_files(v: Base, f: Callable[[str], str]) -> Base:
(including Files nested inside compound Values).
"""

mapped_files = set()

def map_files(v2: Base) -> Base:
if isinstance(v2, File):
assert id(v2) not in mapped_files, f"File {id(v2)} reused in deepcopy"
v2.value = f(v2.value)
mapped_files.add(id(v2))
for ch in v2.children:
map_files(ch)
return v2
Expand Down
25 changes: 0 additions & 25 deletions WDL/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,31 +513,6 @@ def handle_signal(sig: int, frame: FrameType) -> None:
_terminating = None


@export
class CustomDeepCopyMixin:
"""
Mixin class overrides __deepcopy__ to consult an internal list of attribute names to be merely
shallow-copied when the time comes. Useful for attributes referencing large, immutable data
structures.
Override class variable _shallow_copy_attrs to a list of the attribute names to be
shallow-copied.
"""

_shallow_copy_attrs: Optional[List[str]] = None

def __deepcopy__(self, memo: Dict[int, Any]) -> Any: # pyre-ignore
cls = self.__class__
cp = cls.__new__(cls)
memo[id(self)] = cp
for k in self._shallow_copy_attrs or []:
v = self.__dict__[k]
memo[id(v)] = v
for k, v in self.__dict__.items():
setattr(cp, k, copy.deepcopy(v, memo))
return cp


byte_size_units = {
"B": 1,
"K": 1000,
Expand Down
52 changes: 21 additions & 31 deletions WDL/runtime/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,16 +983,10 @@ def _eval_task_inputs(
container.add_files(_filenames(posix_inputs))

# copy posix_inputs with all Files mapped to their in-container paths
def map_files(v: Value.Base) -> Value.Base:
if isinstance(v, Value.File):
v.value = container.input_file_map[v.value]
for ch in v.children:
map_files(ch)
return v
def map_files(fn: str) -> str:
return container.input_file_map[fn]

container_inputs = posix_inputs.map(
lambda binding: Env.Binding(binding.name, map_files(copy.deepcopy(binding.value)))
)
container_inputs = Value.rewrite_env_files(posix_inputs, map_files)

# initialize value environment with the inputs
container_env = Env.Bindings()
Expand Down Expand Up @@ -1216,25 +1210,22 @@ def _eval_task_outputs(
) -> Env.Bindings[Value.Base]:

# helper to rewrite Files from in-container paths to host paths
def rewrite_files(v: Value.Base, output_name: str) -> None:
if isinstance(v, Value.File):
host_file = container.host_file(v.value)
if host_file is None:
logger.warning(
_(
"output file not found in container (error unless declared type is optional)",
name=output_name,
file=v.value,
)
def rewriter(fn: str, output_name: str) -> str:
host_file = container.host_file(fn)
if host_file is None:
logger.warning(
_(
"output file not found in container (error unless declared type is optional)",
name=output_name,
file=fn,
)
else:
logger.debug(_("output file", container=v.value, host=host_file))
# We may overwrite File.value with None, which is an invalid state, then we'll fix it
# up (or abort) below. This trickery is because we don't, at this point, know whether
# the 'desired' output type is File or File?.
v.value = host_file
for ch in v.children:
rewrite_files(ch, output_name)
)
else:
logger.debug(_("output file", container=fn, host=host_file))
# We may overwrite File.value with None, which is an invalid state, then we'll fix it
# up (or abort) below. This trickery is because we don't, at this point, know whether
# the 'desired' output type is File or File?.
return host_file # pyre-fixme

stdlib = OutputStdLib(logger, container)
outputs = Env.Bindings()
Expand All @@ -1258,11 +1249,10 @@ def rewrite_files(v: Value.Base, output_name: str) -> None:
# compound values)

# First bind the value as-is in the environment, so that subsequent output expressions will
# "see" the in-container path(s) if they use this binding. (Copy it though, because we'll
# then clobber v)
env = env.bind(decl.name, copy.deepcopy(v))
# "see" the in-container path(s) if they use this binding.
env = env.bind(decl.name, v)
# Rewrite each File.value to either a host path, or None if the file doesn't exist.
rewrite_files(v, decl.name)
v = Value.rewrite_files(v, lambda fn: rewriter(fn, decl.name))
# File.coerce has a special behavior for us so that, if the value is None:
# - produces Value.Null() if the desired type is File?
# - raises FileNotFoundError otherwise.
Expand Down
33 changes: 33 additions & 0 deletions tests/test_7runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,36 @@ def test_download_cache4(self):
line = json.loads(line)
if "downloaded input files" in line["message"]:
self.assertEqual(line["downloaded"], 0)

class MiscRegressionTests(RunnerTestCase):
def test_repeated_file_rewriting(self):
wdl = """
version 1.0
task t {
input {
Array[File] files
}
command <<<
xargs cat < ~{write_lines(files)}
echo Bob > bob.txt
>>>
output {
Array[String] out = read_lines(stdout())
File bob = "bob.txt"
Array[File] bob2 = [bob, bob]
}
}
workflow w {
input {
File file
}
call t {
input:
files = [file, file]
}
}
"""
with open(os.path.join(self._dir, "alice.txt"), "w") as alice:
print("Alice", file=alice)
outp = self._run(wdl, {"file": os.path.join(self._dir, "alice.txt")})
self.assertEqual(outp["t.out"], ["Alice", "Alice"])

0 comments on commit 1f26d62

Please sign in to comment.