Skip to content

Commit

Permalink
[hail] Introduce ArrayFold2 to permit more deforesting (#6981)
Browse files Browse the repository at this point in the history
* [hail] Introduce `ArrayFold2` to permit more deforesting

Problem 1: functions like `min`/`max` need to do a length check,
preventing us from implementing them in a single primitive fold.

Problem 2: functions like `mean` use  a struct as the accumulator,
leading to allocation (!) every element.

Solution to both: make it possible to have multiple primitive
accumulators. This is `ArrayFold2`.

The node is different from `ArrayFold` in that it:
 * has as sequence of accumulators, not just one
 * has a sequence of seq ops, one for each accumulator. Each of these
   sequence ops can see all the accumulators, and will see the updated
   value from sequence operations with a smaller index.
 * has a result op, which is a function from accumulators to result.

By changing `min`/`max` to use ArrayFold2 and inlining these functions,
we can get a reasonable speedup on `split_multi_hts`:

#6980 (this PR's parent):

```
2019-09-03 07:11:16,374: INFO:     burn in: 42.33s
2019-09-03 07:11:56,085: INFO:     run 1: 39.71s
2019-09-03 07:12:34,916: INFO:     run 2: 38.83s
2019-09-03 07:13:14,087: INFO:     run 3: 39.17s
```

PR:
```
2019-09-03 07:32:10,416: INFO:     burn in: 38.03s
2019-09-03 07:32:39,237: INFO:     run 1: 28.82s
2019-09-03 07:33:07,778: INFO:     run 2: 28.50s
2019-09-03 07:33:35,997: INFO:     run 3: 28.21s
```

* mean benchmarks

* Ported `ArrayFunctions.mean` to ArrayFold2.

Benchmark:
```python
@benchmark
def table_range_means():
    ht = hl.utils.range_table(10_000_000, 16)
    ht = ht.annotate(m = hl.mean(hl.range(0, ht.idx % 1111)))
    ht._force_count()
```

Master:
```
2019-09-03 09:39:05,777: INFO: [1/1] Running table_range_means...
2019-09-03 09:40:52,557: INFO:     burn in: 106.78s
2019-09-03 09:42:34,333: INFO:     run 1: 101.78s
2019-09-03 09:44:14,982: INFO:     run 2: 100.65s
2019-09-03 09:45:53,590: INFO:     run 3: 98.61s
```

PR:
```
2019-09-03 09:47:26,110: INFO: [1/1] Running table_range_means...
2019-09-03 09:47:29,465: INFO:     burn in: 3.35s
2019-09-03 09:47:32,615: INFO:     run 1: 3.15s
2019-09-03 09:47:35,703: INFO:     run 2: 3.09s
2019-09-03 09:47:38,840: INFO:     run 3: 3.14s
```

* add back in arrayfor

* add interpret rule

* add interpret rule

* fix

* remove unused var

* bump!
  • Loading branch information
tpoterba authored and danking committed Sep 16, 2019
1 parent 6fcd804 commit e067ad0
Show file tree
Hide file tree
Showing 24 changed files with 234 additions and 91 deletions.
6 changes: 3 additions & 3 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3581,7 +3581,7 @@ def mean(collection, filter_missing: bool = True) -> Float64Expression:
-------
:class:`.Expression` of type :py:data:`.tfloat64`
"""
return collection._filter_missing_method(filter_missing, "mean", tfloat64)
return array(collection)._filter_missing_method(filter_missing, "mean", tfloat64)


@typecheck(collection=expr_oneof(expr_set(expr_numeric), expr_array(expr_numeric)))
Expand Down Expand Up @@ -3641,7 +3641,7 @@ def product(collection, filter_missing: bool = True) -> NumericExpression:
-------
:class:`.NumericExpression`
"""
return collection._filter_missing_method(filter_missing, "product", collection.dtype.element_type)
return array(collection)._filter_missing_method(filter_missing, "product", collection.dtype.element_type)


@typecheck(collection=expr_oneof(expr_set(expr_numeric), expr_array(expr_numeric)),
Expand Down Expand Up @@ -3672,7 +3672,7 @@ def sum(collection, filter_missing: bool = True) -> NumericExpression:
-------
:class:`.NumericExpression`
"""
return collection._filter_missing_method(filter_missing, "sum", collection.dtype.element_type)
return array(collection)._filter_missing_method(filter_missing, "sum", collection.dtype.element_type)


@typecheck(a=expr_array(expr_numeric),
Expand Down
3 changes: 0 additions & 3 deletions hail/python/hail/ir/register_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def register_functions():
register_function("median", (dtype("set<?T:numeric>"),), dtype("?T"))
register_function("median", (dtype("array<?T:numeric>"),), dtype("?T"))
register_function("uniqueMinIndex", (dtype("array<?T>"),), dtype("int32"))
register_function("mean", (dtype("set<?T:numeric>"),), dtype("float64"))
register_function("mean", (dtype("array<?T:numeric>"),), dtype("float64"))
register_function("toFloat32", (dtype("?T:numeric"),), dtype("float32"))
register_function("uniqueMaxIndex", (dtype("array<?T>"),), dtype("int32"))
Expand Down Expand Up @@ -57,7 +56,6 @@ def ndarray_floating_point_divide(arg_type, ret_type):
register_function("nanmax", (dtype("?T"),dtype("?T"),), dtype("?T"))
register_function("max_ignore_missing", (dtype("?T"),dtype("?T"),), dtype("?T"))
register_function("nanmax_ignore_missing", (dtype("?T"),dtype("?T"),), dtype("?T"))
register_function("product", (dtype("set<?T:numeric>"),), dtype("?T"))
register_function("product", (dtype("array<?T:numeric>"),), dtype("?T"))
register_function("toInt32", (dtype("?T:numeric"),), dtype("int32"))
register_function("extend", (dtype("array<?T>"),dtype("array<?T>"),), dtype("array<?T>"))
Expand Down Expand Up @@ -103,7 +101,6 @@ def ndarray_floating_point_divide(arg_type, ret_type):
register_function("nanmin", (dtype("?T"),dtype("?T"),), dtype("?T"))
register_function("min_ignore_missing", (dtype("?T"),dtype("?T"),), dtype("?T"))
register_function("nanmin_ignore_missing", (dtype("?T"),dtype("?T"),), dtype("?T"))
register_function("sum", (dtype("set<?T:numeric>"),), dtype("?T"))
register_function("sum", (dtype("array<?T:numeric>"),), dtype("?T"))
register_function("toInt64", (dtype("?T:numeric"),), dtype("int64"))
register_function("contains", (dtype("dict<?key, ?value>"),dtype("?key"),), dtype("bool"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from os import path
from tempfile import TemporaryDirectory
import hail as hl
Expand Down Expand Up @@ -146,6 +147,13 @@ def table_aggregate_int_stats():
*(hl.agg.explode(lambda elt: hl.agg.stats(elt), ht[f'array{i}']) for i in range(2))]))


@benchmark
def table_range_means():
ht = hl.utils.range_table(10_000_000, 16)
ht = ht.annotate(m=hl.mean(hl.range(0, ht.idx % 1111)))
ht._force_count()


@benchmark
def table_aggregate_counter():
ht = hl.read_table(resource('many_strings_table.ht'))
Expand Down
7 changes: 7 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ object Bindings {
case ArrayFlatMap(a, name, _) => if (i == 1) Array(name -> -coerce[TStreamable](a.typ).elementType) else empty
case ArrayFilter(a, name, _) => if (i == 1) Array(name -> -coerce[TStreamable](a.typ).elementType) else empty
case ArrayFold(a, zero, accumName, valueName, _) => if (i == 2) Array(accumName -> zero.typ, valueName -> -coerce[TStreamable](a.typ).elementType) else empty
case ArrayFold2(a, accum, valueName, seq, result) =>
if (i <= accum.length)
empty
else if (i < 2 * accum.length + 1)
Array((valueName, -coerce[TStreamable](a.typ).elementType)) ++ accum.map { case (name, value) => (name, value.typ) }
else
accum.map { case (name, value) => (name, value.typ) }
case ArrayScan(a, zero, accumName, valueName, _) => if (i == 2) Array(accumName -> zero.typ, valueName -> -coerce[TStreamable](a.typ).elementType) else empty
case ArrayAggScan(a, name, _) => if (i == 1) FastIndexedSeq(name -> a.typ.asInstanceOf[TStreamable].elementType) else empty
case ArrayLeftJoinDistinct(ll, rr, l, r, _, _) => if (i == 2 || i == 3) Array(l -> -coerce[TStreamable](ll.typ).elementType, r -> -coerce[TStreamable](rr.typ).elementType) else empty
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Children.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ object Children {
Array(a, body)
case ArrayFold(a, zero, accumName, valueName, body) =>
Array(a, zero, body)
case ArrayFold2(a, accum, valueName, seq, result) =>
Array(a) ++ accum.map(_._2) ++ seq ++ Array(result)
case ArrayScan(a, zero, accumName, valueName, body) =>
Array(a, zero, body)
case ArrayLeftJoinDistinct(left, right, l, r, compare, join) =>
Expand Down
8 changes: 8 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Copy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ object Copy {
case ArrayFold(_, _, accumName, valueName, _) =>
val IndexedSeq(a: IR, zero: IR, body: IR) = newChildren
ArrayFold(a, zero, accumName, valueName, body)
case ArrayFold2(_, accum, valueName, seq, _) =>
val ncIR = newChildren.map(_.asInstanceOf[IR])
assert(newChildren.length == 2 + accum.length + seq.length)
ArrayFold2(ncIR(0),
accum.indices.map(i => (accum(i)._1, ncIR(i + 1))),
valueName,
seq.indices.map(i => ncIR(i + 1 + accum.length)), ncIR.last)
case ArrayScan(_, _, accumName, valueName, _) =>
val IndexedSeq(a: IR, zero: IR, body: IR) = newChildren
ArrayScan(a, zero, accumName, valueName, body)
Expand Down Expand Up @@ -214,6 +221,7 @@ object Copy {
case x@ApplyIR(fn, args) =>
val r = ApplyIR(fn, newChildren.map(_.asInstanceOf[IR]))
r.conversion = x.conversion
r.inline = x.inline
r
case Apply(fn, args, t) =>
Apply(fn, newChildren.map(_.asInstanceOf[IR]), t)
Expand Down
61 changes: 61 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,66 @@ private class Emit(
processAElts.addElements))),
xmaccum, xvaccum)

case ArrayFold2(a, acc, valueName, seq, res) =>
val typ = ir.typ
val tarray = coerce[TStreamable](a.typ)
val tti = typeToTypeInfo(typ)
val eti = typeToTypeInfo(tarray.elementType)
val xmv = mb.newField[Boolean](valueName + "_missing")
val xvv = coerce[Any](mb.newField(valueName)(eti))
val accVars = acc.map { case (name, value) =>
val ti = typeToTypeInfo(value.typ)
(name, (ti, mb.newField[Boolean](s"${name}_missing"), mb.newField(name)(ti)))}
val xmtmp = mb.newField[Boolean]("arrayfold2_missing_tmp")

val resEnv = env.bindIterable(accVars.map { case (name, (ti, xm, xv)) => (name, (ti, xm.load(), xv.load())) })
val seqEnv = resEnv.bind(valueName, (eti, xmv.load(), xvv.load()))

val codeZ = acc.map { case (_, value) => emit(value) }
val codeSeq = seq.map(emit(_, env = seqEnv))

val aBase = emitArrayIterator(a)

val cont = { (m: Code[Boolean], v: Code[_]) =>
Code(
xmv := m,
xvv := xmv.mux(defaultValue(tarray.elementType), v),
Code(codeSeq.map(_.setup): _*),
coerce[Unit](Code(codeSeq.zipWithIndex.map { case (et, i) =>
val (_, (_, accm, accv)) = accVars(i)
Code(
xmtmp := et.m,
accv.storeAny(xmtmp.mux(defaultValue(acc(i)._2.typ): Code[_], et.v)),
accm := xmtmp
)
}: _*)))
}

val processAElts = aBase.arrayEmitter(cont)
val marray = processAElts.m.getOrElse(const(false))

val xresm = mb.newField[Boolean]
val xresv = mb.newField(typeToTypeInfo(res.typ))
val codeR = emit(res, env = resEnv)

EmitTriplet(Code(
codeZ.map(_.setup),
accVars.zipWithIndex.map { case ((_, (ti, xm, xv)), i) =>
Code(xm := codeZ(i).m, xv.storeAny(xm.mux(defaultValue(acc(i)._2.typ), codeZ(i).v)))
},
processAElts.setup,
marray.mux(
Code(
xresm := true,
xresv.storeAny(defaultValue(res.typ))),
Code(
aBase.calcLength,
processAElts.addElements,
codeR.setup,
xresm := codeR.m,
xresv.storeAny(codeR.v)))),
xresm, xresv)

case ArrayFor(a, valueName, body) =>
val tarray = coerce[TStreamable](a.typ)
val eti = typeToTypeInfo(tarray.elementType)
Expand Down Expand Up @@ -1238,6 +1298,7 @@ private class Emit(
false,
defaultValue(typ))
case ir@ApplyIR(fn, args) =>
assert(!ir.inline)
val mfield = mb.newField[Boolean]
val vfield = mb.newField()(typeToTypeInfo(ir.typ))

Expand Down
11 changes: 11 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,16 @@ final case class ArrayFlatMap(a: IR, name: String, body: IR) extends IR {
}
final case class ArrayFold(a: IR, zero: IR, accumName: String, valueName: String, body: IR) extends IR

object ArrayFold2 {
def apply(a: ArrayFold): ArrayFold2 = {
ArrayFold2(a.a, FastIndexedSeq((a.accumName, a.zero)), a.valueName, FastSeq(a.body), Ref(a.accumName, a.zero.typ))
}
}

final case class ArrayFold2(a: IR, accum: IndexedSeq[(String, IR)], valueName: String, seq: IndexedSeq[IR], result: IR) extends IR {
assert(accum.length == seq.length)
}

final case class ArrayScan(a: IR, zero: IR, accumName: String, valueName: String, body: IR) extends IR

final case class ArrayFor(a: IR, valueName: String, body: IR) extends IR
Expand Down Expand Up @@ -361,6 +371,7 @@ final case class Die(message: IR, _typ: Type) extends IR

final case class ApplyIR(function: String, args: Seq[IR]) extends IR {
var conversion: Seq[IR] => IR = _
var inline: Boolean = _

private lazy val refs = args.map(a => Ref(genUID(), a.typ)).toArray
lazy val body: IR = conversion(refs).deepCopy()
Expand Down
13 changes: 11 additions & 2 deletions hail/src/main/scala/is/hail/expr/ir/InferPType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ object InferPType {

zero.pType2.setRequired(body.pType2.required)
}
case ArrayFold2(a, acc, valueName, seq, res) =>
InferPType(a, env)
acc.foreach { case (_, accIR) => InferPType(accIR, env) }
InferPType(a, env)
val resEnv = env.bind(acc.map { case (name, accIR) => (name, accIR.pType2)}: _*)
val seqEnv = resEnv.bind(valueName -> a.pType2.asInstanceOf[PArray].elementType)
seq.foreach(InferPType(_, seqEnv))
InferPType(res, resEnv)
res.pType2.setRequired(res.pType2.required && a.pType2.required)
case ArrayScan(a, zero, accumName, valueName, body) => {
InferPType(zero, env)

Expand Down Expand Up @@ -238,7 +247,7 @@ object InferPType {
}
case NDArrayRef(nd, idxs) => {
InferPType(nd, env)

var allRequired = nd.pType2.required
val it = idxs.iterator
while(it.hasNext) {
Expand All @@ -247,7 +256,7 @@ object InferPType {
InferPType(idxIR, env)

assert(idxIR.pType2.isOfType(PInt64()) || idxIR.pType2.isOfType(PInt32()))

if(allRequired == true && idxIR.pType2.required == false) {
allRequired = false
}
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/InferType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ object InferType {
case ArrayFold(a, zero, accumName, valueName, body) =>
assert(body.typ == zero.typ)
zero.typ
case ArrayFold2(_, _, _, _, result) => result.typ
case ArrayScan(a, zero, accumName, valueName, body) =>
assert(body.typ == zero.typ)
coerce[TStreamable](a.typ).copyStreamable(zero.typ)
Expand Down
15 changes: 15 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Interpret.scala
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,21 @@ object Interpret {
}
zeroValue
}
case ArrayFold2(a, accum, valueName, seq, res) =>
val aValue = interpret(a, env, args, aggArgs)
if (aValue == null)
null
else {
val accVals = accum.map { case (name, value) => (name, interpret(value, env, args, aggArgs)) }
var e = env.bindIterable(accVals)
aValue.asInstanceOf[IndexedSeq[Any]].foreach { elt =>
e = e.bind(valueName, elt)
accVals.indices.foreach { i =>
e = e.bind(accum(i)._1, interpret(seq(i), e, args, aggArgs))
}
}
interpret(res, e.delete(valueName), args, aggArgs)
}
case ArrayScan(a, zero, accumName, valueName, body) =>
val aValue = interpret(a, env, args, aggArgs)
if (aValue == null)
Expand Down
9 changes: 9 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean =
val newAccumName = gen()
val newValueName = gen()
ArrayFold(normalize(a), normalize(zero), newAccumName, newValueName, normalize(body, env.bindEval(accumName -> newAccumName, valueName -> newValueName)))
case ArrayFold2(a, accum, valueName, seq, res) =>
val newValueName = gen()
val (accNames, newAcc) = accum.map { case (old, ir) =>
val newName = gen()
((old, newName), (newName, normalize(ir)))
}.unzip
val resEnv = env.bindEval(accNames: _*)
val seqEnv = resEnv.bindEval(valueName, newValueName)
ArrayFold2(normalize(a), newAcc, newValueName, seq.map(normalize(_, seqEnv)), normalize(res, resEnv))
case ArrayScan(a, zero, accumName, valueName, body) =>
val newAccumName = gen()
val newValueName = gen()
Expand Down
11 changes: 11 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,17 @@ object IRParser {
val eltType = -coerce[TStreamable](a.typ).elementType
val body = ir_value_expr(env.update(Map(accumName -> zero.typ, valueName -> eltType)))(it)
ArrayFold(a, zero, accumName, valueName, body)
case "ArrayFold2" =>
val accumNames = identifiers(it)
val valueName = identifier(it)
val a = ir_value_expr(env)(it)
val accs = accumNames.map(name => (name, ir_value_expr(env)(it)))
val eltType = -coerce[TStreamable](a.typ).elementType
val resultEnv = env.update(accs.map { case (name, value) => (name, value.typ) }.toMap)
val seqEnv = resultEnv.update(Map(valueName -> eltType))
val seqs = Array.tabulate(accs.length)(_ => ir_value_expr(seqEnv)(it))
val res = ir_value_expr(resultEnv)(it)
ArrayFold2(a, accs, valueName, seqs, res)
case "ArrayScan" =>
val accumName = identifier(it)
val valueName = identifier(it)
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ object Pretty {
case ArrayFilter(_, name, _) => prettyIdentifier(name)
case ArrayFlatMap(_, name, _) => prettyIdentifier(name)
case ArrayFold(_, _, accumName, valueName, _) => prettyIdentifier(accumName) + " " + prettyIdentifier(valueName)
case ArrayFold2(_, acc, valueName, _, _) => prettyIdentifiers(acc.map(_._1)) + " " + prettyIdentifier(valueName)
case ArrayScan(_, _, accumName, valueName, _) => prettyIdentifier(accumName) + " " + prettyIdentifier(valueName)
case ArrayLeftJoinDistinct(_, _, l, r, _, _) => prettyIdentifier(l) + " " + prettyIdentifier(r)
case ArrayFor(_, valueName, _) => prettyIdentifier(valueName)
Expand Down
18 changes: 18 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,24 @@ object PruneDeadFields {
bodyEnv.deleteEval(valueName).deleteEval(accumName),
memoizeValueIR(a, aType.copyStreamable(valueType), memo)
)
case ArrayFold2(a, accum, valueName, seq, res) =>
val aType = a.typ.asInstanceOf[TStreamable]
val zeroEnvs = accum.map { case (name, zval) => memoizeValueIR(zval, zval.typ, memo) }
val seqEnvs = seq.map { seq => memoizeValueIR(seq, seq.typ, memo) }
val resEnv = memoizeValueIR(res, requestedType, memo)
val valueType = unifySeq(
aType.elementType,
resEnv.eval.lookupOption(valueName).map(_.result()).getOrElse(Array()) ++
seqEnvs.flatMap(_.eval.lookupOption(valueName).map(_.result()).getOrElse(Array())))

val accumNames = accum.map(_._1)
val seqNames = accumNames ++ Array(valueName)
unifyEnvsSeq(
zeroEnvs
++ Array(resEnv.copy(eval = resEnv.eval.delete(accumNames)))
++ seqEnvs.map(e => e.copy(eval = e.eval.delete(seqNames)))
++ Array(memoizeValueIR(a, aType.copyStreamable(valueType), memo))
)
case ArrayScan(a, zero, accumName, valueName, body) =>
val aType = a.typ.asInstanceOf[TStreamable]
val zeroEnv = memoizeValueIR(zero, zero.typ, memo)
Expand Down
2 changes: 1 addition & 1 deletion hail/src/main/scala/is/hail/expr/ir/Simplify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ object Simplify {

case ApplyIR("contains", Seq(ToSet(x), element)) if x.typ.isInstanceOf[TArray] => invoke("contains", TBoolean(), x, element)

case x: ApplyIR if x.body.size < 10 => x.explicitNode
case x: ApplyIR if x.inline || x.body.size < 10 => x.explicitNode

case ArrayLen(MakeArray(args, _)) => I32(args.length)

Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/Streamify.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ object Streamify {
case ToSet(a) => ToSet(streamify(a))
case ToDict(a) => ToDict(streamify(a))
case ArrayFold(a, zero, zn, an, body) => ArrayFold(streamify(a), zero, zn, an, body)
case ArrayFold2(a, acc, vn, seq, res) => ArrayFold2(streamify(a), acc, vn, seq, res)
case ArrayFor(a, n, b) => ArrayFor(streamify(a), n, b)
case x: ApplyIR => apply(x.explicitNode)
case _ if node.typ.isInstanceOf[TStreamable] => unstreamify(node)
Expand Down
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ object TypeCheck {
assert(a.typ.isInstanceOf[TStreamable])
assert(body.typ == zero.typ)
assert(x.typ == zero.typ)
case x@ArrayFold2(a, accum, valueName, seq, res) =>
assert(a.typ.isInstanceOf[TStreamable])
assert(x.typ == res.typ)
assert(accum.zip(seq).forall { case ((_, z), s) => s.typ == z.typ })
case x@ArrayScan(a, zero, accumName, valueName, body) =>
assert(a.typ.isInstanceOf[TStreamable])
assert(body.typ == zero.typ)
Expand Down
Loading

0 comments on commit e067ad0

Please sign in to comment.