Skip to content

Commit

Permalink
add unzip()
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Jan 26, 2021
1 parent 1a34833 commit b5489df
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
40 changes: 40 additions & 0 deletions WDL/StdLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def sep(sep: Value.String, iterable: Value.Array) -> Value.String:
self.select_first = _SelectFirst()
self.select_all = _SelectAll()
self.zip = _Zip()
self.unzip = _Unzip()
self.cross = _Cross()
self.flatten = _Flatten()
self.transpose = _Transpose()
Expand Down Expand Up @@ -829,6 +830,45 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
)


class _Unzip(EagerFunction):
# Array[Pair[X,Y]] -> Pair[Array[X],Array[Y]]
def infer_type(self, expr: "Expr.Apply") -> Type.Base:
if len(expr.arguments) != 1:
raise Error.WrongArity(expr, 1)
arg0ty: Type.Base = expr.arguments[0].type
if (
not isinstance(arg0ty, Type.Array)
or (expr._check_quant and arg0ty.optional)
or not isinstance(arg0ty.item_type, Type.Pair)
or (expr._check_quant and arg0ty.item_type.optional)
):
raise Error.StaticTypeMismatch(
expr.arguments[0], Type.Array(Type.Pair(Type.Any(), Type.Any())), arg0ty
)
return Type.Pair(
Type.Array(arg0ty.item_type.left_type, nonempty=arg0ty.nonempty),
Type.Array(arg0ty.item_type.right_type, nonempty=arg0ty.nonempty),
)

def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.Pair:
pty = self.infer_type(expr)
assert isinstance(pty, Type.Pair)
lty = pty.left_type
assert isinstance(lty, Type.Array)
rty = pty.right_type
assert isinstance(rty, Type.Array)
arr = arguments[0]
assert isinstance(arr, Value.Array)
return Value.Pair(
lty,
rty,
(
Value.Array(lty.item_type, [p.value[0] for p in arr.value]),
Value.Array(rty.item_type, [p.value[1] for p in arr.value]),
),
)


class _Flatten(EagerFunction):
# t array array -> t array
# TODO: if any of the input arrays are statically nonempty then so is output
Expand Down
36 changes: 36 additions & 0 deletions tests/test_5stdlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,42 @@ def test_zip_cross(self):
}
""", expected_exception=WDL.Error.EvalError)

def test_unzip(self):
outputs = self._test_task(R"""
version 1.0
task hello {
Array[Int] xs = [ 1, 2, 3 ]
Array[String] ys = [ "a", "b", "c" ]
Array[String] zs = [ "d", "e" ]
command {}
output {
Pair[Array[Int], Array[String]] unzipped = unzip(zip(xs, ys))
Pair[Array[Int], Array[String]] uncrossed = unzip(cross(xs, zs))
}
}
""")
self.assertEqual(outputs["unzipped"], {
"left": [1, 2, 3],
"right": ["a", "b", "c"]
})
self.assertEqual(outputs["uncrossed"], {
"left": [1, 1, 2, 2, 3, 3],
"right": ["d", "e", "d", "e", "d", "e"]
})

outputs = self._test_task(R"""
version 1.0
task hello {
input {
Array[Array[Int]] x
}
command {}
output {
Array[Pair[Int, Int]] zipped = unzip(x)
}
}
""", expected_exception=WDL.Error.StaticTypeMismatch)

def test_sep(self):
outputs = self._test_task(R"""
version development
Expand Down

0 comments on commit b5489df

Please sign in to comment.