diff --git a/WDL/CLI.py b/WDL/CLI.py index 038c8a9d..07ed44dd 100644 --- a/WDL/CLI.py +++ b/WDL/CLI.py @@ -496,7 +496,7 @@ def runner_input(doc, inputs, input_file, empty, task=None): ) if not isinstance(decl.type, Type.Array) or decl.type.nonempty: die("Cannot set input {} {} to empty array".format(str(decl.type), decl.name)) - input_env = input_env.bind(empty_name, Value.Array(decl.type, []), decl) + input_env = input_env.bind(empty_name, Value.Array(decl.type.item_type, []), decl) # add in command-line inputs for one_input in inputs: @@ -524,7 +524,7 @@ def runner_input(doc, inputs, input_file, empty, task=None): existing = None if existing: if isinstance(v, Value.Array): - assert isinstance(existing, Value.Array) and existing.type == v.type + assert isinstance(existing, Value.Array) and v.type.coerces(existing.type) existing.value.extend(v.value) else: die( @@ -618,7 +618,7 @@ def runner_input_value(s_value, ty): ty.item_type, (Type.String, Type.File, Type.Int, Type.Float) ): # just produce a length-1 array, to be combined ex post facto - return Value.Array(ty, [runner_input_value(s_value, ty.item_type)]) + return Value.Array(ty.item_type, [runner_input_value(s_value, ty.item_type)]) return die( "No command-line support yet for inputs of type {}; workaround: specify in JSON file with --input".format( str(ty) diff --git a/WDL/Expr.py b/WDL/Expr.py index dbfad163..b00e62ee 100644 --- a/WDL/Expr.py +++ b/WDL/Expr.py @@ -462,7 +462,8 @@ def _eval( "" assert isinstance(self.type, Type.Array) return Value.Array( - self.type, [item.eval(env, stdlib).coerce(self.type.item_type) for item in self.items] + self.type.item_type, + [item.eval(env, stdlib).coerce(self.type.item_type) for item in self.items], ) @@ -507,7 +508,7 @@ def _eval( assert isinstance(self.type, Type.Pair) lv = self.left.eval(env, stdlib) rv = self.right.eval(env, stdlib) - return Value.Pair(self.type, (lv, rv)) + return Value.Pair(self.left.type, self.right.type, (lv, rv)) class Map(Base): @@ -584,7 +585,7 @@ def _eval( for k, v in self.items: eitems.append((k.eval(env, stdlib), v.eval(env, stdlib))) # TODO: complain of duplicate keys - return Value.Map(self.type, eitems) + return Value.Map(self.type.item_type, eitems) class Struct(Base): diff --git a/WDL/StdLib.py b/WDL/StdLib.py index 6d9fdf2a..a67e3740 100644 --- a/WDL/StdLib.py +++ b/WDL/StdLib.py @@ -548,7 +548,9 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value. assert isinstance(arr, Value.Array) arrty = arr.type assert isinstance(arrty, Type.Array) - return Value.Array(arrty, [arg for arg in arr.value if not isinstance(arg, Value.Null)]) + return Value.Array( + arrty.item_type, [arg for arg in arr.value if not isinstance(arg, Value.Null)] + ) class _ZipOrCross(EagerFunction): @@ -589,8 +591,13 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value. if len(lhs.value) != len(rhs.value): raise Error.EvalError(expr, "zip(): input arrays must have equal length") return Value.Array( - ty, - [Value.Pair(ty.item_type, (lhs.value[i], rhs.value[i])) for i in range(len(lhs.value))], + ty.item_type, + [ + Value.Pair( + ty.item_type.left_type, ty.item_type.right_type, (lhs.value[i], rhs.value[i]) + ) + for i in range(len(lhs.value)) + ], ) @@ -599,9 +606,9 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value. ty, lhs, rhs = self._coerce_args(expr, arguments) assert isinstance(ty, Type.Array) and isinstance(ty.item_type, Type.Pair) return Value.Array( - ty, + ty.item_type, [ - Value.Pair(ty.item_type, (lhs_item, rhs_item)) + Value.Pair(ty.item_type.left_type, ty.item_type.right_type, (lhs_item, rhs_item)) for lhs_item in lhs.value for rhs_item in rhs.value ], @@ -634,7 +641,7 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value. ans = [] for row in arguments[0].coerce(Type.Array(ty)).value: ans.extend(row.value) - return Value.Array(ty, ans) + return Value.Array(ty.item_type, ans) class _Transpose(EagerFunction): @@ -673,7 +680,7 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value. raise Error.EvalError(expr, "transpose(): ragged input matrix") for i in range(len(row.value)): ans[i].value.append(row.value[i]) - return Value.Array(ty, ans) + return Value.Array(ty.item_type, ans) class _Range(EagerFunction): @@ -700,7 +707,7 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value. assert isinstance(arg0, Value.Int) if arg0.value < 0: raise Error.EvalError(expr, "range() got negative argument") - return Value.Array(Type.Array(Type.Int()), [Value.Int(x) for x in range(arg0.value)]) + return Value.Array(Type.Int(), [Value.Int(x) for x in range(arg0.value)]) class _Prefix(EagerFunction): @@ -720,6 +727,6 @@ def infer_type(self, expr: "Expr.Apply") -> Type.Base: def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.Base: pfx = arguments[0].coerce(Type.String()).value return Value.Array( - Type.Array(Type.String()), + Type.String(), [Value.String(pfx + s.coerce(Type.String()).value) for s in arguments[1].value], ) diff --git a/WDL/Value.py b/WDL/Value.py index 6e9c4128..ca9764af 100644 --- a/WDL/Value.py +++ b/WDL/Value.py @@ -131,10 +131,10 @@ class Array(Base): value: List[Base] type: Type.Array - def __init__(self, type: Type.Array, value: List[Base]) -> None: + def __init__(self, item_type: Type.Base, value: List[Base]) -> None: self.value = [] - self.type = type - super().__init__(type, value) + self.type = Type.Array(item_type, nonempty=(len(value) > 0)) + super().__init__(self.type, value) @property def json(self) -> Any: @@ -165,10 +165,12 @@ class Map(Base): value: List[Tuple[Base, Base]] type: Type.Map - def __init__(self, type: Type.Map, value: List[Tuple[Base, Base]]) -> None: + def __init__( + self, item_type: Tuple[Type.Base, Type.Base], value: List[Tuple[Base, Base]] + ) -> None: self.value = [] - self.type = type - super().__init__(type, value) + self.type = Type.Map(item_type) + super().__init__(self.type, value) @property def json(self) -> Any: @@ -188,7 +190,7 @@ def coerce(self, desired_type: Optional[Type.Base] = None) -> Base: "" if isinstance(desired_type, Type.Map) and desired_type != self.type: return Map( - desired_type, + desired_type.item_type, [ (k.coerce(desired_type.item_type[0]), v.coerce(desired_type.item_type[1])) for (k, v) in self.value @@ -209,10 +211,12 @@ class Pair(Base): value: Tuple[Base, Base] type: Type.Pair - def __init__(self, type: Type.Pair, value: Tuple[Base, Base]) -> None: + def __init__( + self, left_type: Type.Base, right_type: Type.Base, value: Tuple[Base, Base] + ) -> None: self.value = value - self.type = type - super().__init__(type, value) + self.type = Type.Pair(left_type, right_type) + super().__init__(self.type, value) def __str__(self) -> str: assert isinstance(self.value, tuple) @@ -231,7 +235,8 @@ def coerce(self, desired_type: Optional[Type.Base] = None) -> Base: "" if isinstance(desired_type, Type.Pair) and desired_type != self.type: return Pair( - desired_type, + desired_type.left_type, + desired_type.right_type, ( self.value[0].coerce(desired_type.left_type), self.value[1].coerce(desired_type.right_type), @@ -328,7 +333,7 @@ def from_json(type: Type.Base, value: Any) -> Base: for k, v in value.items(): assert isinstance(k, str) items.append((from_json(type.item_type[0], k), from_json(type.item_type[1], v))) - return Map(type, items) + return Map(type.item_type, items) if ( isinstance(type, Type.StructInstance) and isinstance(value, dict) diff --git a/WDL/runtime/task.py b/WDL/runtime/task.py index dd8399e0..4b8dec38 100644 --- a/WDL/runtime/task.py +++ b/WDL/runtime/task.py @@ -559,7 +559,7 @@ def parse_lines(s: str) -> Value.Array: ans = [ Value.String(line) for line in (s[:-1] if s.endswith("\n") else s).split("\n") ] - return Value.Array(Type.Array(Type.String()), ans) + return Value.Array(Type.String(), ans) self._override_static("read_lines", _read_something(parse_lines)) @@ -573,7 +573,7 @@ def parse_tsv(s: str) -> Value.Array: for line in parse_lines(s).value ] # pyre-ignore - return Value.Array(Type.Array(Type.Array(Type.String())), ans) + return Value.Array(Type.Array(Type.String()), ans) self._override_static("read_tsv", _read_something(parse_tsv)) @@ -588,7 +588,7 @@ def parse_map(s: str) -> Value.Map: raise Error.InputError("read_map(): duplicate key") keys.add(line.value[0].value) ans.append((line.value[0], line.value[1])) - return Value.Map(Type.Map((Type.String(), Type.String())), ans) + return Value.Map((Type.String(), Type.String()), ans) self._override_static("read_map", _read_something(parse_map)) @@ -599,9 +599,9 @@ def parse_json(s: str) -> Value.Base: ans = [] for k in j: ans.append((Value.String(str(k)), Value.String(str(j[k])))) - return Value.Map(Type.Map((Type.String(), Type.String())), ans) + return Value.Map((Type.String(), Type.String()), ans) if isinstance(j, list): - return Value.Array(Type.Array(Type.String()), [Value.String(str(v)) for v in j]) + return Value.Array(Type.String(), [Value.String(str(v)) for v in j]) if isinstance(j, bool): return Value.Boolean(j) if isinstance(j, int): @@ -653,7 +653,7 @@ def _serialize_lines(array: Value.Array, outfile: BinaryIO) -> None: _write_something( lambda v, outfile: _serialize_lines( Value.Array( - Type.Array(Type.String()), + Type.String(), [ Value.String( "\t".join( @@ -678,7 +678,7 @@ def _serialize_map(map: Value.Map, outfile: BinaryIO) -> None: "write_map(): keys & values must not contain tab or newline characters" ) lines.append(Value.String(k + "\t" + v)) - _serialize_lines(Value.Array(Type.Array(Type.String()), lines), outfile) + _serialize_lines(Value.Array(Type.String(), lines), outfile) self._override_static("write_map", _write_something(_serialize_map)) # pyre-ignore @@ -728,7 +728,7 @@ def _glob(pattern: Value.String, lib: OutputStdLib = self) -> Value.Array: dstrip += "" if dstrip.endswith("/") else "/" assert hf.startswith(dstrip) container_files.append(os.path.join(lib.container.container_dir, hf[len(dstrip) :])) - return Value.Array(Type.Array(Type.File()), [Value.File(fn) for fn in container_files]) + return Value.Array(Type.File(), [Value.File(fn) for fn in container_files]) self._override_static("glob", _glob) @@ -749,5 +749,5 @@ def _call_eager(self, expr: Expr.Apply, arguments: List[Value.Base]) -> Value.Ba for fn_c in files.value ] # pyre-ignore - arguments = [Value.Array(files.type, host_files)] + arguments[1:] + arguments = [Value.Array(files.type.item_type, host_files)] + arguments[1:] return super()._call_eager(expr, arguments) diff --git a/WDL/runtime/workflow.py b/WDL/runtime/workflow.py index 1b465d7e..cf3a8cf9 100644 --- a/WDL/runtime/workflow.py +++ b/WDL/runtime/workflow.py @@ -508,7 +508,7 @@ def _gather( assert v0 is None or isinstance(v0, Value.Base) # bind the array, singleton value, or None as appropriate if isinstance(gather.section, Tree.Scatter): - rhs = Value.Array(Type.Array(v0.type if v0 else Type.Any()), values) + rhs = Value.Array((v0.type if v0 else Type.Any()), values) else: assert isinstance(gather.section, Tree.Conditional) assert len(values) <= 1 diff --git a/tests/test_0eval.py b/tests/test_0eval.py index 4afb9604..7a45aaf0 100644 --- a/tests/test_0eval.py +++ b/tests/test_0eval.py @@ -301,11 +301,11 @@ def test_interpolation(self): ) def test_pair(self): - env = cons_env(("p", WDL.Value.Pair(WDL.Type.Pair(WDL.Type.Float(), WDL.Type.Float()), + env = cons_env(("p", WDL.Value.Pair(WDL.Type.Float(), WDL.Type.Float(), (WDL.Value.Float(3.14159), WDL.Value.Float(2.71828)))), - ("q", WDL.Value.Pair(WDL.Type.Pair(WDL.Type.Pair(WDL.Type.Int(), WDL.Type.Int()), - WDL.Type.Float(optional=True)), - (WDL.Value.Pair(WDL.Type.Pair(WDL.Type.Int(), WDL.Type.Int()), + ("q", WDL.Value.Pair(WDL.Type.Pair(WDL.Type.Int(), WDL.Type.Int()), + WDL.Type.Float(optional=True), + (WDL.Value.Pair(WDL.Type.Int(), WDL.Type.Int(), (WDL.Value.Int(4), WDL.Value.Int(2))), WDL.Value.Null())))) self._test_tuples( diff --git a/tests/test_1doc.py b/tests/test_1doc.py index 75d0c733..a2f78bf4 100644 --- a/tests/test_1doc.py +++ b/tests/test_1doc.py @@ -189,9 +189,9 @@ def test_placeholders(self): } """)[0] task.typecheck() - foobar = WDL.Value.Array(WDL.Type.Array(WDL.Type.String()), [WDL.Value.String("foo"), WDL.Value.String("bar")]) + foobar = WDL.Value.Array(WDL.Type.String(), [WDL.Value.String("foo"), WDL.Value.String("bar")]) self.assertEqual(task.command.parts[1].eval(WDL.Env.Bindings().bind('s', foobar)).value, 'foo, bar') - foobar = WDL.Value.Array(WDL.Type.Array(WDL.Type.String()), []) + foobar = WDL.Value.Array(WDL.Type.String(), []) self.assertEqual(task.command.parts[1].eval(WDL.Env.Bindings().bind('s', foobar)).value, '') with self.assertRaises(WDL.Error.StaticTypeMismatch): task = WDL.parse_tasks("""