Skip to content

Commit

Permalink
working python test
Browse files Browse the repository at this point in the history
  • Loading branch information
wang committed Dec 4, 2019
1 parent d032d3c commit 7b3a849
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 26 deletions.
4 changes: 3 additions & 1 deletion hail/python/hail/experimental/__init__.py
Expand Up @@ -19,6 +19,7 @@
from .db import DB
from .compile import compile_comparison_binary, compiled_compare
from .vcf_combiner import lgt_to_gt
from .loop import loop

__all__ = ['ld_score',
'ld_score_regression',
Expand Down Expand Up @@ -51,4 +52,5 @@
'decode',
'compile_comparison_binary',
'compiled_compare',
'lgt_to_gt']
'lgt_to_gt',
'loop']
10 changes: 6 additions & 4 deletions hail/python/hail/experimental/loop.py
Expand Up @@ -22,8 +22,6 @@ def loop(f: Callable, typ, *exprs):
Notes
-----
The first argument to the lambda is a marker for the recursive call.
Some infinite loop detection is done, and if an infinite loop is detected, `loop` will not
typecheck.
Parameters
----------
f : function ( (marker, *args) -> :class:`.Expression`
Expand All @@ -44,6 +42,8 @@ def contains_recursive_call(non_recursive):
return True
if isinstance(non_recursive, ir.TailLoop):
return False
if len(non_recursive.children) == 0:
return False
return all([contains_recursive_call(c) for c in non_recursive.children])

def check_tail_recursive(loop_ir):
Expand All @@ -59,6 +59,8 @@ def check_tail_recursive(loop_ir):
elif not isinstance(loop_ir, ir.Recur) and contains_recursive_call(loop_ir):
raise TypeError("found recursive expression outside of tail position!")

loop_name = Env.get_uid()

@typecheck(recur_exprs=expr_any)
def make_loop(*recur_exprs):
if len(recur_exprs) != len(exprs):
Expand All @@ -74,7 +76,7 @@ def make_loop(*recur_exprs):
raise TypeError(err)
irs = [expr._ir for expr in recur_exprs]
indices, aggregations = unify_all(*recur_exprs)
return construct_expr(ir.Recur(irs, typ), typ, indices, aggregations)
return construct_expr(ir.Recur(loop_name, irs, typ), typ, indices, aggregations)

uid_irs = []
loop_vars = []
Expand All @@ -89,4 +91,4 @@ def make_loop(*recur_exprs):
indices, aggregations = unify_all(*exprs, loop_f)
if loop_f.dtype != typ:
raise TypeError(f"requested type {typ} does not match inferred type {loop_f.typ}")
return construct_expr(ir.TailLoop(uid_irs, loop_f._ir), loop_f.dtype, indices, aggregations)
return construct_expr(ir.TailLoop(loop_name, loop_f._ir, uid_irs), loop_f.dtype, indices, aggregations)
66 changes: 66 additions & 0 deletions hail/python/hail/ir/ir.py
Expand Up @@ -388,6 +388,72 @@ def _compute_type(self, env, agg_env):
self._type = env[self.name]


class TailLoop(IR):

@typecheck_method(name=str, body=IR, params=sequenceof(sized_tupleof(str, IR)))
def __init__(self, name, body, params):
super().__init__(*([v for n, v in params] + [body]))
self.name = name
self.params = params
self.body = body

def copy(self, *children):
params = children[:-1]
body = children[-1]
assert len(params) == len(self.params)
return TailLoop(self.name, [(n, v) for (n, _), v in zip(self.params, params)], body)

def head_str(self):
return f'{escape_id(self.name)} ({" ".join([escape_id(n) for n, _ in self.params])})'

def _eq(self, other):
return self.name == other.name

def _compute_type(self, env, agg_env):
self._type = self.body._compute_type(env, agg_env)

@property
def bound_variables(self):
return {n for n, _ in self.params} | {self.name} | super().bound_variables

def _compute_type(self, env, agg_env):
self.body._compute_type(_env_bind(env, self.bindings(len(self.params))), agg_env)
self._type = self.body.typ

def renderable_bindings(self, i, default_value=None):
if i == len(self.params):
if default_value is None:
return {self.name: None, **{n: v.typ for n, v in self.params}}
else:
value = default_value
return {self.name: value, **{n: value for n, _ in self.params}}
else:
return {}


class Recur(IR):
@typecheck_method(name=str, args=sequenceof(IR), return_type=hail_type)
def __init__(self, name, args, return_type):
super().__init__(*args)
self.name = name
self.args = args
self.return_type = return_type
self._free_vars = {name}

def copy(self, args):
return Recur(self.name, args, self.return_type)

def head_str(self):
return f'{escape_id(self.name)} {self.return_type._parsable_string()}'

def _eq(self, other):
return other.name == self.name

def _compute_type(self, env, agg_env):
assert self.name in env
self._type = self.return_type


class ApplyBinaryPrimOp(IR):
@typecheck_method(op=str, l=IR, r=IR)
def __init__(self, op, l, r):
Expand Down
14 changes: 14 additions & 0 deletions hail/python/test/hail/experimental/test_experimental.py
Expand Up @@ -332,3 +332,17 @@ def test_export_block_matrices(self):
a = arrs[i]
a2 = np.loadtxt(f'{prefix}/files/{i}.tsv')
self.assertTrue(np.array_equal(a, a2))

def test_loop(self):
print("foo")
from math import factorial

def fact(f, x, accum):
return hl.cond(x > 0,
f(x - 1, accum * x),
accum)

def hail_factorial(n):
return hl.experimental.loop(fact, n, 1)

self.assertEquals(hl.eval(hail_factorial(20)), factorial(20))
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Emit.scala
Expand Up @@ -1847,7 +1847,7 @@ private class Emit(
.bind(name, (typeToTypeInfo(x.typ), const(false), label.goto))

val newLoopEnv = loopEnv.getOrElse(Env.empty)
val bodyT = emit(body, argEnv, loopEnv = newLoopEnv.bind(name, loopRefs.map(_._2).toArray))
val bodyT = emit(body, argEnv, loopEnv = Some(newLoopEnv.bind(name, loopRefs.map(_._2).toArray)))
val bodyF = Code(
bodyT.setup,
m := bodyT.m,
Expand Down
8 changes: 5 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Expand Up @@ -627,9 +627,11 @@ object IRParser {
AggLet(name, value, body, isScan)
case "TailLoop" =>
val name = identifier(it)
val body = ir_value_expr(env)(it)
val args = named_value_irs(env)(it)
TailLoop(name, args, body)
val paramNames = identifiers(it)
val params = paramNames.map { n => n -> ir_value_expr(env)(it) }
val bodyEnv = env.update(params.map { case (n, v) => n -> v.typ}.toMap)
val body = ir_value_expr(bodyEnv)(it)
TailLoop(name, params, body)
case "Recur" =>
val name = identifier(it)
val typ = type_expr(env.typEnv)(it)
Expand Down
18 changes: 1 addition & 17 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Expand Up @@ -216,23 +216,6 @@ object Pretty {
sb += ')'
}(sb += '\n')
}
case TailLoop(name, args, body) =>
sb += '\n'
prettyIdentifier(name)
sb += '\n'
pretty(body, depth + 2)
sb += '\n'
if (args.nonEmpty) {
sb += '\n'
args.foreachBetween { case (n, a) =>
sb.append(" " * (depth + 2))
sb += '('
sb.append(prettyIdentifier(n))
sb += '\n'
pretty(a, depth + 4)
sb += ')'
}(sb += '\n')
}
case _ =>
val header = ir match {
case I32(x) => x.toString
Expand All @@ -252,6 +235,7 @@ object Pretty {
)
case Let(name, _, _) => prettyIdentifier(name)
case AggLet(name, _, _, isScan) => prettyIdentifier(name) + " " + prettyBooleanLiteral(isScan)
case TailLoop(name, args, _) => prettyIdentifier(name) + " " + prettyIdentifiers(args.map(_._1).toFastIndexedSeq)
case Recur(name, _, t) => prettyIdentifier(name) + " " + t.parsableString()
case Ref(name, _) => prettyIdentifier(name)
case RelationalRef(name, t) => prettyIdentifier(name) + " " + t.parsableString()
Expand Down

0 comments on commit 7b3a849

Please sign in to comment.