Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[query] Add StreamZipJoinProducers node to zip join streams defined by an IR function #13222

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion hail/python/hail/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
contig_length, liftover, min_rep, uniroot, format, approx_equal, reversed, bit_and, bit_or,
bit_xor, bit_lshift, bit_rshift, bit_not, bit_count, binary_search, logit, expit, _values_similar,
_showstr, _sort_by, _compare, _locus_windows_per_contig, shuffle, _console_log, dnorm, dchisq,
query_table, keyed_union, keyed_intersection, repeat)
query_table, keyed_union, keyed_intersection, repeat, _zip_join_producers)

__all__ = ['HailType',
'hail_type',
Expand Down Expand Up @@ -309,5 +309,6 @@
'query_table',
'keyed_union',
'keyed_intersection',
'_zip_join_producers',
'repeat',
]
26 changes: 26 additions & 0 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5058,6 +5058,32 @@ def _union_intersection_base(name, arrays, key, join_f, result_f):
return result_f(construct_expr(zj, zj.typ, indices, aggs))


def _zip_join_producers(contexts, stream_f, key, join_f):
ctx_uid = Env.get_uid()

ctx_var = construct_variable(ctx_uid, contexts.dtype.element_type)
stream_req = stream_f(ctx_var)
make_prod_ir = stream_req._ir
if isinstance(make_prod_ir.typ, hl.tarray):
make_prod_ir = ir.ToStream(make_prod_ir)
t = stream_req.dtype.element_type

key_typ = hl.tstruct(**{k: t[k] for k in key})
vals_typ = hl.tarray(t)

key_uid = Env.get_uid()
vals_uid = Env.get_uid()

key_var = construct_variable(key_uid, key_typ)
vals_var = construct_variable(vals_uid, vals_typ)

join_ir = join_f(key_var, vals_var)
zj = ir.ToArray(
ir.StreamZipJoinProducers(contexts._ir, ctx_uid, make_prod_ir, key, key_uid, vals_uid, join_ir._ir))
indices, aggs = unify_all(contexts, stream_req, join_ir)
return construct_expr(zj, zj.typ, indices, aggs)


@typecheck(arrays=expr_oneof(expr_stream(expr_any), expr_array(expr_any)), key=sequenceof(builtins.str))
def keyed_intersection(*arrays, key):
"""Compute the intersection of sorted arrays on a given key.
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hail/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
GetTupleElement, Die, ConsoleLog, Apply, ApplySeeded, RNGStateLiteral, RNGSplit,\
TableCount, TableGetGlobals, TableCollect, TableAggregate, MatrixCount, \
MatrixAggregate, TableWrite, udf, subst, clear_session_functions, ReadPartition, \
PartitionNativeIntervalReader, StreamMultiMerge, StreamZipJoin, StreamAgg
PartitionNativeIntervalReader, StreamMultiMerge, StreamZipJoin, StreamAgg, StreamZipJoinProducers
from .register_functions import register_functions
from .register_aggregators import register_aggregators
from .table_ir import (MatrixRowsTable, TableJoin, TableLeftJoinRightDistinct, TableIntervalJoin,
Expand Down Expand Up @@ -173,6 +173,7 @@
'toStream',
'ToStream',
'StreamZipJoin',
'StreamZipJoinProducers',
'StreamMultiMerge',
'LowerBoundOnOrderedCollection',
'GroupByKey',
Expand Down
64 changes: 64 additions & 0 deletions hail/python/hail/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,70 @@ def _compute_type(self, env, agg_env, deep_typecheck):
return tstream(self.a.typ.element_type)


class StreamZipJoinProducers(IR):
@typecheck_method(contexts=IR,
ctx_name=str,
make_producer=IR,
key=sequenceof(str),
cur_key=str,
cur_vals=str,
join_f=IR)
def __init__(self, contexts, ctx_name, make_producer, key, cur_key, cur_vals, join_f):
super().__init__(contexts, make_producer, join_f)
self.contexts = contexts
self.ctx_name = ctx_name
self.make_producer = make_producer
self.key = key
self.cur_key = cur_key
self.cur_vals = cur_vals
self.join_f = join_f

def _handle_randomness(self, create_uids):
assert not create_uids
return self

@typecheck_method(new_ir=IR)
def copy(self, *new_irs):
assert len(new_irs) == 3
return StreamZipJoinProducers(new_irs[0], self.ctx_name, new_irs[1],
self.key, self.cur_key, self.cur_vals, new_irs[2])

def head_str(self):
return '({}) {} {} {}'.format(' '.join([escape_id(x) for x in self.key]), self.ctx_name,
self.cur_key, self.cur_vals)

def _compute_type(self, env, agg_env, deep_typecheck):
self.contexts.compute_type(env, agg_env, deep_typecheck)
ctx_elt_type = self.contexts.typ.element_type
self.make_producer.compute_type({**env, self.ctx_name: ctx_elt_type}, agg_env, deep_typecheck)
stream_t = self.make_producer.typ
struct_t = stream_t.element_type
new_env = {**env}
new_env[self.cur_key] = tstruct(**{k: struct_t[k] for k in self.key})
new_env[self.cur_vals] = tarray(struct_t)
self.join_f.compute_type(new_env, agg_env, deep_typecheck)
return tstream(self.join_f.typ)

def renderable_bindings(self, i, default_value=None):
if i == 1:
if default_value is None:
ctx_t = self.contexts.typ.element_type
else:
ctx_t = default_value
return {self.ctx_name: ctx_t}
elif i == 2:
if default_value is None:
struct_t = self.make_producer.typ.element_type
key_x = tstruct(**{k: struct_t[k] for k in self.key})
vals_x = tarray(struct_t)
else:
key_x = default_value
vals_x = default_value
return {self.cur_key: key_x, self.cur_vals: vals_x}
else:
return {}


class StreamZipJoin(IR):
@typecheck_method(streams=sequenceof(IR), key=sequenceof(str), cur_key=str, cur_vals=str, join_f=IR)
def __init__(self, streams, key, cur_key, cur_vals, join_f):
Expand Down
27 changes: 26 additions & 1 deletion hail/python/test/hail/expr/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4160,4 +4160,29 @@ def test_reservoir_sampling():

def test_local_agg():
x = hl.literal([1,2,3,4])
assert hl.eval(x.aggregate(lambda x: hl.agg.sum(x))) == 10
assert hl.eval(x.aggregate(lambda x: hl.agg.sum(x))) == 10


def test_zip_join_producers():
contexts = hl.literal([1,2,3])
zj = hl._zip_join_producers(contexts,
lambda i: hl.range(i).map(lambda x: hl.struct(k=x, stream_id=i)),
['k'],
lambda k, vals: k.annotate(vals=vals))
assert hl.eval(zj) == [
hl.utils.Struct(k=0, vals=[
hl.utils.Struct(k=0, stream_id=1),
hl.utils.Struct(k=0, stream_id=2),
hl.utils.Struct(k=0, stream_id=3),
]),
hl.utils.Struct(k=1, vals=[
None,
hl.utils.Struct(k=1, stream_id=2),
hl.utils.Struct(k=1, stream_id=3),
]),
hl.utils.Struct(k=2, vals=[
None,
None,
hl.utils.Struct(k=2, stream_id=3),
])
]
10 changes: 10 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Binds.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ object Bindings {
curVals -> TArray(eltType))
else
empty
case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) =>
val contextType = TIterable.elementType(contexts.typ)
val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType)
if (i == 1)
Array(ctxName -> contextType)
else if (i == 2)
Array(curKey -> eltType.typeAfterSelectNames(key),
curVals -> TArray(eltType))
else
empty
case StreamFor(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty
case StreamFlatMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty
case StreamFilter(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Children.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ object Children {
as :+ body
case StreamZipJoin(as, _, _, _, joinF) =>
as :+ joinF
case StreamZipJoinProducers(contexts, _, makeProducer, _, _, _, joinF) =>
Array(contexts, makeProducer, joinF)
case StreamMultiMerge(as, _) =>
as
case StreamFilter(a, name, cond) =>
Expand Down
4 changes: 4 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Copy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ object Copy {
case StreamZipJoin(as, key, curKey, curVals, _) =>
assert(newChildren.length == as.length + 1)
StreamZipJoin(newChildren.init.asInstanceOf[IndexedSeq[IR]], key, curKey, curVals, newChildren(as.length).asInstanceOf[IR])
case StreamZipJoinProducers(_, ctxName, _, key, curKey, curVals, _) =>
assert(newChildren.length == 3)
StreamZipJoinProducers(newChildren(0).asInstanceOf[IR], ctxName, newChildren(1).asInstanceOf[IR],
key, curKey, curVals, newChildren(2).asInstanceOf[IR])
case StreamMultiMerge(as, key) =>
assert(newChildren.length == as.length)
StreamMultiMerge(newChildren.asInstanceOf[IndexedSeq[IR]], key)
Expand Down
5 changes: 5 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,11 @@ final case class StreamMultiMerge(as: IndexedSeq[IR], key: IndexedSeq[String]) e
override def typ: TStream = tcoerce[TStream](super.typ)
}

final case class StreamZipJoinProducers(contexts: IR, ctxName: String, makeProducer: IR,
key: IndexedSeq[String], curKey: String, curVals: String, joinF: IR) extends IR {
override def typ: TStream = tcoerce[TStream](super.typ)
}

/**
* The StreamZipJoin node assumes that input streams have distinct keys. If input streams
* do not have distinct keys, the key that is included in the result is undefined, but
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/InferType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ object InferType {
TStream(body.typ)
case StreamZipJoin(_, _, _, _, joinF) =>
TStream(joinF.typ)
case StreamZipJoinProducers(_, _, _, _, _, _, joinF) =>
TStream(joinF.typ)
case StreamMultiMerge(as, _) =>
TStream(tcoerce[TStream](as.head.typ).elementType)
case StreamFilter(a, name, cond) =>
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/Interpretable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ object Interpretable {
_: ReadValue |
_: WriteValue |
_: NDArrayWrite |
_: StreamZipJoinProducers |
_: RNGStateLiteral => false
case x: ApplyIR =>
!Exists(x.body, {
Expand Down
7 changes: 7 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ object NestingDepth {
case StreamZip(as, _, body, _, _) =>
as.foreach(computeIR(_, depth))
computeIR(body, depth.incrementEval)
case StreamZipJoin(as, _, _, _, joinF) =>
as.foreach(computeIR(_, depth))
computeIR(joinF, depth.incrementEval)
case StreamZipJoinProducers(contexts, _, makeProducer, _, _, _, joinF) =>
computeIR(contexts, depth)
computeIR(makeProducer, depth.incrementEval)
computeIR(joinF, depth.incrementEval)
case StreamFor(a, valueName, body) =>
computeIR(a, depth)
computeIR(body, depth.incrementEval)
Expand Down
9 changes: 9 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,15 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean =
newAs <- as.mapRecur(normalize(_))
newJoinF <- normalize(joinF, env.bindEval(curKey -> newCurKey, curVals -> newCurVals))
} yield StreamZipJoin(newAs, key, newCurKey, newCurVals, newJoinF)
case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) =>
val newCtxName = gen()
val newCurKey = gen()
val newCurVals = gen()
for {
newCtxs <- normalize(contexts)
newMakeProducer <- normalize(makeProducer, env.bindEval(ctxName -> newCtxName))
newJoinF <- normalize(joinF, env.bindEval(curKey -> newCurKey, curVals -> newCurVals))
} yield StreamZipJoinProducers(newCtxs, newCtxName, newMakeProducer, key, newCurKey, newCurVals, newJoinF)
case StreamFilter(a, name, body) =>
val newName = gen()
for {
Expand Down
13 changes: 13 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1133,6 +1133,19 @@ object IRParser {
as <- names.mapRecur(_ => ir_value_expr(env)(it))
body <- ir_value_expr(env.bindEval(names.zip(as.map(a => tcoerce[TStream](a.typ).elementType)): _*))(it)
} yield StreamZip(as, names, body, behavior, errorID)
case "StreamZipJoinProducers" =>
val key = identifiers(it)
val ctxName = identifier(it)
val curKey = identifier(it)
val curVals = identifier(it)
for {
ctxs <- ir_value_expr(env)(it)
makeProducer <- ir_value_expr(env.bindEval(ctxName, TIterable.elementType(ctxs.typ)))(it)
body <- {
val structType = TIterable.elementType(makeProducer.typ).asInstanceOf[TStruct]
ir_value_expr(env.bindEval((curKey, structType.typeAfterSelectNames(key)), (curVals, TArray(structType))))(it)
}
} yield StreamZipJoinProducers(ctxs, ctxName, makeProducer, key, curKey, curVals, body)
case "StreamZipJoin" =>
val nStreams = int32_literal(it)
val key = identifiers(it)
Expand Down
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Pretty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int,
}, prettyIdentifiers(names))
case StreamZipJoin(streams, key, curKey, curVals, _) if !elideBindings =>
FastSeq(streams.length.toString, prettyIdentifiers(key), prettyIdentifier(curKey), prettyIdentifier(curVals))
case StreamZipJoinProducers(_, ctxName, _, key, curKey, curVals, _) if !elideBindings =>
FastSeq(prettyIdentifiers(key), prettyIdentifier(ctxName), prettyIdentifier(curKey), prettyIdentifier(curVals))
case StreamMultiMerge(_, key) => single(prettyIdentifiers(key))
case StreamFilter(_, name, _) if !elideBindings => single(prettyIdentifier(name))
case StreamTakeWhile(_, name, _) if !elideBindings => single(prettyIdentifier(name))
Expand Down
21 changes: 21 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,17 @@ object PruneDeadFields {
uses(curVals, bodyEnv.eval).map(TIterable.elementType) :+ selectKey(eltType, key)
)
unifyEnvsSeq(as.map(memoizeValueIR(ctx, _, TStream(childRequestedEltType), memo)))
case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) =>
val baseEltType = tcoerce[TStruct](TIterable.elementType(makeProducer.typ))
val requestedEltType = tcoerce[TStream](requestedType).elementType
val bodyEnv = memoizeValueIR(ctx, joinF, requestedEltType, memo)
val producerRequestedEltType = unifySeq(
baseEltType,
uses(curVals, bodyEnv.eval).map(TIterable.elementType) :+ selectKey(baseEltType, key)
)
val producerEnv = memoizeValueIR(ctx, makeProducer, TStream(producerRequestedEltType), memo)
val ctxEnv = memoizeValueIR(ctx, contexts, TArray(unifySeq(TIterable.elementType(contexts.typ), uses(ctxName, producerEnv.eval))), memo)
unifyEnvsSeq(Array(bodyEnv, producerEnv, ctxEnv))
case StreamMultiMerge(as, key) =>
val eltType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType)
val requestedEltType = tcoerce[TStream](requestedType).elementType
Expand Down Expand Up @@ -1926,6 +1937,16 @@ object PruneDeadFields {
env.bindEval(curKey -> selectKey(newEltType, key), curVals -> TArray(newEltType)),
memo)
StreamZipJoin(newAs, key, curKey, curVals, newJoinF)
case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) =>
val newContexts = rebuildIR(ctx, contexts, env, memo)
val newCtxType = TIterable.elementType(newContexts.typ)
val newMakeProducer = rebuildIR(ctx, makeProducer, env.bindEval(ctxName, newCtxType), memo)
val newEltType = TIterable.elementType(newMakeProducer.typ).asInstanceOf[TStruct]
val newJoinF = rebuildIR(ctx,
joinF,
env.bindEval(curKey -> selectKey(newEltType, key), curVals -> TArray(newEltType)),
memo)
StreamZipJoinProducers(newContexts, ctxName, newMakeProducer,key, curKey, curVals, newJoinF)
case StreamMultiMerge(as, key) =>
val eltType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType)
val requestedEltType = tcoerce[TStream](requestedType).elementType
Expand Down
27 changes: 27 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/Requiredness.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,29 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
uses.foreach { u => defs.bind(u, valTypes) }
as.foreach { a => dependents.getOrElseUpdate(a, mutable.Set[RefEquality[BaseIR]]()) ++= uses }
}
case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) =>
val ctxType = tcoerce[RIterable](lookup(contexts)).elementType
if (refMap.contains(ctxName)) {
val uses = refMap(ctxName)
uses.foreach { u => defs.bind(u, Array(ctxType)) }
dependents.getOrElseUpdate(contexts, mutable.Set[RefEquality[BaseIR]]()) ++= uses
}

val producerElementType = tcoerce[RStruct](tcoerce[RIterable](lookup(makeProducer)).elementType)
if (refMap.contains(curKey)) {
val uses = refMap(curKey)
val keyType = RStruct.fromNamesAndTypes(key.map(k => k -> producerElementType.fieldType(k)))
uses.foreach { u => defs.bind(u, Array(keyType)) }
dependents.getOrElseUpdate(makeProducer, mutable.Set[RefEquality[BaseIR]]()) ++= uses
}
if (refMap.contains(curVals)) {
val uses = refMap(curVals)
val optional = producerElementType.copy(producerElementType.children)
optional.union(false)
uses.foreach { u => defs.bind(u, Array(RIterable(optional))) }
dependents.getOrElseUpdate(makeProducer, mutable.Set[RefEquality[BaseIR]]()) ++= uses
}

case StreamFilter(a, name, cond) => addElementBinding(name, a)
case StreamTakeWhile(a, name, cond) => addElementBinding(name, a)
case StreamDropWhile(a, name, cond) => addElementBinding(name, a)
Expand Down Expand Up @@ -584,6 +607,10 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) {
requiredness.union(as.forall(lookup(_).required))
val eltType = tcoerce[RIterable](requiredness).elementType
eltType.unionFrom(lookup(joinF))
case StreamZipJoinProducers(contexts, ctxName, makeProducer, _, curKey, curVals, joinF) =>
requiredness.union(lookup(contexts).required)
val eltType = tcoerce[RIterable](requiredness).elementType
eltType.unionFrom(lookup(joinF))
case StreamMultiMerge(as, _) =>
requiredness.union(as.forall(lookup(_).required))
val elt = tcoerce[RStruct](tcoerce[RIterable](requiredness).elementType)
Expand Down
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,12 @@ object TypeCheck {
val eltType = tcoerce[TStruct](streamType.elementType)
assert(key.forall(eltType.hasField))
assert(x.typ.elementType == joinF.typ)
case x@StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) =>
assert(contexts.typ.isInstanceOf[TArray])
val streamType = tcoerce[TStream](makeProducer.typ)
val eltType = tcoerce[TStruct](streamType.elementType)
assert(key.forall(eltType.hasField))
assert(x.typ.elementType == joinF.typ)
case x@StreamMultiMerge(as, key) =>
val streamType = tcoerce[TStream](as.head.typ)
assert(as.forall(_.typ == streamType))
Expand Down