diff --git a/hail/python/hail/expr/__init__.py b/hail/python/hail/expr/__init__.py index dac6e84d6ef..3aa25730828 100644 --- a/hail/python/hail/expr/__init__.py +++ b/hail/python/hail/expr/__init__.py @@ -35,7 +35,7 @@ contig_length, liftover, min_rep, uniroot, format, approx_equal, reversed, bit_and, bit_or, bit_xor, bit_lshift, bit_rshift, bit_not, bit_count, binary_search, logit, expit, _values_similar, _showstr, _sort_by, _compare, _locus_windows_per_contig, shuffle, _console_log, dnorm, dchisq, - query_table, keyed_union, keyed_intersection, repeat) + query_table, keyed_union, keyed_intersection, repeat, _zip_join_producers) __all__ = ['HailType', 'hail_type', @@ -309,5 +309,6 @@ 'query_table', 'keyed_union', 'keyed_intersection', + '_zip_join_producers', 'repeat', ] diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index 4be0cf23130..ffa37438f70 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -5058,6 +5058,32 @@ def _union_intersection_base(name, arrays, key, join_f, result_f): return result_f(construct_expr(zj, zj.typ, indices, aggs)) +def _zip_join_producers(contexts, stream_f, key, join_f): + ctx_uid = Env.get_uid() + + ctx_var = construct_variable(ctx_uid, contexts.dtype.element_type) + stream_req = stream_f(ctx_var) + make_prod_ir = stream_req._ir + if isinstance(make_prod_ir.typ, hl.tarray): + make_prod_ir = ir.ToStream(make_prod_ir) + t = stream_req.dtype.element_type + + key_typ = hl.tstruct(**{k: t[k] for k in key}) + vals_typ = hl.tarray(t) + + key_uid = Env.get_uid() + vals_uid = Env.get_uid() + + key_var = construct_variable(key_uid, key_typ) + vals_var = construct_variable(vals_uid, vals_typ) + + join_ir = join_f(key_var, vals_var) + zj = ir.ToArray( + ir.StreamZipJoinProducers(contexts._ir, ctx_uid, make_prod_ir, key, key_uid, vals_uid, join_ir._ir)) + indices, aggs = unify_all(contexts, stream_req, join_ir) + return construct_expr(zj, zj.typ, indices, aggs) + + @typecheck(arrays=expr_oneof(expr_stream(expr_any), expr_array(expr_any)), key=sequenceof(builtins.str)) def keyed_intersection(*arrays, key): """Compute the intersection of sorted arrays on a given key. diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index 7f593785f3f..37c7170bf3d 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -18,7 +18,7 @@ GetTupleElement, Die, ConsoleLog, Apply, ApplySeeded, RNGStateLiteral, RNGSplit,\ TableCount, TableGetGlobals, TableCollect, TableAggregate, MatrixCount, \ MatrixAggregate, TableWrite, udf, subst, clear_session_functions, ReadPartition, \ - PartitionNativeIntervalReader, StreamMultiMerge, StreamZipJoin, StreamAgg + PartitionNativeIntervalReader, StreamMultiMerge, StreamZipJoin, StreamAgg, StreamZipJoinProducers from .register_functions import register_functions from .register_aggregators import register_aggregators from .table_ir import (MatrixRowsTable, TableJoin, TableLeftJoinRightDistinct, TableIntervalJoin, @@ -173,6 +173,7 @@ 'toStream', 'ToStream', 'StreamZipJoin', + 'StreamZipJoinProducers', 'StreamMultiMerge', 'LowerBoundOnOrderedCollection', 'GroupByKey', diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 347331e1458..aed93fdebea 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1368,6 +1368,70 @@ def _compute_type(self, env, agg_env, deep_typecheck): return tstream(self.a.typ.element_type) +class StreamZipJoinProducers(IR): + @typecheck_method(contexts=IR, + ctx_name=str, + make_producer=IR, + key=sequenceof(str), + cur_key=str, + cur_vals=str, + join_f=IR) + def __init__(self, contexts, ctx_name, make_producer, key, cur_key, cur_vals, join_f): + super().__init__(contexts, make_producer, join_f) + self.contexts = contexts + self.ctx_name = ctx_name + self.make_producer = make_producer + self.key = key + self.cur_key = cur_key + self.cur_vals = cur_vals + self.join_f = join_f + + def _handle_randomness(self, create_uids): + assert not create_uids + return self + + @typecheck_method(new_ir=IR) + def copy(self, *new_irs): + assert len(new_irs) == 3 + return StreamZipJoinProducers(new_irs[0], self.ctx_name, new_irs[1], + self.key, self.cur_key, self.cur_vals, new_irs[2]) + + def head_str(self): + return '({}) {} {} {}'.format(' '.join([escape_id(x) for x in self.key]), self.ctx_name, + self.cur_key, self.cur_vals) + + def _compute_type(self, env, agg_env, deep_typecheck): + self.contexts.compute_type(env, agg_env, deep_typecheck) + ctx_elt_type = self.contexts.typ.element_type + self.make_producer.compute_type({**env, self.ctx_name: ctx_elt_type}, agg_env, deep_typecheck) + stream_t = self.make_producer.typ + struct_t = stream_t.element_type + new_env = {**env} + new_env[self.cur_key] = tstruct(**{k: struct_t[k] for k in self.key}) + new_env[self.cur_vals] = tarray(struct_t) + self.join_f.compute_type(new_env, agg_env, deep_typecheck) + return tstream(self.join_f.typ) + + def renderable_bindings(self, i, default_value=None): + if i == 1: + if default_value is None: + ctx_t = self.contexts.typ.element_type + else: + ctx_t = default_value + return {self.ctx_name: ctx_t} + elif i == 2: + if default_value is None: + struct_t = self.make_producer.typ.element_type + key_x = tstruct(**{k: struct_t[k] for k in self.key}) + vals_x = tarray(struct_t) + else: + key_x = default_value + vals_x = default_value + return {self.cur_key: key_x, self.cur_vals: vals_x} + else: + return {} + + class StreamZipJoin(IR): @typecheck_method(streams=sequenceof(IR), key=sequenceof(str), cur_key=str, cur_vals=str, join_f=IR) def __init__(self, streams, key, cur_key, cur_vals, join_f): diff --git a/hail/python/test/hail/expr/test_expr.py b/hail/python/test/hail/expr/test_expr.py index def5ba76948..852cf064d56 100644 --- a/hail/python/test/hail/expr/test_expr.py +++ b/hail/python/test/hail/expr/test_expr.py @@ -4160,4 +4160,29 @@ def test_reservoir_sampling(): def test_local_agg(): x = hl.literal([1,2,3,4]) - assert hl.eval(x.aggregate(lambda x: hl.agg.sum(x))) == 10 \ No newline at end of file + assert hl.eval(x.aggregate(lambda x: hl.agg.sum(x))) == 10 + + +def test_zip_join_producers(): + contexts = hl.literal([1,2,3]) + zj = hl._zip_join_producers(contexts, + lambda i: hl.range(i).map(lambda x: hl.struct(k=x, stream_id=i)), + ['k'], + lambda k, vals: k.annotate(vals=vals)) + assert hl.eval(zj) == [ + hl.utils.Struct(k=0, vals=[ + hl.utils.Struct(k=0, stream_id=1), + hl.utils.Struct(k=0, stream_id=2), + hl.utils.Struct(k=0, stream_id=3), + ]), + hl.utils.Struct(k=1, vals=[ + None, + hl.utils.Struct(k=1, stream_id=2), + hl.utils.Struct(k=1, stream_id=3), + ]), + hl.utils.Struct(k=2, vals=[ + None, + None, + hl.utils.Struct(k=2, stream_id=3), + ]) + ] diff --git a/hail/src/main/scala/is/hail/expr/ir/Binds.scala b/hail/src/main/scala/is/hail/expr/ir/Binds.scala index 095da71c979..fe961351ff2 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Binds.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Binds.scala @@ -25,6 +25,16 @@ object Bindings { curVals -> TArray(eltType)) else empty + case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) => + val contextType = TIterable.elementType(contexts.typ) + val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType) + if (i == 1) + Array(ctxName -> contextType) + else if (i == 2) + Array(curKey -> eltType.typeAfterSelectNames(key), + curVals -> TArray(eltType)) + else + empty case StreamFor(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty case StreamFlatMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty case StreamFilter(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty diff --git a/hail/src/main/scala/is/hail/expr/ir/Children.scala b/hail/src/main/scala/is/hail/expr/ir/Children.scala index 146ef135f7e..4d973055cfc 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Children.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Children.scala @@ -118,6 +118,8 @@ object Children { as :+ body case StreamZipJoin(as, _, _, _, joinF) => as :+ joinF + case StreamZipJoinProducers(contexts, _, makeProducer, _, _, _, joinF) => + Array(contexts, makeProducer, joinF) case StreamMultiMerge(as, _) => as case StreamFilter(a, name, cond) => diff --git a/hail/src/main/scala/is/hail/expr/ir/Copy.scala b/hail/src/main/scala/is/hail/expr/ir/Copy.scala index ba68a1f4786..76fb399421f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Copy.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Copy.scala @@ -200,6 +200,10 @@ object Copy { case StreamZipJoin(as, key, curKey, curVals, _) => assert(newChildren.length == as.length + 1) StreamZipJoin(newChildren.init.asInstanceOf[IndexedSeq[IR]], key, curKey, curVals, newChildren(as.length).asInstanceOf[IR]) + case StreamZipJoinProducers(_, ctxName, _, key, curKey, curVals, _) => + assert(newChildren.length == 3) + StreamZipJoinProducers(newChildren(0).asInstanceOf[IR], ctxName, newChildren(1).asInstanceOf[IR], + key, curKey, curVals, newChildren(2).asInstanceOf[IR]) case StreamMultiMerge(as, key) => assert(newChildren.length == as.length) StreamMultiMerge(newChildren.asInstanceOf[IndexedSeq[IR]], key) diff --git a/hail/src/main/scala/is/hail/expr/ir/IR.scala b/hail/src/main/scala/is/hail/expr/ir/IR.scala index 013edc28131..1efbb74c51c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/IR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/IR.scala @@ -356,6 +356,11 @@ final case class StreamMultiMerge(as: IndexedSeq[IR], key: IndexedSeq[String]) e override def typ: TStream = tcoerce[TStream](super.typ) } +final case class StreamZipJoinProducers(contexts: IR, ctxName: String, makeProducer: IR, + key: IndexedSeq[String], curKey: String, curVals: String, joinF: IR) extends IR { + override def typ: TStream = tcoerce[TStream](super.typ) +} + /** * The StreamZipJoin node assumes that input streams have distinct keys. If input streams * do not have distinct keys, the key that is included in the result is undefined, but diff --git a/hail/src/main/scala/is/hail/expr/ir/InferType.scala b/hail/src/main/scala/is/hail/expr/ir/InferType.scala index d0e789d15cb..9dadea3f476 100644 --- a/hail/src/main/scala/is/hail/expr/ir/InferType.scala +++ b/hail/src/main/scala/is/hail/expr/ir/InferType.scala @@ -136,6 +136,8 @@ object InferType { TStream(body.typ) case StreamZipJoin(_, _, _, _, joinF) => TStream(joinF.typ) + case StreamZipJoinProducers(_, _, _, _, _, _, joinF) => + TStream(joinF.typ) case StreamMultiMerge(as, _) => TStream(tcoerce[TStream](as.head.typ).elementType) case StreamFilter(a, name, cond) => diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpretable.scala b/hail/src/main/scala/is/hail/expr/ir/Interpretable.scala index caf1820ce20..3024703c790 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpretable.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpretable.scala @@ -41,6 +41,7 @@ object Interpretable { _: ReadValue | _: WriteValue | _: NDArrayWrite | + _: StreamZipJoinProducers | _: RNGStateLiteral => false case x: ApplyIR => !Exists(x.body, { diff --git a/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala b/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala index 2754e0b645e..ee1b92cd449 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala @@ -60,6 +60,13 @@ object NestingDepth { case StreamZip(as, _, body, _, _) => as.foreach(computeIR(_, depth)) computeIR(body, depth.incrementEval) + case StreamZipJoin(as, _, _, _, joinF) => + as.foreach(computeIR(_, depth)) + computeIR(joinF, depth.incrementEval) + case StreamZipJoinProducers(contexts, _, makeProducer, _, _, _, joinF) => + computeIR(contexts, depth) + computeIR(makeProducer, depth.incrementEval) + computeIR(joinF, depth.incrementEval) case StreamFor(a, valueName, body) => computeIR(a, depth) computeIR(body, depth.incrementEval) diff --git a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala index dd6d4e21358..42843cc6430 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala @@ -98,6 +98,15 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = newAs <- as.mapRecur(normalize(_)) newJoinF <- normalize(joinF, env.bindEval(curKey -> newCurKey, curVals -> newCurVals)) } yield StreamZipJoin(newAs, key, newCurKey, newCurVals, newJoinF) + case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) => + val newCtxName = gen() + val newCurKey = gen() + val newCurVals = gen() + for { + newCtxs <- normalize(contexts) + newMakeProducer <- normalize(makeProducer, env.bindEval(ctxName -> newCtxName)) + newJoinF <- normalize(joinF, env.bindEval(curKey -> newCurKey, curVals -> newCurVals)) + } yield StreamZipJoinProducers(newCtxs, newCtxName, newMakeProducer, key, newCurKey, newCurVals, newJoinF) case StreamFilter(a, name, body) => val newName = gen() for { diff --git a/hail/src/main/scala/is/hail/expr/ir/Parser.scala b/hail/src/main/scala/is/hail/expr/ir/Parser.scala index 9a6fb167bb8..4ea71179b4f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Parser.scala @@ -1133,6 +1133,19 @@ object IRParser { as <- names.mapRecur(_ => ir_value_expr(env)(it)) body <- ir_value_expr(env.bindEval(names.zip(as.map(a => tcoerce[TStream](a.typ).elementType)): _*))(it) } yield StreamZip(as, names, body, behavior, errorID) + case "StreamZipJoinProducers" => + val key = identifiers(it) + val ctxName = identifier(it) + val curKey = identifier(it) + val curVals = identifier(it) + for { + ctxs <- ir_value_expr(env)(it) + makeProducer <- ir_value_expr(env.bindEval(ctxName, TIterable.elementType(ctxs.typ)))(it) + body <- { + val structType = TIterable.elementType(makeProducer.typ).asInstanceOf[TStruct] + ir_value_expr(env.bindEval((curKey, structType.typeAfterSelectNames(key)), (curVals, TArray(structType))))(it) + } + } yield StreamZipJoinProducers(ctxs, ctxName, makeProducer, key, curKey, curVals, body) case "StreamZipJoin" => val nStreams = int32_literal(it) val key = identifiers(it) diff --git a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala index 44e2d8e2cda..b00415defd6 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala @@ -222,6 +222,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, }, prettyIdentifiers(names)) case StreamZipJoin(streams, key, curKey, curVals, _) if !elideBindings => FastSeq(streams.length.toString, prettyIdentifiers(key), prettyIdentifier(curKey), prettyIdentifier(curVals)) + case StreamZipJoinProducers(_, ctxName, _, key, curKey, curVals, _) if !elideBindings => + FastSeq(prettyIdentifiers(key), prettyIdentifier(ctxName), prettyIdentifier(curKey), prettyIdentifier(curVals)) case StreamMultiMerge(_, key) => single(prettyIdentifiers(key)) case StreamFilter(_, name, _) if !elideBindings => single(prettyIdentifier(name)) case StreamTakeWhile(_, name, _) if !elideBindings => single(prettyIdentifier(name)) diff --git a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala index 617065c3d15..cc47db672ff 100644 --- a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala +++ b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala @@ -1134,6 +1134,17 @@ object PruneDeadFields { uses(curVals, bodyEnv.eval).map(TIterable.elementType) :+ selectKey(eltType, key) ) unifyEnvsSeq(as.map(memoizeValueIR(ctx, _, TStream(childRequestedEltType), memo))) + case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) => + val baseEltType = tcoerce[TStruct](TIterable.elementType(makeProducer.typ)) + val requestedEltType = tcoerce[TStream](requestedType).elementType + val bodyEnv = memoizeValueIR(ctx, joinF, requestedEltType, memo) + val producerRequestedEltType = unifySeq( + baseEltType, + uses(curVals, bodyEnv.eval).map(TIterable.elementType) :+ selectKey(baseEltType, key) + ) + val producerEnv = memoizeValueIR(ctx, makeProducer, TStream(producerRequestedEltType), memo) + val ctxEnv = memoizeValueIR(ctx, contexts, TArray(unifySeq(TIterable.elementType(contexts.typ), uses(ctxName, producerEnv.eval))), memo) + unifyEnvsSeq(Array(bodyEnv, producerEnv, ctxEnv)) case StreamMultiMerge(as, key) => val eltType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) val requestedEltType = tcoerce[TStream](requestedType).elementType @@ -1926,6 +1937,16 @@ object PruneDeadFields { env.bindEval(curKey -> selectKey(newEltType, key), curVals -> TArray(newEltType)), memo) StreamZipJoin(newAs, key, curKey, curVals, newJoinF) + case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) => + val newContexts = rebuildIR(ctx, contexts, env, memo) + val newCtxType = TIterable.elementType(newContexts.typ) + val newMakeProducer = rebuildIR(ctx, makeProducer, env.bindEval(ctxName, newCtxType), memo) + val newEltType = TIterable.elementType(newMakeProducer.typ).asInstanceOf[TStruct] + val newJoinF = rebuildIR(ctx, + joinF, + env.bindEval(curKey -> selectKey(newEltType, key), curVals -> TArray(newEltType)), + memo) + StreamZipJoinProducers(newContexts, ctxName, newMakeProducer,key, curKey, curVals, newJoinF) case StreamMultiMerge(as, key) => val eltType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) val requestedEltType = tcoerce[TStream](requestedType).elementType diff --git a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala index 371f0b8412d..97803dd82c7 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala @@ -213,6 +213,29 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { uses.foreach { u => defs.bind(u, valTypes) } as.foreach { a => dependents.getOrElseUpdate(a, mutable.Set[RefEquality[BaseIR]]()) ++= uses } } + case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) => + val ctxType = tcoerce[RIterable](lookup(contexts)).elementType + if (refMap.contains(ctxName)) { + val uses = refMap(ctxName) + uses.foreach { u => defs.bind(u, Array(ctxType)) } + dependents.getOrElseUpdate(contexts, mutable.Set[RefEquality[BaseIR]]()) ++= uses + } + + val producerElementType = tcoerce[RStruct](tcoerce[RIterable](lookup(makeProducer)).elementType) + if (refMap.contains(curKey)) { + val uses = refMap(curKey) + val keyType = RStruct.fromNamesAndTypes(key.map(k => k -> producerElementType.fieldType(k))) + uses.foreach { u => defs.bind(u, Array(keyType)) } + dependents.getOrElseUpdate(makeProducer, mutable.Set[RefEquality[BaseIR]]()) ++= uses + } + if (refMap.contains(curVals)) { + val uses = refMap(curVals) + val optional = producerElementType.copy(producerElementType.children) + optional.union(false) + uses.foreach { u => defs.bind(u, Array(RIterable(optional))) } + dependents.getOrElseUpdate(makeProducer, mutable.Set[RefEquality[BaseIR]]()) ++= uses + } + case StreamFilter(a, name, cond) => addElementBinding(name, a) case StreamTakeWhile(a, name, cond) => addElementBinding(name, a) case StreamDropWhile(a, name, cond) => addElementBinding(name, a) @@ -584,6 +607,10 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.union(as.forall(lookup(_).required)) val eltType = tcoerce[RIterable](requiredness).elementType eltType.unionFrom(lookup(joinF)) + case StreamZipJoinProducers(contexts, ctxName, makeProducer, _, curKey, curVals, joinF) => + requiredness.union(lookup(contexts).required) + val eltType = tcoerce[RIterable](requiredness).elementType + eltType.unionFrom(lookup(joinF)) case StreamMultiMerge(as, _) => requiredness.union(as.forall(lookup(_).required)) val elt = tcoerce[RStruct](tcoerce[RIterable](requiredness).elementType) diff --git a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala index 66addb26218..e06d3b2cd2b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala @@ -342,6 +342,12 @@ object TypeCheck { val eltType = tcoerce[TStruct](streamType.elementType) assert(key.forall(eltType.hasField)) assert(x.typ.elementType == joinF.typ) + case x@StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) => + assert(contexts.typ.isInstanceOf[TArray]) + val streamType = tcoerce[TStream](makeProducer.typ) + val eltType = tcoerce[TStruct](streamType.elementType) + assert(key.forall(eltType.hasField)) + assert(x.typ.elementType == joinF.typ) case x@StreamMultiMerge(as, key) => val streamType = tcoerce[TStream](as.head.typ) assert(as.forall(_.typ == streamType)) diff --git a/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala b/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala index b8010a6cf37..8e57825e249 100644 --- a/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala +++ b/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala @@ -15,7 +15,7 @@ import is.hail.types.physical.stypes.primitives.{SFloat64Value, SInt32Value} import is.hail.types.physical.stypes.{EmitType, SSettable} import is.hail.types.physical.{PCanonicalArray, PCanonicalBinary, PCanonicalStruct, PType} import is.hail.types.virtual._ -import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq} +import is.hail.types.{RIterable, TypeWithRequiredness, VirtualTypeWithReq} import is.hail.utils._ import is.hail.variant.Locus import org.objectweb.asm.Opcodes._ @@ -2741,6 +2741,241 @@ object EmitStream { SStreamValue(producer) } + case x@StreamZipJoinProducers(contexts, ctxName, makeProducer, key, keyRef, valsRef, joinIR) => + emit(contexts, cb).map(cb) { case contextsArray: SIndexableValue => + val nStreams = cb.memoizeField(contextsArray.loadLength()) + val iterArray = cb.memoizeField(Code.newArray[NoBoxLongIterator](nStreams), "iterArray") + val idx = cb.newLocal[Int]("i", 0) + val eltType = VirtualTypeWithReq(TIterable.elementType(makeProducer.typ), + emitter.ctx.req.lookup(makeProducer).asInstanceOf[RIterable].elementType).canonicalPType + .asInstanceOf[PCanonicalStruct] + .setRequired(false) + var streamRequiresMemoryManagement = false + cb.whileLoop(idx < nStreams, { + val iter = produceIterator(makeProducer, + eltType, + cb, + outerRegion, + env.bind(ctxName, cb.memoize(contextsArray.loadElement(cb, idx)))) + .get(cb, "streams in zipJoinProducers cannot be missing") + .asInstanceOf[SStreamConcrete] + streamRequiresMemoryManagement = iter.st.requiresMemoryManagement + cb += iterArray.update(idx, iter.it) + cb.assign(idx, idx + 1) + }) + + val keyType = eltType.selectFields(key) + + val curValsType = PCanonicalArray(eltType) + + val _elementRegion = mb.genFieldThisRef[Region]("szj_region") + + // The algorithm maintains a tournament tree of comparisons between the + // current values of the k streams. The tournament tree is a complete + // binary tree with k leaves. The leaves of the tree are the streams, + // and each internal node represents the "contest" between the "winners" + // of the two subtrees, where the winner is the stream with the smaller + // current key. Each internal node stores the index of the stream which + // *lost* that contest. + // Each time we remove the overall winner, and replace that stream's + // leaf with its next value, we only need to rerun the contests on the + // path from that leaf to the root, comparing the new value with what + // previously lost that contest to the previous overall winner. + + val k = nStreams + // The leaf nodes of the tournament tree, each of which holds a pointer + // to the current value of that stream. + val heads = mb.genFieldThisRef[Array[Long]]("merge_heads") + // The internal nodes of the tournament tree, laid out in breadth-first + // order, each of which holds the index of the stream which lost that + // contest. + val bracket = mb.genFieldThisRef[Array[Int]]("merge_bracket") + // When updating the tournament tree, holds the winner of the subtree + // containing the updated leaf. Otherwise, holds the overall winner, i.e. + // the current least element. + val winner = mb.genFieldThisRef[Int]("merge_winner") + val result = mb.genFieldThisRef[Array[Long]]("merge_result") + val i = mb.genFieldThisRef[Int]("merge_i") + + val curKey = mb.newPField("st_grpby_curkey", keyType.sType) + + val xKey = mb.newEmitField("zipjoin_key", keyType.sType, required = true) + val xElts = mb.newEmitField("zipjoin_elts", curValsType.sType, required = true) + + val joinResult: EmitCode = EmitCode.fromI(mb) { cb => + val newEnv = env.bind((keyRef -> xKey), (valsRef -> xElts)) + emit(joinIR, cb, env = newEnv) + } + + val regionArray: Settable[Array[Region]] = if (streamRequiresMemoryManagement) + mb.genFieldThisRef[Array[Region]]("szj_region_array") + else + null + + val producer = new StreamProducer { + override def method: EmitMethodBuilder[_] = mb + + override val length: Option[EmitCodeBuilder => Code[Int]] = None + + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + if (streamRequiresMemoryManagement) + cb.assign(regionArray, Code.newArray[Region](nStreams)) + cb.assign(bracket, Code.newArray[Int](k)) + cb.assign(heads, Code.newArray[Long](k)) + cb.forLoop(cb.assign(i, 0), i < k, cb.assign(i, i + 1), { + cb += (bracket(i) = -1) + val eltRegion: Value[Region] = if (streamRequiresMemoryManagement) { + val r = cb.memoize(Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) + cb += regionArray.update(i, r) + r + } else outerRegion + cb += iterArray(i).invoke[Region, Region, Unit]("init", outerRegion, eltRegion) + }) + cb.assign(result, Code._null) + cb.assign(i, 0) + cb.assign(winner, 0) + } + + override val elementRegion: Settable[Region] = _elementRegion + override val requiresMemoryManagementPerElement: Boolean = streamRequiresMemoryManagement + override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => + val LrunMatch = CodeLabel() + val LpullChild = CodeLabel() + val LloopEnd = CodeLabel() + val LaddToResult = CodeLabel() + val LstartNewKey = CodeLabel() + val Lpush = CodeLabel() + + def inSetup: Code[Boolean] = result.isNull + + cb.ifx(inSetup, { + cb.assign(i, 0) + cb.goto(LpullChild) + }, { + cb.ifx(winner.ceq(k), cb.goto(LendOfStream), cb.goto(LstartNewKey)) + }) + + cb.define(Lpush) + cb.assign(xKey, EmitCode.present(cb.emb, curKey)) + cb.assign(xElts, EmitCode.present(cb.emb, curValsType.constructFromElements(cb, elementRegion, k, false) { (cb, i) => + IEmitCode(cb, result(i).ceq(0L), eltType.loadCheapSCode(cb, result(i))) + })) + cb.goto(LproduceElementDone) + + cb.define(LstartNewKey) + cb.forLoop(cb.assign(i, 0), i < k, cb.assign(i, i + 1), { + cb += (result(i) = 0L) + }) + cb.assign(curKey, eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + .castTo(cb, elementRegion, curKey.st, true)) + cb.goto(LaddToResult) + + cb.define(LaddToResult) + cb += (result(winner) = heads(winner)) + if (streamRequiresMemoryManagement) { + val r = cb.newLocal[Region]("tzj_winner_region", regionArray(winner)) + cb += elementRegion.trackAndIncrementReferenceCountOf(r) + cb += r.clearRegion() + } + cb.goto(LpullChild) + + val matchIdx = mb.genFieldThisRef[Int]("merge_match_idx") + val challenger = mb.genFieldThisRef[Int]("merge_challenger") + // Compare 'winner' with value in 'matchIdx', loser goes in 'matchIdx', + // winner goes on to next round. A contestant '-1' beats everything + // (negative infinity), a contestant 'k' loses to everything + // (positive infinity), and values in between are indices into 'heads'. + + cb.define(LrunMatch) + cb.assign(challenger, bracket(matchIdx)) + cb.ifx(matchIdx.ceq(0) || challenger.ceq(-1), cb.goto(LloopEnd)) + + val LafterChallenge = CodeLabel() + + cb.ifx(challenger.cne(k), { + val LchallengerWins = CodeLabel() + + cb.ifx(winner.ceq(k), cb.goto(LchallengerWins)) + + val left = eltType.loadCheapSCode(cb, heads(challenger)).subset(key: _*) + val right = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + val ord = StructOrdering.make(left.st, right.st, cb.emb.ecb, missingFieldsEqual = false) + cb.ifx(ord.lteqNonnull(cb, left, right), + cb.goto(LchallengerWins), + cb.goto(LafterChallenge)) + + cb.define(LchallengerWins) + cb += (bracket(matchIdx) = winner) + cb.assign(winner, challenger) + }) + cb.define(LafterChallenge) + cb.assign(matchIdx, matchIdx >>> 1) + cb.goto(LrunMatch) + + cb.define(LloopEnd) + cb.ifx(matchIdx.ceq(0), { + // 'winner' is smallest of all k heads. If 'winner' = k, all heads + // must be k, and all streams are exhausted. + + cb.ifx(inSetup, { + cb.ifx(winner.ceq(k), + cb.goto(LendOfStream), + { + cb.assign(result, Code.newArray[Long](k)) + cb.goto(LstartNewKey) + }) + }, { + cb.ifx(!winner.cne(k), cb.goto(Lpush)) + val left = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + val right = curKey + val ord = StructOrdering.make(left.st, right.st.asInstanceOf[SBaseStruct], + cb.emb.ecb, missingFieldsEqual = false) + cb.ifx(ord.equivNonnull(cb, left, right), cb.goto(LaddToResult), cb.goto(Lpush)) + }) + }, { + // We're still in the setup phase + cb += (bracket(matchIdx) = winner) + cb.assign(i, i + 1) + cb.assign(winner, i) + cb.goto(LpullChild) + }) + + cb.define(LpullChild) + cb.ifx(winner >= nStreams, LendOfStream.goto) // can only happen if k=0 + val winnerIter = cb.memoize(iterArray(winner)) + val winnerNextElt = cb.memoize(winnerIter.invoke[Long]("next")) + cb.ifx(winnerIter.invoke[Boolean]("eos"), { + cb.assign(matchIdx, (winner + k) >>> 1) + cb.assign(winner, k) + }, { + cb.assign(matchIdx, (winner + k) >>> 1) + cb += heads.update(winner, winnerNextElt) + }) + cb.goto(LrunMatch) + } + + override val element: EmitCode = joinResult + + override def close(cb: EmitCodeBuilder): Unit = { + cb.assign(i, 0) + cb.whileLoop(i < nStreams, { + cb += iterArray(i).invoke[Unit]("close") + cb.assign(i, i + 1) + if (requiresMemoryManagementPerElement) + cb += regionArray(i).invoke[Unit]("invalidate") + }) + if (requiresMemoryManagementPerElement) + cb.assign(regionArray, Code._null) + cb.assign(bracket, Code._null) + cb.assign(heads, Code._null) + cb.assign(result, Code._null) + } + } + + SStreamValue(producer) + } + + case x@StreamMultiMerge(as, key) => IEmitCode.multiMapEmitCodes(cb, as.map(a => EmitCode.fromI(mb)(cb => emit(a, cb)))) { children => val producers = children.map(_.asStream.getProducer(mb))