From cc295538befa612093445939f11e018a78a23fe0 Mon Sep 17 00:00:00 2001 From: Arcturus Wang Date: Mon, 7 Oct 2019 17:41:09 -0400 Subject: [PATCH] [hail] new aggregator path for TableAggregateByKey (#7195) * [hail] new aggregator path for TableAggregateByKey * bump --- .../main/scala/is/hail/expr/ir/TableIR.scala | 100 ++++++++++++++++++ 1 file changed, 100 insertions(+) diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index 4e2d886d794..72d27f20faf 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -1476,6 +1476,106 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { val prev = child.execute(ctx) val prevRVD = prev.rvd + if (HailContext.getFlag("newaggs") != null) { + try { + val res = genUID() + val extracted = agg.Extract(expr, res) + + val (_, makeInit) = ir.CompileWithAggregators2[Long, Unit]( + extracted.aggs, + "global", prev.globals.t, + extracted.init) + + val (_, makeSeq) = ir.CompileWithAggregators2[Long, Long, Unit]( + extracted.aggs, + "global", prev.globals.t, + "row", prev.rvd.rowPType, + extracted.seqPerElt) + + val valueIR = Let(res, extracted.results, extracted.postAggIR) + val keyType = PType.canonical(prev.typ.keyType).asInstanceOf[PStruct] + val key = Ref(genUID(), keyType.virtualType) + val value = Ref(genUID(), valueIR.typ) + val (rowType: PStruct, makeRow) = ir.CompileWithAggregators2[Long, Long, Long]( + extracted.aggs, + "global", prev.globals.t, + key.name, keyType, + Let(value.name, valueIR, + InsertFields(key, typ.valueType.fieldNames.map(n => n -> GetField(value, n))))) + assert(rowType.virtualType == typ.rowType, s"$rowType, ${ typ.rowType }") + + val localChildRowType = prevRVD.rowPType + val keyIndices = prev.typ.keyFieldIdx + val keyOrd = prevRVD.typ.kRowOrd + val globalsBc = prev.globals.broadcast + + val newRVDType = prevRVD.typ.copy(rowType = rowType) + + val newRVD = prevRVD + .repartition(prevRVD.partitioner.strictify) + .boundary + .mapPartitionsWithIndex(newRVDType, { (i, ctx, it) => + val partRegion = ctx.freshRegion + + val globalsOff = globalsBc.value.readRegionValue(partRegion) + + val initialize = makeInit(i, partRegion) + val sequence = makeSeq(i, partRegion) + val newRowF = makeRow(i, partRegion) + + val aggRegion = ctx.freshRegion + + new Iterator[RegionValue] { + var isEnd = false + var current: RegionValue = _ + val rowKey: WritableRegionValue = WritableRegionValue(keyType, ctx.freshRegion) + val consumerRegion: Region = ctx.region + val newRV = RegionValue(consumerRegion) + + def hasNext: Boolean = { + if (isEnd || (current == null && !it.hasNext)) { + isEnd = true + return false + } + if (current == null) + current = it.next() + true + } + + def next(): RegionValue = { + if (!hasNext) + throw new java.util.NoSuchElementException() + + rowKey.setSelect(localChildRowType, keyIndices, current) + val region = current.region + + aggRegion.clear() + initialize.newAggState(aggRegion) + initialize(region, globalsOff, false) + sequence.setAggState(aggRegion, initialize.getAggOffset()) + + do { + val region = current.region + sequence(region, + globalsOff, false, + current.offset, false) + current = null + } while (hasNext && keyOrd.equiv(rowKey.value, current)) + newRowF.setAggState(aggRegion, sequence.getAggOffset()) + newRV.setOffset(newRowF(consumerRegion, globalsOff, false, rowKey.offset, false)) + newRV + } + } + }) + + return prev.copy(rvd = newRVD, typ = typ) + + } catch { + case e: agg.UnsupportedExtraction => + log.info(s"couldn't lower TableAggregate: $e") + } + } + val (rvAggs, makeInit, makeSeq, aggResultType, postAggIR) = ir.CompileWithAggregators[Long, Long, Long]( "global", prev.globals.t, "global", prev.globals.t,