Skip to content

Commit

Permalink
[hail] new aggregator path for TableAggregateByKey (#7195)
Browse files Browse the repository at this point in the history
* [hail] new aggregator path for TableAggregateByKey

* bump
  • Loading branch information
Arcturus Wang authored and danking committed Oct 7, 2019
1 parent 915efc9 commit cc29553
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Expand Up @@ -1476,6 +1476,106 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR {
val prev = child.execute(ctx)
val prevRVD = prev.rvd

if (HailContext.getFlag("newaggs") != null) {
try {
val res = genUID()
val extracted = agg.Extract(expr, res)

val (_, makeInit) = ir.CompileWithAggregators2[Long, Unit](
extracted.aggs,
"global", prev.globals.t,
extracted.init)

val (_, makeSeq) = ir.CompileWithAggregators2[Long, Long, Unit](
extracted.aggs,
"global", prev.globals.t,
"row", prev.rvd.rowPType,
extracted.seqPerElt)

val valueIR = Let(res, extracted.results, extracted.postAggIR)
val keyType = PType.canonical(prev.typ.keyType).asInstanceOf[PStruct]
val key = Ref(genUID(), keyType.virtualType)
val value = Ref(genUID(), valueIR.typ)
val (rowType: PStruct, makeRow) = ir.CompileWithAggregators2[Long, Long, Long](
extracted.aggs,
"global", prev.globals.t,
key.name, keyType,
Let(value.name, valueIR,
InsertFields(key, typ.valueType.fieldNames.map(n => n -> GetField(value, n)))))
assert(rowType.virtualType == typ.rowType, s"$rowType, ${ typ.rowType }")

val localChildRowType = prevRVD.rowPType
val keyIndices = prev.typ.keyFieldIdx
val keyOrd = prevRVD.typ.kRowOrd
val globalsBc = prev.globals.broadcast

val newRVDType = prevRVD.typ.copy(rowType = rowType)

val newRVD = prevRVD
.repartition(prevRVD.partitioner.strictify)
.boundary
.mapPartitionsWithIndex(newRVDType, { (i, ctx, it) =>
val partRegion = ctx.freshRegion

val globalsOff = globalsBc.value.readRegionValue(partRegion)

val initialize = makeInit(i, partRegion)
val sequence = makeSeq(i, partRegion)
val newRowF = makeRow(i, partRegion)

val aggRegion = ctx.freshRegion

new Iterator[RegionValue] {
var isEnd = false
var current: RegionValue = _
val rowKey: WritableRegionValue = WritableRegionValue(keyType, ctx.freshRegion)
val consumerRegion: Region = ctx.region
val newRV = RegionValue(consumerRegion)

def hasNext: Boolean = {
if (isEnd || (current == null && !it.hasNext)) {
isEnd = true
return false
}
if (current == null)
current = it.next()
true
}

def next(): RegionValue = {
if (!hasNext)
throw new java.util.NoSuchElementException()

rowKey.setSelect(localChildRowType, keyIndices, current)
val region = current.region

aggRegion.clear()
initialize.newAggState(aggRegion)
initialize(region, globalsOff, false)
sequence.setAggState(aggRegion, initialize.getAggOffset())

do {
val region = current.region
sequence(region,
globalsOff, false,
current.offset, false)
current = null
} while (hasNext && keyOrd.equiv(rowKey.value, current))
newRowF.setAggState(aggRegion, sequence.getAggOffset())
newRV.setOffset(newRowF(consumerRegion, globalsOff, false, rowKey.offset, false))
newRV
}
}
})

return prev.copy(rvd = newRVD, typ = typ)

} catch {
case e: agg.UnsupportedExtraction =>
log.info(s"couldn't lower TableAggregate: $e")
}
}

val (rvAggs, makeInit, makeSeq, aggResultType, postAggIR) = ir.CompileWithAggregators[Long, Long, Long](
"global", prev.globals.t,
"global", prev.globals.t,
Expand Down

0 comments on commit cc29553

Please sign in to comment.