Skip to content

Commit

Permalink
Merge ee85cbc into 2d76b84
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Aug 11, 2019
2 parents 2d76b84 + ee85cbc commit 67e5baf
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 43 deletions.
6 changes: 3 additions & 3 deletions WDL/CLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def runner_input(doc, inputs, input_file, empty, task=None):
)
if not isinstance(decl.type, Type.Array) or decl.type.nonempty:
die("Cannot set input {} {} to empty array".format(str(decl.type), decl.name))
input_env = input_env.bind(empty_name, Value.Array(decl.type, []), decl)
input_env = input_env.bind(empty_name, Value.Array(decl.type.item_type, []), decl)

# add in command-line inputs
for one_input in inputs:
Expand Down Expand Up @@ -524,7 +524,7 @@ def runner_input(doc, inputs, input_file, empty, task=None):
existing = None
if existing:
if isinstance(v, Value.Array):
assert isinstance(existing, Value.Array) and existing.type == v.type
assert isinstance(existing, Value.Array) and v.type.coerces(existing.type)
existing.value.extend(v.value)
else:
die(
Expand Down Expand Up @@ -618,7 +618,7 @@ def runner_input_value(s_value, ty):
ty.item_type, (Type.String, Type.File, Type.Int, Type.Float)
):
# just produce a length-1 array, to be combined ex post facto
return Value.Array(ty, [runner_input_value(s_value, ty.item_type)])
return Value.Array(ty.item_type, [runner_input_value(s_value, ty.item_type)])
return die(
"No command-line support yet for inputs of type {}; workaround: specify in JSON file with --input".format(
str(ty)
Expand Down
7 changes: 4 additions & 3 deletions WDL/Expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ def _eval(
""
assert isinstance(self.type, Type.Array)
return Value.Array(
self.type, [item.eval(env, stdlib).coerce(self.type.item_type) for item in self.items]
self.type.item_type,
[item.eval(env, stdlib).coerce(self.type.item_type) for item in self.items],
)


Expand Down Expand Up @@ -507,7 +508,7 @@ def _eval(
assert isinstance(self.type, Type.Pair)
lv = self.left.eval(env, stdlib)
rv = self.right.eval(env, stdlib)
return Value.Pair(self.type, (lv, rv))
return Value.Pair(self.left.type, self.right.type, (lv, rv))


class Map(Base):
Expand Down Expand Up @@ -584,7 +585,7 @@ def _eval(
for k, v in self.items:
eitems.append((k.eval(env, stdlib), v.eval(env, stdlib)))
# TODO: complain of duplicate keys
return Value.Map(self.type, eitems)
return Value.Map(self.type.item_type, eitems)


class Struct(Base):
Expand Down
25 changes: 16 additions & 9 deletions WDL/StdLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,9 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
assert isinstance(arr, Value.Array)
arrty = arr.type
assert isinstance(arrty, Type.Array)
return Value.Array(arrty, [arg for arg in arr.value if not isinstance(arg, Value.Null)])
return Value.Array(
arrty.item_type, [arg for arg in arr.value if not isinstance(arg, Value.Null)]
)


class _ZipOrCross(EagerFunction):
Expand Down Expand Up @@ -589,8 +591,13 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
if len(lhs.value) != len(rhs.value):
raise Error.EvalError(expr, "zip(): input arrays must have equal length")
return Value.Array(
ty,
[Value.Pair(ty.item_type, (lhs.value[i], rhs.value[i])) for i in range(len(lhs.value))],
ty.item_type,
[
Value.Pair(
ty.item_type.left_type, ty.item_type.right_type, (lhs.value[i], rhs.value[i])
)
for i in range(len(lhs.value))
],
)


Expand All @@ -599,9 +606,9 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
ty, lhs, rhs = self._coerce_args(expr, arguments)
assert isinstance(ty, Type.Array) and isinstance(ty.item_type, Type.Pair)
return Value.Array(
ty,
ty.item_type,
[
Value.Pair(ty.item_type, (lhs_item, rhs_item))
Value.Pair(ty.item_type.left_type, ty.item_type.right_type, (lhs_item, rhs_item))
for lhs_item in lhs.value
for rhs_item in rhs.value
],
Expand Down Expand Up @@ -634,7 +641,7 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
ans = []
for row in arguments[0].coerce(Type.Array(ty)).value:
ans.extend(row.value)
return Value.Array(ty, ans)
return Value.Array(ty.item_type, ans)


class _Transpose(EagerFunction):
Expand Down Expand Up @@ -673,7 +680,7 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
raise Error.EvalError(expr, "transpose(): ragged input matrix")
for i in range(len(row.value)):
ans[i].value.append(row.value[i])
return Value.Array(ty, ans)
return Value.Array(ty.item_type, ans)


class _Range(EagerFunction):
Expand All @@ -700,7 +707,7 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
assert isinstance(arg0, Value.Int)
if arg0.value < 0:
raise Error.EvalError(expr, "range() got negative argument")
return Value.Array(Type.Array(Type.Int()), [Value.Int(x) for x in range(arg0.value)])
return Value.Array(Type.Int(), [Value.Int(x) for x in range(arg0.value)])


class _Prefix(EagerFunction):
Expand All @@ -720,6 +727,6 @@ def infer_type(self, expr: "Expr.Apply") -> Type.Base:
def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.Base:
pfx = arguments[0].coerce(Type.String()).value
return Value.Array(
Type.Array(Type.String()),
Type.String(),
[Value.String(pfx + s.coerce(Type.String()).value) for s in arguments[1].value],
)
29 changes: 17 additions & 12 deletions WDL/Value.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ class Array(Base):
value: List[Base]
type: Type.Array

def __init__(self, type: Type.Array, value: List[Base]) -> None:
def __init__(self, item_type: Type.Base, value: List[Base]) -> None:
self.value = []
self.type = type
super().__init__(type, value)
self.type = Type.Array(item_type, nonempty=(len(value) > 0))
super().__init__(self.type, value)

@property
def json(self) -> Any:
Expand Down Expand Up @@ -165,10 +165,12 @@ class Map(Base):
value: List[Tuple[Base, Base]]
type: Type.Map

def __init__(self, type: Type.Map, value: List[Tuple[Base, Base]]) -> None:
def __init__(
self, item_type: Tuple[Type.Base, Type.Base], value: List[Tuple[Base, Base]]
) -> None:
self.value = []
self.type = type
super().__init__(type, value)
self.type = Type.Map(item_type)
super().__init__(self.type, value)

@property
def json(self) -> Any:
Expand All @@ -188,7 +190,7 @@ def coerce(self, desired_type: Optional[Type.Base] = None) -> Base:
""
if isinstance(desired_type, Type.Map) and desired_type != self.type:
return Map(
desired_type,
desired_type.item_type,
[
(k.coerce(desired_type.item_type[0]), v.coerce(desired_type.item_type[1]))
for (k, v) in self.value
Expand All @@ -209,10 +211,12 @@ class Pair(Base):
value: Tuple[Base, Base]
type: Type.Pair

def __init__(self, type: Type.Pair, value: Tuple[Base, Base]) -> None:
def __init__(
self, left_type: Type.Base, right_type: Type.Base, value: Tuple[Base, Base]
) -> None:
self.value = value
self.type = type
super().__init__(type, value)
self.type = Type.Pair(left_type, right_type)
super().__init__(self.type, value)

def __str__(self) -> str:
assert isinstance(self.value, tuple)
Expand All @@ -231,7 +235,8 @@ def coerce(self, desired_type: Optional[Type.Base] = None) -> Base:
""
if isinstance(desired_type, Type.Pair) and desired_type != self.type:
return Pair(
desired_type,
desired_type.left_type,
desired_type.right_type,
(
self.value[0].coerce(desired_type.left_type),
self.value[1].coerce(desired_type.right_type),
Expand Down Expand Up @@ -328,7 +333,7 @@ def from_json(type: Type.Base, value: Any) -> Base:
for k, v in value.items():
assert isinstance(k, str)
items.append((from_json(type.item_type[0], k), from_json(type.item_type[1], v)))
return Map(type, items)
return Map(type.item_type, items)
if (
isinstance(type, Type.StructInstance)
and isinstance(value, dict)
Expand Down
18 changes: 9 additions & 9 deletions WDL/runtime/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def parse_lines(s: str) -> Value.Array:
ans = [
Value.String(line) for line in (s[:-1] if s.endswith("\n") else s).split("\n")
]
return Value.Array(Type.Array(Type.String()), ans)
return Value.Array(Type.String(), ans)

self._override_static("read_lines", _read_something(parse_lines))

Expand All @@ -573,7 +573,7 @@ def parse_tsv(s: str) -> Value.Array:
for line in parse_lines(s).value
]
# pyre-ignore
return Value.Array(Type.Array(Type.Array(Type.String())), ans)
return Value.Array(Type.Array(Type.String()), ans)

self._override_static("read_tsv", _read_something(parse_tsv))

Expand All @@ -588,7 +588,7 @@ def parse_map(s: str) -> Value.Map:
raise Error.InputError("read_map(): duplicate key")
keys.add(line.value[0].value)
ans.append((line.value[0], line.value[1]))
return Value.Map(Type.Map((Type.String(), Type.String())), ans)
return Value.Map((Type.String(), Type.String()), ans)

self._override_static("read_map", _read_something(parse_map))

Expand All @@ -599,9 +599,9 @@ def parse_json(s: str) -> Value.Base:
ans = []
for k in j:
ans.append((Value.String(str(k)), Value.String(str(j[k]))))
return Value.Map(Type.Map((Type.String(), Type.String())), ans)
return Value.Map((Type.String(), Type.String()), ans)
if isinstance(j, list):
return Value.Array(Type.Array(Type.String()), [Value.String(str(v)) for v in j])
return Value.Array(Type.String(), [Value.String(str(v)) for v in j])
if isinstance(j, bool):
return Value.Boolean(j)
if isinstance(j, int):
Expand Down Expand Up @@ -653,7 +653,7 @@ def _serialize_lines(array: Value.Array, outfile: BinaryIO) -> None:
_write_something(
lambda v, outfile: _serialize_lines(
Value.Array(
Type.Array(Type.String()),
Type.String(),
[
Value.String(
"\t".join(
Expand All @@ -678,7 +678,7 @@ def _serialize_map(map: Value.Map, outfile: BinaryIO) -> None:
"write_map(): keys & values must not contain tab or newline characters"
)
lines.append(Value.String(k + "\t" + v))
_serialize_lines(Value.Array(Type.Array(Type.String()), lines), outfile)
_serialize_lines(Value.Array(Type.String(), lines), outfile)

self._override_static("write_map", _write_something(_serialize_map)) # pyre-ignore

Expand Down Expand Up @@ -728,7 +728,7 @@ def _glob(pattern: Value.String, lib: OutputStdLib = self) -> Value.Array:
dstrip += "" if dstrip.endswith("/") else "/"
assert hf.startswith(dstrip)
container_files.append(os.path.join(lib.container.container_dir, hf[len(dstrip) :]))
return Value.Array(Type.Array(Type.File()), [Value.File(fn) for fn in container_files])
return Value.Array(Type.File(), [Value.File(fn) for fn in container_files])

self._override_static("glob", _glob)

Expand All @@ -749,5 +749,5 @@ def _call_eager(self, expr: Expr.Apply, arguments: List[Value.Base]) -> Value.Ba
for fn_c in files.value
]
# pyre-ignore
arguments = [Value.Array(files.type, host_files)] + arguments[1:]
arguments = [Value.Array(files.type.item_type, host_files)] + arguments[1:]
return super()._call_eager(expr, arguments)
2 changes: 1 addition & 1 deletion WDL/runtime/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _gather(
assert v0 is None or isinstance(v0, Value.Base)
# bind the array, singleton value, or None as appropriate
if isinstance(gather.section, Tree.Scatter):
rhs = Value.Array(Type.Array(v0.type if v0 else Type.Any()), values)
rhs = Value.Array((v0.type if v0 else Type.Any()), values)
else:
assert isinstance(gather.section, Tree.Conditional)
assert len(values) <= 1
Expand Down
8 changes: 4 additions & 4 deletions tests/test_0eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,11 @@ def test_interpolation(self):
)

def test_pair(self):
env = cons_env(("p", WDL.Value.Pair(WDL.Type.Pair(WDL.Type.Float(), WDL.Type.Float()),
env = cons_env(("p", WDL.Value.Pair(WDL.Type.Float(), WDL.Type.Float(),
(WDL.Value.Float(3.14159), WDL.Value.Float(2.71828)))),
("q", WDL.Value.Pair(WDL.Type.Pair(WDL.Type.Pair(WDL.Type.Int(), WDL.Type.Int()),
WDL.Type.Float(optional=True)),
(WDL.Value.Pair(WDL.Type.Pair(WDL.Type.Int(), WDL.Type.Int()),
("q", WDL.Value.Pair(WDL.Type.Pair(WDL.Type.Int(), WDL.Type.Int()),
WDL.Type.Float(optional=True),
(WDL.Value.Pair(WDL.Type.Int(), WDL.Type.Int(),
(WDL.Value.Int(4), WDL.Value.Int(2))),
WDL.Value.Null()))))
self._test_tuples(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_1doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def test_placeholders(self):
}
""")[0]
task.typecheck()
foobar = WDL.Value.Array(WDL.Type.Array(WDL.Type.String()), [WDL.Value.String("foo"), WDL.Value.String("bar")])
foobar = WDL.Value.Array(WDL.Type.String(), [WDL.Value.String("foo"), WDL.Value.String("bar")])
self.assertEqual(task.command.parts[1].eval(WDL.Env.Bindings().bind('s', foobar)).value, 'foo, bar')
foobar = WDL.Value.Array(WDL.Type.Array(WDL.Type.String()), [])
foobar = WDL.Value.Array(WDL.Type.String(), [])
self.assertEqual(task.command.parts[1].eval(WDL.Env.Bindings().bind('s', foobar)).value, '')
with self.assertRaises(WDL.Error.StaticTypeMismatch):
task = WDL.parse_tasks("""
Expand Down

0 comments on commit 67e5baf

Please sign in to comment.