Skip to content

Commit

Permalink
Merge 0d4d66c into a7eabc5
Browse files Browse the repository at this point in the history
  • Loading branch information
mlin committed Aug 19, 2019
2 parents a7eabc5 + 0d4d66c commit 0d91b70
Show file tree
Hide file tree
Showing 13 changed files with 380 additions and 281 deletions.
8 changes: 4 additions & 4 deletions WDL/Expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,8 @@ def infer_type(
# override the + operator with the within-interpolation version which accepts String?
# operands and produces a String? result
stdlib = stdlib or StdLib.Base()
with stdlib._context_override("_add", StdLib.InterpolationAddOperator()):
return super().infer_type(type_env, stdlib, check_quant)
setattr(stdlib, "_add", StdLib.InterpolationAddOperator())
return super().infer_type(type_env, stdlib, check_quant)

def _infer_type(self, type_env: Env.Bindings[Type.Base]) -> Type.Base:
if isinstance(self.expr.type, Type.Array):
Expand Down Expand Up @@ -291,8 +291,8 @@ def _eval(
# override the + operator with the within-interpolation version which evaluates to None
# if either operand is None
stdlib = stdlib or StdLib.Base()
with stdlib._context_override("_add", StdLib.InterpolationAddOperator()):
v = self.expr.eval(env, stdlib)
setattr(stdlib, "_add", StdLib.InterpolationAddOperator())
v = self.expr.eval(env, stdlib)
if isinstance(v, Value.Null):
if "default" in self.options:
return Value.String(self.options["default"])
Expand Down
236 changes: 205 additions & 31 deletions WDL/StdLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import math
import os
import re
from typing import List, Tuple, Callable
import json
import tempfile
from typing import List, Tuple, Callable, BinaryIO
from abc import ABC, abstractmethod
from contextlib import contextmanager
from . import Type, Value, Expr, Env, Error


Expand All @@ -19,7 +20,11 @@ class Base:
output sections.
"""

def __init__(self):
_write_dir: str # directory in which write_* functions create

def __init__(self, write_dir: str = ""):
self._write_dir = write_dir if write_dir else tempfile.gettempdir()

# language built-ins
self._at = _At()
self._land = _And()
Expand Down Expand Up @@ -55,41 +60,131 @@ def __init__(self):
Type.Boolean(),
lambda v: Value.Boolean(not isinstance(v, Value.Null)),
),
(
"write_lines",
[Type.Array(Type.String())],
Type.File(),
self._write(_serialize_lines),
),
(
"write_tsv",
[Type.Array(Type.Array(Type.String()))],
Type.File(),
self._write(
lambda v, outfile: _serialize_lines(
Value.Array(
Type.String(),
[
Value.String(
"\t".join(
[part.coerce(Type.String()).value for part in parts.value]
)
)
for parts in v.value
],
),
outfile,
)
),
),
(
"write_map",
[Type.Map((Type.Any(), Type.Any()))],
Type.File(),
self._write(_serialize_map),
),
(
"write_json",
[Type.Any()],
Type.File(),
self._write(lambda v, outfile: outfile.write(json.dumps(v.json).encode("utf-8"))),
),
("read_int", [Type.File()], Type.Int(), self._read(lambda s: Value.Int(int(s)))),
("read_boolean", [Type.File()], Type.Boolean(), self._read(_parse_boolean)),
(
"read_string",
[Type.File()],
Type.String(),
self._read(lambda s: Value.String(s[:-1] if s.endswith("\n") else s)),
),
(
"read_float",
[Type.File()],
Type.Float(),
self._read(lambda s: Value.Float(float(s))),
),
(
"read_map",
[Type.File()],
Type.Map((Type.String(), Type.String())),
self._read(_parse_map),
),
("read_lines", [Type.File()], Type.Array(Type.String()), self._read(_parse_lines)),
(
"read_tsv",
[Type.File()],
Type.Array(Type.Array(Type.String())),
self._read(_parse_tsv),
),
("read_json", [Type.File()], Type.Any(), self._read(_parse_json)),
# context-dependent:
("write_lines", [Type.Array(Type.String())], Type.File(), _notimpl),
("write_tsv", [Type.Array(Type.Array(Type.String()))], Type.File(), _notimpl),
("write_map", [Type.Map((Type.Any(), Type.Any()))], Type.File(), _notimpl),
("write_json", [Type.Any()], Type.File(), _notimpl),
("stdout", [], Type.File(), _notimpl),
("stderr", [], Type.File(), _notimpl),
("glob", [Type.String()], Type.Array(Type.File()), _notimpl),
("read_int", [Type.File()], Type.Int(), _notimpl),
("read_boolean", [Type.File()], Type.Boolean(), _notimpl),
("read_string", [Type.File()], Type.String(), _notimpl),
("read_float", [Type.File()], Type.Float(), _notimpl),
("read_map", [Type.File()], Type.Map((Type.String(), Type.String())), _notimpl),
("read_lines", [Type.File()], Type.Array(Type.Any()), _notimpl),
("read_tsv", [Type.File()], Type.Array(Type.Array(Type.String())), _notimpl),
("read_json", [Type.File()], Type.Any(), _notimpl),
]:
setattr(self, name, StaticFunction(name, argument_types, return_type, F))

# polymorphically typed stdlib functions which require specialized
# infer_type logic
self.range = _Range()
self.prefix = _Prefix()
self.size = _Size()
self.size = _Size(self)
self.select_first = _SelectFirst()
self.select_all = _SelectAll()
self.zip = _Zip()
self.cross = _Cross()
self.flatten = _Flatten()
self.transpose = _Transpose()

def _override(self, name: str, fn: "Function") -> None:
# replace a Function
assert isinstance(getattr(self, name), Function)
setattr(self, name, fn)
def _read(self, parse: Callable[[str], Value.Base]) -> Callable[[Value.File], Value.Base]:
"generate read_* function implementation based on parse"

def f(file: Value.File) -> Value.Base:
with open(self._devirtualize_filename(file.value), "r") as infile:
return parse(infile.read())

return f

def _devirtualize_filename(self, filename: str) -> str:
"""
'devirtualize' filename passed to a read_* function: return a filename that can be open()ed
on the local host. Subclasses may further wish to forbid access to files outside of a
designated directory or whitelist (by raising an exception)
"""
raise NotImplementedError()

def _write(
self, serialize: Callable[[Value.Base, BinaryIO], None]
) -> Callable[[Value.Base], Value.File]:
"generate write_* function implementation based on serialize"

def _f(v: Value.Base,) -> Value.File:
os.makedirs(self._write_dir, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=self._write_dir, delete=False) as outfile:
outfile: BinaryIO = outfile # pyre-ignore
serialize(v, outfile)
filename = outfile.name
vfn = self._virtualize_filename(filename)
return Value.File(vfn)

return _f

def _virtualize_filename(self, filename: str) -> str:
"""
from a local path in write_dir, 'virtualize' into the filename as it should present in a
File value
"""
raise NotImplementedError()

def _override_static(self, name: str, f: Callable) -> None:
# replace the implementation lambda of a StaticFunction (keeping its
Expand All @@ -98,16 +193,6 @@ def _override_static(self, name: str, f: Callable) -> None:
assert isinstance(sf, StaticFunction)
setattr(sf, "F", f)

@contextmanager
def _context_override(self, name: str, fn: "Function"):
# replace a Function only for the life of the contextmanager.
orig = getattr(self, name)
self._override(name, fn)
try:
yield self
finally:
self._override(name, orig)


class Function(ABC):
# Abstract interface to a standard library function implementation
Expand Down Expand Up @@ -196,6 +281,88 @@ def _notimpl(*args, **kwargs) -> None:
exec("raise NotImplementedError('function not available in this context')")


def _parse_lines(s: str) -> Value.Array:
ans = []
if s:
ans = [Value.String(line) for line in (s[:-1] if s.endswith("\n") else s).split("\n")]
return Value.Array(Type.String(), ans)


def _parse_boolean(s: str) -> Value.Boolean:
s = s.rstrip()
if s == "true":
return Value.Boolean(True)
if s == "false":
return Value.Boolean(False)
raise Error.InputError('read_boolean(): file content is not "true" or "false"')


def _parse_tsv(s: str) -> Value.Array:
# TODO: should a blank line parse as [] or ['']?
ans = [
Value.Array(
Type.Array(Type.String()), [Value.String(field) for field in line.value.split("\t")]
)
for line in _parse_lines(s).value
]
# pyre-ignore
return Value.Array(Type.Array(Type.String()), ans)


def _parse_map(s: str) -> Value.Map:
keys = set()
ans = []
for line in _parse_tsv(s).value:
assert isinstance(line, Value.Array)
if len(line.value) != 2:
raise Error.InputError("read_map(): each line must have two fields")
if line.value[0].value in keys:
raise Error.InputError("read_map(): duplicate key")
keys.add(line.value[0].value)
ans.append((line.value[0], line.value[1]))
return Value.Map((Type.String(), Type.String()), ans)


def _parse_json(s: str) -> Value.Base:
# TODO: parse int/float/boolean inside map or list as such
j = json.loads(s)
if isinstance(j, dict):
ans = []
for k in j:
ans.append((Value.String(str(k)), Value.String(str(j[k]))))
return Value.Map((Type.String(), Type.String()), ans)
if isinstance(j, list):
return Value.Array(Type.String(), [Value.String(str(v)) for v in j])
if isinstance(j, bool):
return Value.Boolean(j)
if isinstance(j, int):
return Value.Int(j)
if isinstance(j, float):
return Value.Float(j)
if j is None:
return Value.Null()
raise Error.InputError("parse_json()")


def _serialize_lines(array: Value.Array, outfile: BinaryIO) -> None:
for item in array.value:
outfile.write(item.coerce(Type.String()).value.encode("utf-8"))
outfile.write(b"\n")


def _serialize_map(map: Value.Map, outfile: BinaryIO) -> None:
lines = []
for (k, v) in map.value:
k = k.coerce(Type.String()).value
v = v.coerce(Type.String()).value
if "\n" in k or "\t" in k or "\n" in v or "\t" in v:
raise ValueError(
"write_map(): keys & values must not contain tab or newline characters"
)
lines.append(Value.String(k + "\t" + v))
_serialize_lines(Value.Array(Type.String(), lines), outfile)


def _basename(*args) -> Value.String:
assert len(args) in (1, 2)
assert isinstance(args[0], Value.String)
Expand Down Expand Up @@ -452,6 +619,10 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.

class _Size(EagerFunction):
# size(): first argument can be File? or Array[File?]
stdlib: Base

def __init__(self, stdlib: Base) -> None:
self.stdlib = stdlib

def infer_type(self, expr: "Expr.Apply") -> Type.Base:
if not expr.arguments:
Expand Down Expand Up @@ -479,7 +650,10 @@ def _call_eager(self, expr: "Expr.Apply", arguments: List[Value.Base]) -> Value.
files = arguments[0].coerce(Type.Array(Type.File()))
unit = arguments[1].coerce(Type.String()) if len(arguments) > 1 else None

ans = sum(float(os.path.getsize(fn.value)) for fn in files.value)
ans = []
for file in files.value:
ans.append(os.path.getsize(self.stdlib._devirtualize_filename(file.value)))
ans = float(sum(ans))

if unit:
try:
Expand Down
4 changes: 2 additions & 2 deletions WDL/Type.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
1. ``Int`` coerces to ``Float``
2. ``Boolean``, ``Int``, ``Float``, and ``File`` coerce to ``String``
3. ``String`` coerces to ``File``
3. ``String`` coerces to ``File``, ``Int``, and ``Float``
4. ``Array[T]`` coerces to ``String`` provided ``T`` does as well
5. ``T`` coerces to ``T?`` but the reverse is not true in general*
6. ``Array[T]+`` coerces to ``Array[T]`` but the reverse is not true in general*
Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(self, optional: bool = False) -> None:

def coerces(self, rhs: Base, check_quant: bool = True) -> bool:
""
if isinstance(rhs, File):
if isinstance(rhs, (File, Int, Float)):
return self._check_optional(rhs, check_quant)
return super().coerces(rhs, check_quant)

Expand Down
15 changes: 14 additions & 1 deletion WDL/Value.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def coerce(self, desired_type: Optional[Type.Base] = None) -> Base:
""
if isinstance(desired_type, Type.File) and not isinstance(self, File):
return File(self.value)
try:
if isinstance(desired_type, Type.Int):
return Int(int(self.value))
if isinstance(desired_type, Type.Float):
return Float(float(self.value))
except ValueError as exn:
if self.expr:
raise Error.EvalError(self.expr, "coercing String to number: " + str(exn)) from exn
raise
return super().coerce(desired_type)


Expand Down Expand Up @@ -281,9 +290,13 @@ def __init__(
super().__init__(type, value)
self.value = dict(value)
if isinstance(type, Type.StructInstance):
assert type.members
# coerce values to member types
for k in self.value:
assert k in type.members
self.value[k] = self.value[k].coerce(type.members[k])
# if initializer (map or object literal) omits optional members,
# fill them in with null
assert type.members
for k in type.members:
if k not in self.value:
assert type.members[k].optional
Expand Down
Loading

0 comments on commit 0d91b70

Please sign in to comment.